1use crate::error::Error;
4use crate::llm::types::{CompletionRequest, CompletionResponse, ContentBlock, StopReason};
5use crate::llm::{DynLlmProvider, LlmProvider, OnText};
6
7pub trait ConfidenceGate: Send + Sync {
10 fn accept(&self, request: &CompletionRequest, response: &CompletionResponse) -> bool;
12}
13
14pub struct HeuristicGate {
16 pub min_output_tokens: u32,
18 pub refusal_patterns: Vec<String>,
20 pub accept_tool_calls: bool,
22 pub escalate_on_max_tokens: bool,
24}
25
26impl Default for HeuristicGate {
27 fn default() -> Self {
28 Self {
29 min_output_tokens: 5,
30 refusal_patterns: default_refusal_patterns(),
31 accept_tool_calls: true,
32 escalate_on_max_tokens: true,
33 }
34 }
35}
36
37fn default_refusal_patterns() -> Vec<String> {
38 [
46 "I don't have enough information",
47 "I'm unable to help",
48 "beyond my capabilities",
49 "I apologize, but I cannot",
50 ]
51 .iter()
52 .map(|s| s.to_string())
53 .collect()
54}
55
56impl ConfidenceGate for HeuristicGate {
57 fn accept(&self, _request: &CompletionRequest, response: &CompletionResponse) -> bool {
58 if self.accept_tool_calls
60 && response
61 .content
62 .iter()
63 .any(|b| matches!(b, ContentBlock::ToolUse { .. }))
64 {
65 return true;
66 }
67
68 if self.escalate_on_max_tokens && response.stop_reason == StopReason::MaxTokens {
70 return false;
71 }
72
73 if response.usage.output_tokens < self.min_output_tokens {
75 return false;
76 }
77
78 let text = response.text().to_lowercase();
80 for pattern in &self.refusal_patterns {
81 if text.contains(&pattern.to_lowercase()) {
82 return false;
83 }
84 }
85
86 true
88 }
89}
90
91pub struct CascadeTier {
93 provider: Box<dyn DynLlmProvider>,
94 label: String,
95}
96
97pub struct CascadingProvider {
104 tiers: Vec<CascadeTier>,
105 gate: Box<dyn ConfidenceGate>,
106}
107
108impl CascadingProvider {
109 pub fn builder() -> CascadingProviderBuilder {
111 CascadingProviderBuilder {
112 tiers: Vec::new(),
113 gate: None,
114 }
115 }
116}
117
118impl LlmProvider for CascadingProvider {
119 fn model_name(&self) -> Option<&str> {
120 self.tiers.first().map(|t| t.label.as_str())
121 }
122
123 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
124 for (i, tier) in self.tiers.iter().enumerate() {
125 let is_last = i == self.tiers.len() - 1;
126 match tier.provider.complete(request.clone()).await {
127 Ok(mut response) => {
128 if is_last || self.gate.accept(&request, &response) {
129 response.model = Some(tier.label.clone());
130 tracing::info!(
131 tier = %tier.label,
132 is_last,
133 output_tokens = response.usage.output_tokens,
134 "cascade: accepted response"
135 );
136 return Ok(response);
137 }
138 tracing::info!(
139 from = %tier.label,
140 to = %self.tiers[i + 1].label,
141 "cascade: gate rejected, escalating"
142 );
143 }
144 Err(e) if is_last => return Err(e),
145 Err(e) => {
146 tracing::warn!(
147 tier = %tier.label,
148 error = %e,
149 "cascade: tier failed, escalating"
150 );
151 }
152 }
153 }
154 unreachable!("cascade must have at least one tier")
155 }
156
157 async fn stream_complete(
158 &self,
159 request: CompletionRequest,
160 on_text: &OnText,
161 ) -> Result<CompletionResponse, Error> {
162 if self.tiers.len() == 1 {
164 let mut resp = self.tiers[0]
165 .provider
166 .stream_complete(request, on_text)
167 .await?;
168 resp.model = Some(self.tiers[0].label.clone());
169 return Ok(resp);
170 }
171
172 for (i, tier) in self.tiers.iter().enumerate() {
174 let is_last = i == self.tiers.len() - 1;
175 if is_last {
176 let mut resp = tier.provider.stream_complete(request, on_text).await?;
177 resp.model = Some(tier.label.clone());
178 return Ok(resp);
179 }
180 match tier.provider.complete(request.clone()).await {
181 Ok(mut response) if self.gate.accept(&request, &response) => {
182 response.model = Some(tier.label.clone());
183 tracing::info!(
184 tier = %tier.label,
185 output_tokens = response.usage.output_tokens,
186 "cascade: cheap tier accepted (stream path)"
187 );
188 let text = response.text();
190 if !text.is_empty() {
191 on_text(&text);
192 }
193 return Ok(response);
194 }
195 Ok(_) => {
196 tracing::info!(
197 from = %tier.label,
198 to = %self.tiers[i + 1].label,
199 "cascade: gate rejected, escalating"
200 );
201 }
202 Err(e) => {
203 tracing::warn!(
204 tier = %tier.label,
205 error = %e,
206 "cascade: tier failed, escalating"
207 );
208 }
209 }
210 }
211 unreachable!("cascade stream_complete exhausted all tiers without returning")
212 }
213}
214
215pub struct CascadingProviderBuilder {
217 tiers: Vec<CascadeTier>,
218 gate: Option<Box<dyn ConfidenceGate>>,
219}
220
221impl CascadingProviderBuilder {
222 pub fn add_tier(
224 mut self,
225 label: impl Into<String>,
226 provider: impl LlmProvider + 'static,
227 ) -> Self {
228 self.tiers.push(CascadeTier {
229 provider: Box::new(provider),
230 label: label.into(),
231 });
232 self
233 }
234
235 pub fn gate(mut self, gate: impl ConfidenceGate + 'static) -> Self {
237 self.gate = Some(Box::new(gate));
238 self
239 }
240
241 pub fn build(self) -> Result<CascadingProvider, Error> {
243 if self.tiers.is_empty() {
244 return Err(Error::Config(
245 "CascadingProvider requires at least one tier".into(),
246 ));
247 }
248 Ok(CascadingProvider {
249 tiers: self.tiers,
250 gate: self
251 .gate
252 .unwrap_or_else(|| Box::new(HeuristicGate::default())),
253 })
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
261 use serde_json::json;
262 use std::sync::atomic::{AtomicUsize, Ordering};
263 use std::sync::{Arc, Mutex};
264
265 fn text_response(text: &str, output_tokens: u32) -> CompletionResponse {
266 CompletionResponse {
267 content: vec![ContentBlock::Text { text: text.into() }],
268 stop_reason: StopReason::EndTurn,
269 usage: TokenUsage {
270 output_tokens,
271 ..Default::default()
272 },
273 model: None,
274 }
275 }
276
277 fn tool_response() -> CompletionResponse {
278 CompletionResponse {
279 content: vec![ContentBlock::ToolUse {
280 id: "call-1".into(),
281 name: "search".into(),
282 input: json!({"q": "rust"}),
283 }],
284 stop_reason: StopReason::ToolUse,
285 usage: TokenUsage {
286 output_tokens: 20,
287 ..Default::default()
288 },
289 model: None,
290 }
291 }
292
293 fn max_tokens_response() -> CompletionResponse {
294 CompletionResponse {
295 content: vec![ContentBlock::Text {
296 text: "truncated...".into(),
297 }],
298 stop_reason: StopReason::MaxTokens,
299 usage: TokenUsage {
300 output_tokens: 100,
301 ..Default::default()
302 },
303 model: None,
304 }
305 }
306
307 fn test_request() -> CompletionRequest {
308 CompletionRequest {
309 system: String::new(),
310 messages: vec![Message::user("hello")],
311 tools: vec![],
312 max_tokens: 1024,
313 tool_choice: None,
314 reasoning_effort: None,
315 }
316 }
317
318 #[test]
321 fn heuristic_gate_accepts_normal_response() {
322 let gate = HeuristicGate::default();
323 let req = test_request();
324 let resp = text_response("Salut Pascal! Comment vas-tu?", 10);
325 assert!(gate.accept(&req, &resp));
326 }
327
328 #[test]
329 fn heuristic_gate_rejects_short_response() {
330 let gate = HeuristicGate::default();
331 let req = test_request();
332 let resp = text_response("Hi", 2);
333 assert!(!gate.accept(&req, &resp));
334 }
335
336 #[test]
337 fn heuristic_gate_rejects_refusal_patterns() {
338 let gate = HeuristicGate::default();
339 let req = test_request();
340
341 let patterns = [
345 "I don't have enough information to answer.",
346 "I'm unable to help with that request.",
347 "That topic is beyond my capabilities.",
348 "I apologize, but I cannot help with this.",
349 ];
350 for text in patterns {
351 let resp = text_response(text, 20);
352 assert!(!gate.accept(&req, &resp), "should reject: {text}");
353 }
354 }
355
356 #[test]
357 fn heuristic_gate_accepts_tool_calls() {
358 let gate = HeuristicGate::default();
359 let req = test_request();
360 let resp = tool_response();
361 assert!(gate.accept(&req, &resp));
362 }
363
364 #[test]
365 fn heuristic_gate_rejects_max_tokens() {
366 let gate = HeuristicGate::default();
367 let req = test_request();
368 let resp = max_tokens_response();
369 assert!(!gate.accept(&req, &resp));
370 }
371
372 #[test]
373 fn heuristic_gate_default_patterns() {
374 let gate = HeuristicGate::default();
375 assert_eq!(gate.min_output_tokens, 5);
376 assert!(gate.accept_tool_calls);
377 assert!(gate.escalate_on_max_tokens);
378 assert!(!gate.refusal_patterns.is_empty());
382 }
383
384 #[test]
385 fn heuristic_gate_case_insensitive_refusal() {
386 let gate = HeuristicGate::default();
387 let req = test_request();
388 let resp = text_response("I'M UNABLE TO HELP with that", 10);
391 assert!(!gate.accept(&req, &resp));
392 }
393
394 struct FixedProvider {
397 label: &'static str,
398 response: Result<CompletionResponse, Error>,
399 call_count: AtomicUsize,
400 }
401
402 impl FixedProvider {
403 fn ok(label: &'static str, response: CompletionResponse) -> Self {
404 Self {
405 label,
406 response: Ok(response),
407 call_count: AtomicUsize::new(0),
408 }
409 }
410
411 fn err(label: &'static str) -> Self {
412 Self {
413 label,
414 response: Err(Error::Api {
415 status: 500,
416 message: "tier error".into(),
417 }),
418 call_count: AtomicUsize::new(0),
419 }
420 }
421 }
422
423 impl LlmProvider for FixedProvider {
424 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
425 self.call_count.fetch_add(1, Ordering::Relaxed);
426 match &self.response {
427 Ok(r) => Ok(r.clone()),
428 Err(e) => Err(Error::Api {
429 status: match e {
430 Error::Api { status, .. } => *status,
431 _ => 500,
432 },
433 message: format!("{} error", self.label),
434 }),
435 }
436 }
437
438 async fn stream_complete(
439 &self,
440 _request: CompletionRequest,
441 on_text: &OnText,
442 ) -> Result<CompletionResponse, Error> {
443 self.call_count.fetch_add(1, Ordering::Relaxed);
444 match &self.response {
445 Ok(r) => {
446 let text = r.text();
447 if !text.is_empty() {
448 on_text(&text);
449 }
450 Ok(r.clone())
451 }
452 Err(_) => Err(Error::Api {
453 status: 500,
454 message: format!("{} error", self.label),
455 }),
456 }
457 }
458
459 fn model_name(&self) -> Option<&str> {
460 Some(self.label)
461 }
462 }
463
464 struct AlwaysAccept;
466 impl ConfidenceGate for AlwaysAccept {
467 fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
468 true
469 }
470 }
471
472 struct AlwaysReject;
474 impl ConfidenceGate for AlwaysReject {
475 fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
476 false
477 }
478 }
479
480 #[tokio::test]
483 async fn single_tier_delegates_directly() {
484 let provider = CascadingProvider::builder()
485 .add_tier(
486 "haiku",
487 FixedProvider::ok("haiku", text_response("hello", 10)),
488 )
489 .gate(AlwaysAccept)
490 .build()
491 .unwrap();
492
493 let resp = LlmProvider::complete(&provider, test_request())
494 .await
495 .unwrap();
496 assert_eq!(resp.text(), "hello");
497 assert_eq!(resp.model.as_deref(), Some("haiku"));
498 }
499
500 #[tokio::test]
501 async fn two_tier_accepts_cheap_when_gate_passes() {
502 let provider = CascadingProvider::builder()
503 .add_tier(
504 "haiku",
505 FixedProvider::ok("haiku", text_response("Salut!", 10)),
506 )
507 .add_tier(
508 "sonnet",
509 FixedProvider::ok("sonnet", text_response("expensive", 50)),
510 )
511 .gate(AlwaysAccept)
512 .build()
513 .unwrap();
514
515 let resp = LlmProvider::complete(&provider, test_request())
516 .await
517 .unwrap();
518 assert_eq!(resp.text(), "Salut!");
519 assert_eq!(resp.model.as_deref(), Some("haiku"));
520 }
523
524 #[tokio::test]
525 async fn two_tier_escalates_when_gate_rejects() {
526 let provider = CascadingProvider::builder()
527 .add_tier(
528 "haiku",
529 FixedProvider::ok("haiku", text_response("dunno", 10)),
530 )
531 .add_tier(
532 "sonnet",
533 FixedProvider::ok("sonnet", text_response("great answer", 50)),
534 )
535 .gate(AlwaysReject)
536 .build()
537 .unwrap();
538
539 let resp = LlmProvider::complete(&provider, test_request())
540 .await
541 .unwrap();
542 assert_eq!(resp.text(), "great answer");
545 assert_eq!(resp.model.as_deref(), Some("sonnet"));
546 }
547
548 #[tokio::test]
549 async fn three_tier_skips_erroring_tier() {
550 let provider = CascadingProvider::builder()
551 .add_tier("haiku", FixedProvider::err("haiku"))
552 .add_tier(
553 "sonnet",
554 FixedProvider::ok("sonnet", text_response("mid", 10)),
555 )
556 .add_tier(
557 "opus",
558 FixedProvider::ok("opus", text_response("expensive", 50)),
559 )
560 .gate(AlwaysAccept)
561 .build()
562 .unwrap();
563
564 let resp = LlmProvider::complete(&provider, test_request())
565 .await
566 .unwrap();
567 assert_eq!(resp.text(), "mid");
568 assert_eq!(resp.model.as_deref(), Some("sonnet"));
569 }
570
571 #[tokio::test]
572 async fn final_tier_always_accepts() {
573 let provider = CascadingProvider::builder()
575 .add_tier(
576 "haiku",
577 FixedProvider::ok("haiku", text_response("cheap", 10)),
578 )
579 .add_tier(
580 "sonnet",
581 FixedProvider::ok("sonnet", text_response("final", 50)),
582 )
583 .gate(AlwaysReject)
584 .build()
585 .unwrap();
586
587 let resp = LlmProvider::complete(&provider, test_request())
588 .await
589 .unwrap();
590 assert_eq!(resp.text(), "final");
591 assert_eq!(resp.model.as_deref(), Some("sonnet"));
592 }
593
594 #[tokio::test]
595 async fn stream_uses_complete_for_non_final_tiers() {
596 struct CompleteOnlyProvider;
599 impl LlmProvider for CompleteOnlyProvider {
600 async fn complete(
601 &self,
602 _request: CompletionRequest,
603 ) -> Result<CompletionResponse, Error> {
604 Ok(text_response("cheap answer", 10))
605 }
606 async fn stream_complete(
607 &self,
608 _request: CompletionRequest,
609 _on_text: &OnText,
610 ) -> Result<CompletionResponse, Error> {
611 panic!("non-final tier should not call stream_complete");
612 }
613 }
614
615 let provider = CascadingProvider::builder()
616 .add_tier("cheap", CompleteOnlyProvider)
617 .add_tier(
618 "expensive",
619 FixedProvider::ok("expensive", text_response("expensive", 50)),
620 )
621 .gate(AlwaysAccept)
622 .build()
623 .unwrap();
624
625 let on_text: &OnText = &|_| {};
626 let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
627 .await
628 .unwrap();
629 assert_eq!(resp.text(), "cheap answer");
630 }
631
632 #[tokio::test]
633 async fn stream_emits_text_when_cheap_accepted() {
634 let collected = Arc::new(Mutex::new(Vec::<String>::new()));
635 let collected_clone = collected.clone();
636 let on_text: &OnText = &move |text: &str| {
637 collected_clone.lock().expect("lock").push(text.to_string());
638 };
639
640 let provider = CascadingProvider::builder()
641 .add_tier(
642 "cheap",
643 FixedProvider::ok("cheap", text_response("hello world", 10)),
644 )
645 .add_tier(
646 "expensive",
647 FixedProvider::ok("expensive", text_response("expensive", 50)),
648 )
649 .gate(AlwaysAccept)
650 .build()
651 .unwrap();
652
653 let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
654 .await
655 .unwrap();
656 assert_eq!(resp.text(), "hello world");
657
658 let texts = collected.lock().expect("lock");
659 assert_eq!(*texts, vec!["hello world"]);
660 }
661
662 #[tokio::test]
663 async fn stream_streams_final_tier() {
664 let streamed = Arc::new(Mutex::new(Vec::<String>::new()));
666 let streamed_clone = streamed.clone();
667 let on_text: &OnText = &move |text: &str| {
668 streamed_clone.lock().expect("lock").push(text.to_string());
669 };
670
671 struct StreamingProvider;
672 impl LlmProvider for StreamingProvider {
673 async fn complete(
674 &self,
675 _request: CompletionRequest,
676 ) -> Result<CompletionResponse, Error> {
677 panic!("final tier with streaming should use stream_complete");
678 }
679 async fn stream_complete(
680 &self,
681 _request: CompletionRequest,
682 on_text: &OnText,
683 ) -> Result<CompletionResponse, Error> {
684 on_text("streamed ");
685 on_text("response");
686 Ok(CompletionResponse {
687 content: vec![ContentBlock::Text {
688 text: "streamed response".into(),
689 }],
690 stop_reason: StopReason::EndTurn,
691 usage: TokenUsage {
692 output_tokens: 20,
693 ..Default::default()
694 },
695 model: None,
696 })
697 }
698 }
699
700 let provider = CascadingProvider::builder()
701 .add_tier(
702 "cheap",
703 FixedProvider::ok("cheap", text_response("dunno", 10)),
704 )
705 .add_tier("expensive", StreamingProvider)
706 .gate(AlwaysReject)
707 .build()
708 .unwrap();
709
710 let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
711 .await
712 .unwrap();
713 assert_eq!(resp.text(), "streamed response");
714 assert_eq!(resp.model.as_deref(), Some("expensive"));
715
716 let texts = streamed.lock().expect("lock");
717 assert_eq!(*texts, vec!["streamed ", "response"]);
718 }
719
720 #[tokio::test]
721 async fn response_model_set_to_accepting_tier() {
722 let provider = CascadingProvider::builder()
723 .add_tier("haiku", FixedProvider::err("haiku"))
724 .add_tier(
725 "sonnet",
726 FixedProvider::ok("sonnet", text_response("answer", 10)),
727 )
728 .gate(AlwaysAccept)
729 .build()
730 .unwrap();
731
732 let resp = LlmProvider::complete(&provider, test_request())
733 .await
734 .unwrap();
735 assert_eq!(resp.model.as_deref(), Some("sonnet"));
736 }
737
738 #[test]
739 fn builder_rejects_zero_tiers() {
740 let result = CascadingProvider::builder().gate(AlwaysAccept).build();
741 assert!(result.is_err());
742 }
743
744 #[test]
745 fn cascading_provider_is_send_sync() {
746 fn assert_send_sync<T: Send + Sync>() {}
747 assert_send_sync::<CascadingProvider>();
748 }
749
750 #[test]
751 fn builder_defaults_to_heuristic_gate() {
752 let provider = CascadingProvider::builder()
753 .add_tier("haiku", FixedProvider::ok("haiku", text_response("hi", 10)))
754 .build()
755 .unwrap();
756 assert_eq!(LlmProvider::model_name(&provider), Some("haiku"));
758 }
759
760 #[tokio::test]
761 async fn single_tier_streams_directly() {
762 struct StreamOnlyProvider;
764 impl LlmProvider for StreamOnlyProvider {
765 async fn complete(
766 &self,
767 _request: CompletionRequest,
768 ) -> Result<CompletionResponse, Error> {
769 panic!("single tier should stream directly");
770 }
771 async fn stream_complete(
772 &self,
773 _request: CompletionRequest,
774 on_text: &OnText,
775 ) -> Result<CompletionResponse, Error> {
776 on_text("streamed");
777 Ok(text_response("streamed", 10))
778 }
779 }
780
781 let provider = CascadingProvider::builder()
782 .add_tier("only", StreamOnlyProvider)
783 .gate(AlwaysAccept)
784 .build()
785 .unwrap();
786
787 let on_text: &OnText = &|_| {};
788 let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
789 .await
790 .unwrap();
791 assert_eq!(resp.text(), "streamed");
792 assert_eq!(resp.model.as_deref(), Some("only"));
793 }
794
795 #[tokio::test]
796 async fn all_tiers_error_returns_last_error() {
797 let provider = CascadingProvider::builder()
798 .add_tier("tier1", FixedProvider::err("tier1"))
799 .add_tier("tier2", FixedProvider::err("tier2"))
800 .gate(AlwaysAccept)
801 .build()
802 .unwrap();
803
804 let err = LlmProvider::complete(&provider, test_request())
805 .await
806 .unwrap_err();
807 assert!(err.to_string().contains("tier2"), "error: {err}");
808 }
809
810 #[tokio::test]
811 async fn heuristic_gate_integration_with_cascade() {
812 let provider = CascadingProvider::builder()
814 .add_tier("haiku", FixedProvider::ok("haiku", text_response("Hi", 2)))
815 .add_tier(
816 "sonnet",
817 FixedProvider::ok("sonnet", text_response("detailed answer here", 30)),
818 )
819 .build()
821 .unwrap();
822
823 let resp = LlmProvider::complete(&provider, test_request())
824 .await
825 .unwrap();
826 assert_eq!(resp.text(), "detailed answer here");
827 assert_eq!(resp.model.as_deref(), Some("sonnet"));
828 }
829}