use crate::provider::{ChatMessage, ChatResponse, Provider, ToolSpec};
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tracing::{info, warn};
pub struct FallbackProvider {
primary: Arc<dyn Provider>,
fallback: Arc<dyn Provider>,
}
impl FallbackProvider {
pub fn new(primary: Arc<dyn Provider>, fallback: Arc<dyn Provider>) -> Self {
Self { primary, fallback }
}
}
pub fn is_refusal(resp: &ChatResponse) -> bool {
if matches!(
resp.stop_reason.as_deref(),
Some("refusal") | Some("content_filter")
) {
return true;
}
if !resp.tool_calls.is_empty() {
return false;
}
let Some(text) = resp.text.as_deref() else {
return false;
};
if text.chars().count() > 600 {
return false;
}
let head = text
.trim_start()
.chars()
.take(120)
.collect::<String>()
.to_lowercase();
const PATTERNS: &[&str] = &[
"i can't help",
"i cannot help",
"i can't assist",
"i cannot assist",
"i'm not able to",
"i am not able to",
"i'm unable to",
"i am unable to",
"i won't",
"i will not",
"申し訳",
"お手伝いできません",
"対応できません",
];
PATTERNS.iter().any(|p| head.contains(p))
}
#[async_trait]
impl Provider for FallbackProvider {
fn name(&self) -> &str {
self.primary.name()
}
async fn chat(
&self,
system: Option<&str>,
messages: &[ChatMessage],
tools: Option<&[ToolSpec]>,
) -> Result<ChatResponse> {
match self.primary.chat(system, messages, tools).await {
Ok(resp) if is_refusal(&resp) => {
info!(
"Primary provider '{}' refused; retrying via fallback '{}'",
self.primary.name(),
self.fallback.name()
);
self.fallback.chat(system, messages, tools).await
}
Ok(resp) => Ok(resp),
Err(e) => {
warn!(
"Primary provider '{}' errored ({e:#}); retrying via fallback '{}'",
self.primary.name(),
self.fallback.name()
);
self.fallback.chat(system, messages, tools).await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{ChatMessage, ChatResponse, ToolCall};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Mutex;
struct ScriptedProvider {
name: String,
responses: Mutex<Vec<Result<ChatResponse, String>>>,
calls: Mutex<usize>,
}
impl ScriptedProvider {
fn new(name: &str, responses: Vec<Result<ChatResponse, String>>) -> Arc<Self> {
Arc::new(Self {
name: name.to_string(),
responses: Mutex::new(responses),
calls: Mutex::new(0),
})
}
fn calls(&self) -> usize {
*self.calls.lock().unwrap()
}
}
#[async_trait]
impl Provider for ScriptedProvider {
fn name(&self) -> &str {
&self.name
}
async fn chat(
&self,
_system: Option<&str>,
_messages: &[ChatMessage],
_tools: Option<&[ToolSpec]>,
) -> Result<ChatResponse> {
*self.calls.lock().unwrap() += 1;
let next = self.responses.lock().unwrap().remove(0);
next.map_err(|e| anyhow::anyhow!(e))
}
}
fn refusal_response() -> ChatResponse {
ChatResponse {
text: Some("I can't help with that request.".into()),
tool_calls: vec![],
stop_reason: Some("refusal".into()),
}
}
fn ok_response(text: &str) -> ChatResponse {
ChatResponse {
text: Some(text.into()),
tool_calls: vec![],
stop_reason: Some("end_turn".into()),
}
}
#[test]
fn refusal_stop_reason_is_detected() {
assert!(is_refusal(&refusal_response()));
}
#[test]
fn content_filter_stop_reason_is_detected() {
let resp = ChatResponse {
text: Some("...".into()),
tool_calls: vec![],
stop_reason: Some("content_filter".into()),
};
assert!(is_refusal(&resp));
}
#[test]
fn apology_pattern_without_stop_reason_is_detected() {
let resp = ChatResponse {
text: Some("I'm unable to assist with that one.".into()),
tool_calls: vec![],
stop_reason: Some("end_turn".into()),
};
assert!(is_refusal(&resp));
}
#[test]
fn legitimate_short_answer_is_not_a_refusal() {
let resp = ChatResponse {
text: Some("The file is at src/main.rs.".into()),
tool_calls: vec![],
stop_reason: Some("end_turn".into()),
};
assert!(!is_refusal(&resp));
}
#[test]
fn long_response_is_not_a_refusal_even_if_apology_appears() {
let mut text =
"I'm unable to immediately answer, but here's a detailed walkthrough: ".to_string();
text.push_str(&"x".repeat(800));
let resp = ChatResponse {
text: Some(text),
tool_calls: vec![],
stop_reason: Some("end_turn".into()),
};
assert!(!is_refusal(&resp));
}
#[test]
fn tool_call_short_circuits_text_heuristic() {
let resp = ChatResponse {
text: Some("I can't help directly, calling tool…".into()),
tool_calls: vec![ToolCall {
id: "x".into(),
name: "y".into(),
input: json!({}),
}],
stop_reason: Some("tool_use".into()),
};
assert!(!is_refusal(&resp));
}
#[tokio::test]
async fn fallback_takes_over_on_refusal() {
let primary = ScriptedProvider::new("primary", vec![Ok(refusal_response())]);
let fallback = ScriptedProvider::new("fallback", vec![Ok(ok_response("done"))]);
let p: Arc<dyn Provider> = primary.clone();
let f: Arc<dyn Provider> = fallback.clone();
let wrap = FallbackProvider::new(p, f);
let resp = wrap.chat(None, &[], None).await.unwrap();
assert_eq!(resp.text.as_deref(), Some("done"));
assert_eq!(primary.calls(), 1);
assert_eq!(fallback.calls(), 1);
}
#[tokio::test]
async fn fallback_takes_over_on_error() {
let primary = ScriptedProvider::new("primary", vec![Err("boom".to_string())]);
let fallback = ScriptedProvider::new("fallback", vec![Ok(ok_response("done"))]);
let p: Arc<dyn Provider> = primary.clone();
let f: Arc<dyn Provider> = fallback.clone();
let wrap = FallbackProvider::new(p, f);
let resp = wrap.chat(None, &[], None).await.unwrap();
assert_eq!(resp.text.as_deref(), Some("done"));
assert_eq!(primary.calls(), 1);
assert_eq!(fallback.calls(), 1);
}
#[tokio::test]
async fn primary_success_skips_fallback() {
let primary = ScriptedProvider::new("primary", vec![Ok(ok_response("hi"))]);
let fallback = ScriptedProvider::new("fallback", vec![Ok(ok_response("nope"))]);
let p: Arc<dyn Provider> = primary.clone();
let f: Arc<dyn Provider> = fallback.clone();
let wrap = FallbackProvider::new(p, f);
let resp = wrap.chat(None, &[], None).await.unwrap();
assert_eq!(resp.text.as_deref(), Some("hi"));
assert_eq!(primary.calls(), 1);
assert_eq!(fallback.calls(), 0);
}
}