use crate::error::Error;
use crate::llm::types::{CompletionRequest, CompletionResponse, ContentBlock, StopReason};
use crate::llm::{DynLlmProvider, LlmProvider, OnText};
pub trait ConfidenceGate: Send + Sync {
fn accept(&self, request: &CompletionRequest, response: &CompletionResponse) -> bool;
}
pub struct HeuristicGate {
pub min_output_tokens: u32,
pub refusal_patterns: Vec<String>,
pub accept_tool_calls: bool,
pub escalate_on_max_tokens: bool,
}
impl Default for HeuristicGate {
fn default() -> Self {
Self {
min_output_tokens: 5,
refusal_patterns: default_refusal_patterns(),
accept_tool_calls: true,
escalate_on_max_tokens: true,
}
}
}
fn default_refusal_patterns() -> Vec<String> {
[
"I don't know",
"I'm not sure",
"I cannot",
"I can't",
"I'm unable",
"beyond my capabilities",
"I apologize, but",
]
.iter()
.map(|s| s.to_string())
.collect()
}
impl ConfidenceGate for HeuristicGate {
fn accept(&self, _request: &CompletionRequest, response: &CompletionResponse) -> bool {
if self.accept_tool_calls
&& response
.content
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }))
{
return true;
}
if self.escalate_on_max_tokens && response.stop_reason == StopReason::MaxTokens {
return false;
}
if response.usage.output_tokens < self.min_output_tokens {
return false;
}
let text = response.text().to_lowercase();
for pattern in &self.refusal_patterns {
if text.contains(&pattern.to_lowercase()) {
return false;
}
}
true
}
}
pub struct CascadeTier {
provider: Box<dyn DynLlmProvider>,
label: String,
}
pub struct CascadingProvider {
tiers: Vec<CascadeTier>,
gate: Box<dyn ConfidenceGate>,
}
impl CascadingProvider {
pub fn builder() -> CascadingProviderBuilder {
CascadingProviderBuilder {
tiers: Vec::new(),
gate: None,
}
}
}
impl LlmProvider for CascadingProvider {
fn model_name(&self) -> Option<&str> {
self.tiers.first().map(|t| t.label.as_str())
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
for (i, tier) in self.tiers.iter().enumerate() {
let is_last = i == self.tiers.len() - 1;
match tier.provider.complete(request.clone()).await {
Ok(mut response) => {
if is_last || self.gate.accept(&request, &response) {
response.model = Some(tier.label.clone());
tracing::info!(
tier = %tier.label,
is_last,
output_tokens = response.usage.output_tokens,
"cascade: accepted response"
);
return Ok(response);
}
tracing::info!(
from = %tier.label,
to = %self.tiers[i + 1].label,
"cascade: gate rejected, escalating"
);
}
Err(e) if is_last => return Err(e),
Err(e) => {
tracing::warn!(
tier = %tier.label,
error = %e,
"cascade: tier failed, escalating"
);
}
}
}
unreachable!("cascade must have at least one tier")
}
async fn stream_complete(
&self,
request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
if self.tiers.len() == 1 {
let mut resp = self.tiers[0]
.provider
.stream_complete(request, on_text)
.await?;
resp.model = Some(self.tiers[0].label.clone());
return Ok(resp);
}
for (i, tier) in self.tiers.iter().enumerate() {
let is_last = i == self.tiers.len() - 1;
if is_last {
let mut resp = tier.provider.stream_complete(request, on_text).await?;
resp.model = Some(tier.label.clone());
return Ok(resp);
}
match tier.provider.complete(request.clone()).await {
Ok(mut response) if self.gate.accept(&request, &response) => {
response.model = Some(tier.label.clone());
tracing::info!(
tier = %tier.label,
output_tokens = response.usage.output_tokens,
"cascade: cheap tier accepted (stream path)"
);
let text = response.text();
if !text.is_empty() {
on_text(&text);
}
return Ok(response);
}
Ok(_) => {
tracing::info!(
from = %tier.label,
to = %self.tiers[i + 1].label,
"cascade: gate rejected, escalating"
);
}
Err(e) => {
tracing::warn!(
tier = %tier.label,
error = %e,
"cascade: tier failed, escalating"
);
}
}
}
unreachable!("cascade stream_complete exhausted all tiers without returning")
}
}
pub struct CascadingProviderBuilder {
tiers: Vec<CascadeTier>,
gate: Option<Box<dyn ConfidenceGate>>,
}
impl CascadingProviderBuilder {
pub fn add_tier(
mut self,
label: impl Into<String>,
provider: impl LlmProvider + 'static,
) -> Self {
self.tiers.push(CascadeTier {
provider: Box::new(provider),
label: label.into(),
});
self
}
pub fn gate(mut self, gate: impl ConfidenceGate + 'static) -> Self {
self.gate = Some(Box::new(gate));
self
}
pub fn build(self) -> Result<CascadingProvider, Error> {
if self.tiers.is_empty() {
return Err(Error::Config(
"CascadingProvider requires at least one tier".into(),
));
}
Ok(CascadingProvider {
tiers: self.tiers,
gate: self
.gate
.unwrap_or_else(|| Box::new(HeuristicGate::default())),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
use serde_json::json;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
fn text_response(text: &str, output_tokens: u32) -> CompletionResponse {
CompletionResponse {
content: vec![ContentBlock::Text { text: text.into() }],
stop_reason: StopReason::EndTurn,
usage: TokenUsage {
output_tokens,
..Default::default()
},
model: None,
}
}
fn tool_response() -> CompletionResponse {
CompletionResponse {
content: vec![ContentBlock::ToolUse {
id: "call-1".into(),
name: "search".into(),
input: json!({"q": "rust"}),
}],
stop_reason: StopReason::ToolUse,
usage: TokenUsage {
output_tokens: 20,
..Default::default()
},
model: None,
}
}
fn max_tokens_response() -> CompletionResponse {
CompletionResponse {
content: vec![ContentBlock::Text {
text: "truncated...".into(),
}],
stop_reason: StopReason::MaxTokens,
usage: TokenUsage {
output_tokens: 100,
..Default::default()
},
model: None,
}
}
fn test_request() -> CompletionRequest {
CompletionRequest {
system: String::new(),
messages: vec![Message::user("hello")],
tools: vec![],
max_tokens: 1024,
tool_choice: None,
reasoning_effort: None,
}
}
#[test]
fn heuristic_gate_accepts_normal_response() {
let gate = HeuristicGate::default();
let req = test_request();
let resp = text_response("Salut Pascal! Comment vas-tu?", 10);
assert!(gate.accept(&req, &resp));
}
#[test]
fn heuristic_gate_rejects_short_response() {
let gate = HeuristicGate::default();
let req = test_request();
let resp = text_response("Hi", 2);
assert!(!gate.accept(&req, &resp));
}
#[test]
fn heuristic_gate_rejects_refusal_patterns() {
let gate = HeuristicGate::default();
let req = test_request();
let patterns = [
"I don't know the answer to that.",
"I'm not sure about this topic.",
"I cannot help with that request.",
"I can't do that.",
"I'm unable to assist with this.",
"That is beyond my capabilities.",
"I apologize, but I need more context.",
];
for text in patterns {
let resp = text_response(text, 20);
assert!(!gate.accept(&req, &resp), "should reject: {text}");
}
}
#[test]
fn heuristic_gate_accepts_tool_calls() {
let gate = HeuristicGate::default();
let req = test_request();
let resp = tool_response();
assert!(gate.accept(&req, &resp));
}
#[test]
fn heuristic_gate_rejects_max_tokens() {
let gate = HeuristicGate::default();
let req = test_request();
let resp = max_tokens_response();
assert!(!gate.accept(&req, &resp));
}
#[test]
fn heuristic_gate_default_patterns() {
let gate = HeuristicGate::default();
assert_eq!(gate.min_output_tokens, 5);
assert!(gate.accept_tool_calls);
assert!(gate.escalate_on_max_tokens);
assert!(!gate.refusal_patterns.is_empty());
assert!(gate.refusal_patterns.len() >= 7);
}
#[test]
fn heuristic_gate_case_insensitive_refusal() {
let gate = HeuristicGate::default();
let req = test_request();
let resp = text_response("I DON'T KNOW about that", 10);
assert!(!gate.accept(&req, &resp));
}
struct FixedProvider {
label: &'static str,
response: Result<CompletionResponse, Error>,
call_count: AtomicUsize,
}
impl FixedProvider {
fn ok(label: &'static str, response: CompletionResponse) -> Self {
Self {
label,
response: Ok(response),
call_count: AtomicUsize::new(0),
}
}
fn err(label: &'static str) -> Self {
Self {
label,
response: Err(Error::Api {
status: 500,
message: "tier error".into(),
}),
call_count: AtomicUsize::new(0),
}
}
}
impl LlmProvider for FixedProvider {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
self.call_count.fetch_add(1, Ordering::Relaxed);
match &self.response {
Ok(r) => Ok(r.clone()),
Err(e) => Err(Error::Api {
status: match e {
Error::Api { status, .. } => *status,
_ => 500,
},
message: format!("{} error", self.label),
}),
}
}
async fn stream_complete(
&self,
_request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
self.call_count.fetch_add(1, Ordering::Relaxed);
match &self.response {
Ok(r) => {
let text = r.text();
if !text.is_empty() {
on_text(&text);
}
Ok(r.clone())
}
Err(_) => Err(Error::Api {
status: 500,
message: format!("{} error", self.label),
}),
}
}
fn model_name(&self) -> Option<&str> {
Some(self.label)
}
}
struct AlwaysAccept;
impl ConfidenceGate for AlwaysAccept {
fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
true
}
}
struct AlwaysReject;
impl ConfidenceGate for AlwaysReject {
fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
false
}
}
#[tokio::test]
async fn single_tier_delegates_directly() {
let provider = CascadingProvider::builder()
.add_tier(
"haiku",
FixedProvider::ok("haiku", text_response("hello", 10)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "hello");
assert_eq!(resp.model.as_deref(), Some("haiku"));
}
#[tokio::test]
async fn two_tier_accepts_cheap_when_gate_passes() {
let provider = CascadingProvider::builder()
.add_tier(
"haiku",
FixedProvider::ok("haiku", text_response("Salut!", 10)),
)
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("expensive", 50)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "Salut!");
assert_eq!(resp.model.as_deref(), Some("haiku"));
}
#[tokio::test]
async fn two_tier_escalates_when_gate_rejects() {
let provider = CascadingProvider::builder()
.add_tier(
"haiku",
FixedProvider::ok("haiku", text_response("dunno", 10)),
)
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("great answer", 50)),
)
.gate(AlwaysReject)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "great answer");
assert_eq!(resp.model.as_deref(), Some("sonnet"));
}
#[tokio::test]
async fn three_tier_skips_erroring_tier() {
let provider = CascadingProvider::builder()
.add_tier("haiku", FixedProvider::err("haiku"))
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("mid", 10)),
)
.add_tier(
"opus",
FixedProvider::ok("opus", text_response("expensive", 50)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "mid");
assert_eq!(resp.model.as_deref(), Some("sonnet"));
}
#[tokio::test]
async fn final_tier_always_accepts() {
let provider = CascadingProvider::builder()
.add_tier(
"haiku",
FixedProvider::ok("haiku", text_response("cheap", 10)),
)
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("final", 50)),
)
.gate(AlwaysReject)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "final");
assert_eq!(resp.model.as_deref(), Some("sonnet"));
}
#[tokio::test]
async fn stream_uses_complete_for_non_final_tiers() {
struct CompleteOnlyProvider;
impl LlmProvider for CompleteOnlyProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
Ok(text_response("cheap answer", 10))
}
async fn stream_complete(
&self,
_request: CompletionRequest,
_on_text: &OnText,
) -> Result<CompletionResponse, Error> {
panic!("non-final tier should not call stream_complete");
}
}
let provider = CascadingProvider::builder()
.add_tier("cheap", CompleteOnlyProvider)
.add_tier(
"expensive",
FixedProvider::ok("expensive", text_response("expensive", 50)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let on_text: &OnText = &|_| {};
let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(resp.text(), "cheap answer");
}
#[tokio::test]
async fn stream_emits_text_when_cheap_accepted() {
let collected = Arc::new(Mutex::new(Vec::<String>::new()));
let collected_clone = collected.clone();
let on_text: &OnText = &move |text: &str| {
collected_clone.lock().expect("lock").push(text.to_string());
};
let provider = CascadingProvider::builder()
.add_tier(
"cheap",
FixedProvider::ok("cheap", text_response("hello world", 10)),
)
.add_tier(
"expensive",
FixedProvider::ok("expensive", text_response("expensive", 50)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(resp.text(), "hello world");
let texts = collected.lock().expect("lock");
assert_eq!(*texts, vec!["hello world"]);
}
#[tokio::test]
async fn stream_streams_final_tier() {
let streamed = Arc::new(Mutex::new(Vec::<String>::new()));
let streamed_clone = streamed.clone();
let on_text: &OnText = &move |text: &str| {
streamed_clone.lock().expect("lock").push(text.to_string());
};
struct StreamingProvider;
impl LlmProvider for StreamingProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
panic!("final tier with streaming should use stream_complete");
}
async fn stream_complete(
&self,
_request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
on_text("streamed ");
on_text("response");
Ok(CompletionResponse {
content: vec![ContentBlock::Text {
text: "streamed response".into(),
}],
stop_reason: StopReason::EndTurn,
usage: TokenUsage {
output_tokens: 20,
..Default::default()
},
model: None,
})
}
}
let provider = CascadingProvider::builder()
.add_tier(
"cheap",
FixedProvider::ok("cheap", text_response("dunno", 10)),
)
.add_tier("expensive", StreamingProvider)
.gate(AlwaysReject)
.build()
.unwrap();
let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(resp.text(), "streamed response");
assert_eq!(resp.model.as_deref(), Some("expensive"));
let texts = streamed.lock().expect("lock");
assert_eq!(*texts, vec!["streamed ", "response"]);
}
#[tokio::test]
async fn response_model_set_to_accepting_tier() {
let provider = CascadingProvider::builder()
.add_tier("haiku", FixedProvider::err("haiku"))
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("answer", 10)),
)
.gate(AlwaysAccept)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.model.as_deref(), Some("sonnet"));
}
#[test]
fn builder_rejects_zero_tiers() {
let result = CascadingProvider::builder().gate(AlwaysAccept).build();
assert!(result.is_err());
}
#[test]
fn cascading_provider_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CascadingProvider>();
}
#[test]
fn builder_defaults_to_heuristic_gate() {
let provider = CascadingProvider::builder()
.add_tier("haiku", FixedProvider::ok("haiku", text_response("hi", 10)))
.build()
.unwrap();
assert_eq!(LlmProvider::model_name(&provider), Some("haiku"));
}
#[tokio::test]
async fn single_tier_streams_directly() {
struct StreamOnlyProvider;
impl LlmProvider for StreamOnlyProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
panic!("single tier should stream directly");
}
async fn stream_complete(
&self,
_request: CompletionRequest,
on_text: &OnText,
) -> Result<CompletionResponse, Error> {
on_text("streamed");
Ok(text_response("streamed", 10))
}
}
let provider = CascadingProvider::builder()
.add_tier("only", StreamOnlyProvider)
.gate(AlwaysAccept)
.build()
.unwrap();
let on_text: &OnText = &|_| {};
let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
.await
.unwrap();
assert_eq!(resp.text(), "streamed");
assert_eq!(resp.model.as_deref(), Some("only"));
}
#[tokio::test]
async fn all_tiers_error_returns_last_error() {
let provider = CascadingProvider::builder()
.add_tier("tier1", FixedProvider::err("tier1"))
.add_tier("tier2", FixedProvider::err("tier2"))
.gate(AlwaysAccept)
.build()
.unwrap();
let err = LlmProvider::complete(&provider, test_request())
.await
.unwrap_err();
assert!(err.to_string().contains("tier2"), "error: {err}");
}
#[tokio::test]
async fn heuristic_gate_integration_with_cascade() {
let provider = CascadingProvider::builder()
.add_tier("haiku", FixedProvider::ok("haiku", text_response("Hi", 2)))
.add_tier(
"sonnet",
FixedProvider::ok("sonnet", text_response("detailed answer here", 30)),
)
.build()
.unwrap();
let resp = LlmProvider::complete(&provider, test_request())
.await
.unwrap();
assert_eq!(resp.text(), "detailed answer here");
assert_eq!(resp.model.as_deref(), Some("sonnet"));
}
}