use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::{
BaseChatModel, ChatStream, ModelProfile, ToolChoice,
};
use cognis_core::messages::{AIMessage, Message};
use cognis_core::outputs::{ChatGeneration, ChatResult};
use cognis_core::tools::ToolSchema;
pub struct GracefulChatModel {
inner: Box<dyn BaseChatModel>,
fallback_message: String,
#[allow(clippy::type_complexity)]
on_error: Option<Arc<dyn Fn(&CognisError) + Send + Sync>>,
}
impl GracefulChatModel {
pub fn new(inner: Box<dyn BaseChatModel>, fallback_message: String) -> Self {
Self {
inner,
fallback_message,
on_error: None,
}
}
pub fn with_on_error<F>(mut self, callback: F) -> Self
where
F: Fn(&CognisError) + Send + Sync + 'static,
{
self.on_error = Some(Arc::new(callback));
self
}
fn fallback_result(&self) -> ChatResult {
ChatResult {
generations: vec![ChatGeneration {
text: self.fallback_message.clone(),
message: Message::Ai(AIMessage::new(&self.fallback_message)),
generation_info: None,
}],
llm_output: None,
}
}
}
#[async_trait]
impl BaseChatModel for GracefulChatModel {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult> {
match self.inner._generate(messages, stop).await {
Ok(result) => Ok(result),
Err(e) => {
if let Some(ref callback) = self.on_error {
callback(&e);
}
Ok(self.fallback_result())
}
}
}
fn llm_type(&self) -> &str {
self.inner.llm_type()
}
async fn _stream(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatStream> {
match self.inner._stream(messages, stop).await {
Ok(stream) => Ok(stream),
Err(e) => {
if let Some(ref callback) = self.on_error {
callback(&e);
}
use cognis_core::messages::AIMessageChunk;
use cognis_core::outputs::ChatGenerationChunk;
use futures::stream;
let chunk = ChatGenerationChunk {
text: self.fallback_message.clone(),
message: AIMessageChunk::new(&self.fallback_message),
generation_info: None,
};
Ok(Box::pin(stream::once(async move { Ok(chunk) })))
}
}
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
self.inner.bind_tools(tools, tool_choice)
}
fn profile(&self) -> ModelProfile {
self.inner.profile()
}
fn get_num_tokens_from_messages(&self, messages: &[Message]) -> usize {
self.inner.get_num_tokens_from_messages(messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::HumanMessage;
use std::sync::atomic::{AtomicBool, Ordering};
struct SuccessModel;
#[async_trait]
impl BaseChatModel for SuccessModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
Ok(ChatResult {
generations: vec![ChatGeneration {
text: "Real response".into(),
message: Message::Ai(AIMessage::new("Real response")),
generation_info: None,
}],
llm_output: None,
})
}
fn llm_type(&self) -> &str {
"success_mock"
}
}
struct FailModel;
#[async_trait]
impl BaseChatModel for FailModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
Err(CognisError::HttpError {
status: 500,
body: "Internal Server Error".into(),
})
}
fn llm_type(&self) -> &str {
"fail_mock"
}
}
#[tokio::test]
async fn test_graceful_passes_through_on_success() {
let model = GracefulChatModel::new(Box::new(SuccessModel), "Fallback message".into());
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
let result = model._generate(&msgs, None).await;
assert!(result.is_ok());
let chat_result = result.unwrap();
assert_eq!(chat_result.generations[0].text, "Real response");
}
#[tokio::test]
async fn test_graceful_returns_fallback_on_error() {
let model =
GracefulChatModel::new(Box::new(FailModel), "Sorry, service unavailable".into());
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
let result = model._generate(&msgs, None).await;
assert!(result.is_ok());
let chat_result = result.unwrap();
assert_eq!(
chat_result.generations[0].text,
"Sorry, service unavailable"
);
}
#[tokio::test]
async fn test_graceful_calls_on_error_callback() {
let error_logged = Arc::new(AtomicBool::new(false));
let error_logged_clone = error_logged.clone();
let model = GracefulChatModel::new(Box::new(FailModel), "Fallback".into()).with_on_error(
move |_err| {
error_logged_clone.store(true, Ordering::SeqCst);
},
);
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
let result = model._generate(&msgs, None).await;
assert!(result.is_ok());
assert!(
error_logged.load(Ordering::SeqCst),
"on_error callback should have been called"
);
}
}