use std::sync::Arc;
use tracing::{info, warn};
use converge_core::traits::{
BoxFuture, ChatBackend, ChatRequest, ChatResponse, DynChatBackend, LlmError,
};
pub struct ResilientChatBackend {
primary: Arc<dyn DynChatBackend>,
fallback: Option<Arc<dyn DynChatBackend>>,
primary_label: String,
fallback_label: String,
}
impl ResilientChatBackend {
#[must_use]
pub fn new(primary: Arc<dyn DynChatBackend>, label: impl Into<String>) -> Self {
Self {
primary,
fallback: None,
primary_label: label.into(),
fallback_label: String::new(),
}
}
#[must_use]
pub fn with_fallback(
mut self,
fallback: Arc<dyn DynChatBackend>,
label: impl Into<String>,
) -> Self {
self.fallback = Some(fallback);
self.fallback_label = label.into();
self
}
async fn chat_async(&self, req: ChatRequest) -> Result<ChatResponse, LlmError> {
let original_format = req.response_format;
match self.primary.chat(req.clone()).await {
Ok(response) => return Ok(response),
Err(e) if is_retryable_with_format_change(&e) => {
if let Some(fallback_format) = original_format.fallback() {
warn!(
primary = %self.primary_label,
original_format = ?original_format,
fallback_format = ?fallback_format,
"Format failure, retrying with fallback format"
);
let mut retry_req = req.clone();
retry_req.response_format = fallback_format;
match self.primary.chat(retry_req).await {
Ok(response) => return Ok(response),
Err(retry_err) => return Err(retry_err),
}
}
return Err(e);
}
Err(e) if is_retryable_with_model_change(&e) => {
if let Some(fallback) = &self.fallback {
warn!(
primary = %self.primary_label,
fallback = %self.fallback_label,
error = %e,
"Model failure, retrying with fallback backend"
);
match fallback.chat(req.clone()).await {
Ok(response) => {
info!(
fallback = %self.fallback_label,
"Fallback backend succeeded"
);
return Ok(response);
}
Err(fallback_err) => {
warn!(
fallback = %self.fallback_label,
error = %fallback_err,
"Fallback backend also failed"
);
return Err(e);
}
}
}
return Err(e);
}
Err(e) => return Err(e),
}
}
}
impl ChatBackend for ResilientChatBackend {
type ChatFut<'a>
= BoxFuture<'a, Result<ChatResponse, LlmError>>
where
Self: 'a;
fn chat(&self, req: ChatRequest) -> Self::ChatFut<'_> {
Box::pin(async move { self.chat_async(req).await })
}
}
fn is_retryable_with_format_change(error: &LlmError) -> bool {
matches!(
error,
LlmError::InvalidRequest { .. }
| LlmError::ContentFiltered { .. }
| LlmError::ResponseFormatMismatch { .. }
)
}
fn is_retryable_with_model_change(error: &LlmError) -> bool {
matches!(
error,
LlmError::RateLimited { .. }
| LlmError::ProviderError { .. }
| LlmError::ModelNotFound { .. }
| LlmError::NetworkError { .. }
| LlmError::Timeout { .. }
)
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use converge_core::traits::{ChatMessage, ChatRole, ResponseFormat};
use super::*;
struct FormatAwareBackend {
seen_formats: Mutex<Vec<ResponseFormat>>,
fail_json: bool,
}
impl FormatAwareBackend {
fn new(fail_json: bool) -> Self {
Self {
seen_formats: Mutex::new(Vec::new()),
fail_json,
}
}
fn seen_formats(&self) -> Vec<ResponseFormat> {
self.seen_formats.lock().unwrap().clone()
}
}
impl ChatBackend for FormatAwareBackend {
type ChatFut<'a>
= BoxFuture<'a, Result<ChatResponse, LlmError>>
where
Self: 'a;
fn chat(&self, req: ChatRequest) -> Self::ChatFut<'_> {
self.seen_formats.lock().unwrap().push(req.response_format);
Box::pin(async move {
match req.response_format {
ResponseFormat::Yaml => Err(LlmError::ResponseFormatMismatch {
expected: ResponseFormat::Yaml,
message: "yaml parse failed".to_string(),
}),
ResponseFormat::Json => {
if self.fail_json {
Err(LlmError::ResponseFormatMismatch {
expected: ResponseFormat::Json,
message: "json parse failed".to_string(),
})
} else {
Ok(ChatResponse {
content: "{\"facts\":[]}".to_string(),
tool_calls: Vec::new(),
usage: None,
model: None,
finish_reason: None,
})
}
}
_ => unreachable!(),
}
})
}
}
fn request(response_format: ResponseFormat) -> ChatRequest {
ChatRequest {
messages: vec![ChatMessage {
role: ChatRole::User,
content: "Return structured output".to_string(),
tool_calls: Vec::new(),
tool_call_id: None,
}],
system: None,
tools: Vec::new(),
response_format,
max_tokens: None,
temperature: None,
stop_sequences: Vec::new(),
model: None,
}
}
#[tokio::test(flavor = "current_thread")]
async fn retries_with_json_after_format_mismatch() {
let primary = Arc::new(FormatAwareBackend::new(false));
let backend = ResilientChatBackend::new(primary.clone(), "primary");
let response = ChatBackend::chat(&backend, request(ResponseFormat::Yaml))
.await
.unwrap();
assert_eq!(response.content, "{\"facts\":[]}");
assert_eq!(
primary.seen_formats(),
vec![ResponseFormat::Yaml, ResponseFormat::Json]
);
}
#[tokio::test(flavor = "current_thread")]
async fn preserves_json_format_mismatch_when_no_fallback_exists() {
let primary = Arc::new(FormatAwareBackend::new(true));
let backend = ResilientChatBackend::new(primary, "primary");
let error = ChatBackend::chat(&backend, request(ResponseFormat::Json))
.await
.unwrap_err();
assert!(matches!(
error,
LlmError::ResponseFormatMismatch {
expected: ResponseFormat::Json,
..
}
));
}
}