Skip to main content

heartbit_core/llm/
cascade.rs

1//! Cascading provider — tries cheaper models first and escalates on rejection or error.
2
3use crate::error::Error;
4use crate::llm::types::{CompletionRequest, CompletionResponse, ContentBlock, StopReason};
5use crate::llm::{DynLlmProvider, LlmProvider, OnText};
6
7/// Evaluates whether a cheaper model's response is "good enough"
8/// to avoid escalating to a more expensive tier.
9pub trait ConfidenceGate: Send + Sync {
10    /// Return `true` if the response is good enough to accept without escalating to a higher tier.
11    fn accept(&self, request: &CompletionRequest, response: &CompletionResponse) -> bool;
12}
13
14/// Zero-cost heuristic gate (no extra LLM calls).
15pub struct HeuristicGate {
16    /// Minimum output tokens for acceptance (default: 5).
17    pub min_output_tokens: u32,
18    /// Refusal phrases that trigger escalation.
19    pub refusal_patterns: Vec<String>,
20    /// Accept responses that include tool calls (default: true).
21    pub accept_tool_calls: bool,
22    /// Escalate on MaxTokens stop reason (default: true).
23    pub escalate_on_max_tokens: bool,
24}
25
26impl Default for HeuristicGate {
27    fn default() -> Self {
28        Self {
29            min_output_tokens: 5,
30            refusal_patterns: default_refusal_patterns(),
31            accept_tool_calls: true,
32            escalate_on_max_tokens: true,
33        }
34    }
35}
36
37fn default_refusal_patterns() -> Vec<String> {
38    // SECURITY (F-LLM-7): the previous list included short generic phrases
39    // ("I cannot", "I can't") that an attacker can inject through user
40    // input to force every cheap-tier response into escalation, amplifying
41    // cost. The trimmed list keeps only longer, less-injectable phrases
42    // characteristic of a refusal. Pattern matching uses word-boundary-
43    // adjacent tokens (still substring-based — for serious refusal
44    // detection, swap in a custom `ConfidenceGate`).
45    [
46        "I don't have enough information",
47        "I'm unable to help",
48        "beyond my capabilities",
49        "I apologize, but I cannot",
50    ]
51    .iter()
52    .map(|s| s.to_string())
53    .collect()
54}
55
56impl ConfidenceGate for HeuristicGate {
57    fn accept(&self, _request: &CompletionRequest, response: &CompletionResponse) -> bool {
58        // 1. Accept tool calls immediately
59        if self.accept_tool_calls
60            && response
61                .content
62                .iter()
63                .any(|b| matches!(b, ContentBlock::ToolUse { .. }))
64        {
65            return true;
66        }
67
68        // 2. Reject on MaxTokens
69        if self.escalate_on_max_tokens && response.stop_reason == StopReason::MaxTokens {
70            return false;
71        }
72
73        // 3. Reject short responses
74        if response.usage.output_tokens < self.min_output_tokens {
75            return false;
76        }
77
78        // 4. Reject refusal patterns (case-insensitive)
79        let text = response.text().to_lowercase();
80        for pattern in &self.refusal_patterns {
81            if text.contains(&pattern.to_lowercase()) {
82                return false;
83            }
84        }
85
86        // 5. Accept
87        true
88    }
89}
90
91/// A tier in the cascade: a provider with a human-readable label.
92pub struct CascadeTier {
93    provider: Box<dyn DynLlmProvider>,
94    label: String,
95}
96
97/// Tries cheaper models first, escalating to more expensive tiers
98/// when the confidence gate rejects a response or a tier errors.
99///
100/// The final tier always accepts (no gate check).
101/// Non-final tiers use `complete()` even for `stream_complete()` calls
102/// to avoid streaming tokens that might be discarded.
103pub struct CascadingProvider {
104    tiers: Vec<CascadeTier>,
105    gate: Box<dyn ConfidenceGate>,
106}
107
108impl CascadingProvider {
109    /// Create a new [`CascadingProviderBuilder`].
110    pub fn builder() -> CascadingProviderBuilder {
111        CascadingProviderBuilder {
112            tiers: Vec::new(),
113            gate: None,
114        }
115    }
116}
117
118impl LlmProvider for CascadingProvider {
119    fn model_name(&self) -> Option<&str> {
120        self.tiers.first().map(|t| t.label.as_str())
121    }
122
123    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
124        for (i, tier) in self.tiers.iter().enumerate() {
125            let is_last = i == self.tiers.len() - 1;
126            match tier.provider.complete(request.clone()).await {
127                Ok(mut response) => {
128                    if is_last || self.gate.accept(&request, &response) {
129                        response.model = Some(tier.label.clone());
130                        tracing::info!(
131                            tier = %tier.label,
132                            is_last,
133                            output_tokens = response.usage.output_tokens,
134                            "cascade: accepted response"
135                        );
136                        return Ok(response);
137                    }
138                    tracing::info!(
139                        from = %tier.label,
140                        to = %self.tiers[i + 1].label,
141                        "cascade: gate rejected, escalating"
142                    );
143                }
144                Err(e) if is_last => return Err(e),
145                Err(e) => {
146                    tracing::warn!(
147                        tier = %tier.label,
148                        error = %e,
149                        "cascade: tier failed, escalating"
150                    );
151                }
152            }
153        }
154        unreachable!("cascade must have at least one tier")
155    }
156
157    async fn stream_complete(
158        &self,
159        request: CompletionRequest,
160        on_text: &OnText,
161    ) -> Result<CompletionResponse, Error> {
162        // Single tier: stream directly
163        if self.tiers.len() == 1 {
164            let mut resp = self.tiers[0]
165                .provider
166                .stream_complete(request, on_text)
167                .await?;
168            resp.model = Some(self.tiers[0].label.clone());
169            return Ok(resp);
170        }
171
172        // Non-final tiers: use complete() to avoid streaming tokens we might discard
173        for (i, tier) in self.tiers.iter().enumerate() {
174            let is_last = i == self.tiers.len() - 1;
175            if is_last {
176                let mut resp = tier.provider.stream_complete(request, on_text).await?;
177                resp.model = Some(tier.label.clone());
178                return Ok(resp);
179            }
180            match tier.provider.complete(request.clone()).await {
181                Ok(mut response) if self.gate.accept(&request, &response) => {
182                    response.model = Some(tier.label.clone());
183                    tracing::info!(
184                        tier = %tier.label,
185                        output_tokens = response.usage.output_tokens,
186                        "cascade: cheap tier accepted (stream path)"
187                    );
188                    // Emit text as a single chunk for streaming callers
189                    let text = response.text();
190                    if !text.is_empty() {
191                        on_text(&text);
192                    }
193                    return Ok(response);
194                }
195                Ok(_) => {
196                    tracing::info!(
197                        from = %tier.label,
198                        to = %self.tiers[i + 1].label,
199                        "cascade: gate rejected, escalating"
200                    );
201                }
202                Err(e) => {
203                    tracing::warn!(
204                        tier = %tier.label,
205                        error = %e,
206                        "cascade: tier failed, escalating"
207                    );
208                }
209            }
210        }
211        unreachable!("cascade stream_complete exhausted all tiers without returning")
212    }
213}
214
215/// Builder for [`CascadingProvider`].
216pub struct CascadingProviderBuilder {
217    tiers: Vec<CascadeTier>,
218    gate: Option<Box<dyn ConfidenceGate>>,
219}
220
221impl CascadingProviderBuilder {
222    /// Add a tier (cheapest first, most expensive last).
223    pub fn add_tier(
224        mut self,
225        label: impl Into<String>,
226        provider: impl LlmProvider + 'static,
227    ) -> Self {
228        self.tiers.push(CascadeTier {
229            provider: Box::new(provider),
230            label: label.into(),
231        });
232        self
233    }
234
235    /// Set the confidence gate. Defaults to [`HeuristicGate`] with default settings.
236    pub fn gate(mut self, gate: impl ConfidenceGate + 'static) -> Self {
237        self.gate = Some(Box::new(gate));
238        self
239    }
240
241    /// Build the cascading provider.
242    pub fn build(self) -> Result<CascadingProvider, Error> {
243        if self.tiers.is_empty() {
244            return Err(Error::Config(
245                "CascadingProvider requires at least one tier".into(),
246            ));
247        }
248        Ok(CascadingProvider {
249            tiers: self.tiers,
250            gate: self
251                .gate
252                .unwrap_or_else(|| Box::new(HeuristicGate::default())),
253        })
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
261    use serde_json::json;
262    use std::sync::atomic::{AtomicUsize, Ordering};
263    use std::sync::{Arc, Mutex};
264
265    fn text_response(text: &str, output_tokens: u32) -> CompletionResponse {
266        CompletionResponse {
267            content: vec![ContentBlock::Text { text: text.into() }],
268            stop_reason: StopReason::EndTurn,
269            usage: TokenUsage {
270                output_tokens,
271                ..Default::default()
272            },
273            model: None,
274        }
275    }
276
277    fn tool_response() -> CompletionResponse {
278        CompletionResponse {
279            content: vec![ContentBlock::ToolUse {
280                id: "call-1".into(),
281                name: "search".into(),
282                input: json!({"q": "rust"}),
283            }],
284            stop_reason: StopReason::ToolUse,
285            usage: TokenUsage {
286                output_tokens: 20,
287                ..Default::default()
288            },
289            model: None,
290        }
291    }
292
293    fn max_tokens_response() -> CompletionResponse {
294        CompletionResponse {
295            content: vec![ContentBlock::Text {
296                text: "truncated...".into(),
297            }],
298            stop_reason: StopReason::MaxTokens,
299            usage: TokenUsage {
300                output_tokens: 100,
301                ..Default::default()
302            },
303            model: None,
304        }
305    }
306
307    fn test_request() -> CompletionRequest {
308        CompletionRequest {
309            system: String::new(),
310            messages: vec![Message::user("hello")],
311            tools: vec![],
312            max_tokens: 1024,
313            tool_choice: None,
314            reasoning_effort: None,
315        }
316    }
317
318    // -- HeuristicGate tests --
319
320    #[test]
321    fn heuristic_gate_accepts_normal_response() {
322        let gate = HeuristicGate::default();
323        let req = test_request();
324        let resp = text_response("Salut Pascal! Comment vas-tu?", 10);
325        assert!(gate.accept(&req, &resp));
326    }
327
328    #[test]
329    fn heuristic_gate_rejects_short_response() {
330        let gate = HeuristicGate::default();
331        let req = test_request();
332        let resp = text_response("Hi", 2);
333        assert!(!gate.accept(&req, &resp));
334    }
335
336    #[test]
337    fn heuristic_gate_rejects_refusal_patterns() {
338        let gate = HeuristicGate::default();
339        let req = test_request();
340
341        // SECURITY (F-LLM-7): pattern list trimmed to less-injectable phrases.
342        // The previous broad list ("I cannot", "I can't") was injectable
343        // through user input, forcing escalation to expensive tiers.
344        let patterns = [
345            "I don't have enough information to answer.",
346            "I'm unable to help with that request.",
347            "That topic is beyond my capabilities.",
348            "I apologize, but I cannot help with this.",
349        ];
350        for text in patterns {
351            let resp = text_response(text, 20);
352            assert!(!gate.accept(&req, &resp), "should reject: {text}");
353        }
354    }
355
356    #[test]
357    fn heuristic_gate_accepts_tool_calls() {
358        let gate = HeuristicGate::default();
359        let req = test_request();
360        let resp = tool_response();
361        assert!(gate.accept(&req, &resp));
362    }
363
364    #[test]
365    fn heuristic_gate_rejects_max_tokens() {
366        let gate = HeuristicGate::default();
367        let req = test_request();
368        let resp = max_tokens_response();
369        assert!(!gate.accept(&req, &resp));
370    }
371
372    #[test]
373    fn heuristic_gate_default_patterns() {
374        let gate = HeuristicGate::default();
375        assert_eq!(gate.min_output_tokens, 5);
376        assert!(gate.accept_tool_calls);
377        assert!(gate.escalate_on_max_tokens);
378        // SECURITY (F-LLM-7): trimmed list — at least one pattern still
379        // present. The prior threshold (>= 7) encoded the wide list that
380        // was vulnerable to user-text injection.
381        assert!(!gate.refusal_patterns.is_empty());
382    }
383
384    #[test]
385    fn heuristic_gate_case_insensitive_refusal() {
386        let gate = HeuristicGate::default();
387        let req = test_request();
388        // SECURITY (F-LLM-7): "I'M UNABLE TO HELP" should match
389        // "I'm unable to help" via lowercase comparison.
390        let resp = text_response("I'M UNABLE TO HELP with that", 10);
391        assert!(!gate.accept(&req, &resp));
392    }
393
394    // -- Mock providers for CascadingProvider tests --
395
396    struct FixedProvider {
397        label: &'static str,
398        response: Result<CompletionResponse, Error>,
399        call_count: AtomicUsize,
400    }
401
402    impl FixedProvider {
403        fn ok(label: &'static str, response: CompletionResponse) -> Self {
404            Self {
405                label,
406                response: Ok(response),
407                call_count: AtomicUsize::new(0),
408            }
409        }
410
411        fn err(label: &'static str) -> Self {
412            Self {
413                label,
414                response: Err(Error::Api {
415                    status: 500,
416                    message: "tier error".into(),
417                }),
418                call_count: AtomicUsize::new(0),
419            }
420        }
421    }
422
423    impl LlmProvider for FixedProvider {
424        async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
425            self.call_count.fetch_add(1, Ordering::Relaxed);
426            match &self.response {
427                Ok(r) => Ok(r.clone()),
428                Err(e) => Err(Error::Api {
429                    status: match e {
430                        Error::Api { status, .. } => *status,
431                        _ => 500,
432                    },
433                    message: format!("{} error", self.label),
434                }),
435            }
436        }
437
438        async fn stream_complete(
439            &self,
440            _request: CompletionRequest,
441            on_text: &OnText,
442        ) -> Result<CompletionResponse, Error> {
443            self.call_count.fetch_add(1, Ordering::Relaxed);
444            match &self.response {
445                Ok(r) => {
446                    let text = r.text();
447                    if !text.is_empty() {
448                        on_text(&text);
449                    }
450                    Ok(r.clone())
451                }
452                Err(_) => Err(Error::Api {
453                    status: 500,
454                    message: format!("{} error", self.label),
455                }),
456            }
457        }
458
459        fn model_name(&self) -> Option<&str> {
460            Some(self.label)
461        }
462    }
463
464    /// Gate that always accepts.
465    struct AlwaysAccept;
466    impl ConfidenceGate for AlwaysAccept {
467        fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
468            true
469        }
470    }
471
472    /// Gate that always rejects.
473    struct AlwaysReject;
474    impl ConfidenceGate for AlwaysReject {
475        fn accept(&self, _req: &CompletionRequest, _resp: &CompletionResponse) -> bool {
476            false
477        }
478    }
479
480    // -- CascadingProvider tests --
481
482    #[tokio::test]
483    async fn single_tier_delegates_directly() {
484        let provider = CascadingProvider::builder()
485            .add_tier(
486                "haiku",
487                FixedProvider::ok("haiku", text_response("hello", 10)),
488            )
489            .gate(AlwaysAccept)
490            .build()
491            .unwrap();
492
493        let resp = LlmProvider::complete(&provider, test_request())
494            .await
495            .unwrap();
496        assert_eq!(resp.text(), "hello");
497        assert_eq!(resp.model.as_deref(), Some("haiku"));
498    }
499
500    #[tokio::test]
501    async fn two_tier_accepts_cheap_when_gate_passes() {
502        let provider = CascadingProvider::builder()
503            .add_tier(
504                "haiku",
505                FixedProvider::ok("haiku", text_response("Salut!", 10)),
506            )
507            .add_tier(
508                "sonnet",
509                FixedProvider::ok("sonnet", text_response("expensive", 50)),
510            )
511            .gate(AlwaysAccept)
512            .build()
513            .unwrap();
514
515        let resp = LlmProvider::complete(&provider, test_request())
516            .await
517            .unwrap();
518        assert_eq!(resp.text(), "Salut!");
519        assert_eq!(resp.model.as_deref(), Some("haiku"));
520        // expensive provider was never called (we can't check this with the current
521        // setup since we moved providers into tiers, but the response proves haiku was used)
522    }
523
524    #[tokio::test]
525    async fn two_tier_escalates_when_gate_rejects() {
526        let provider = CascadingProvider::builder()
527            .add_tier(
528                "haiku",
529                FixedProvider::ok("haiku", text_response("dunno", 10)),
530            )
531            .add_tier(
532                "sonnet",
533                FixedProvider::ok("sonnet", text_response("great answer", 50)),
534            )
535            .gate(AlwaysReject)
536            .build()
537            .unwrap();
538
539        let resp = LlmProvider::complete(&provider, test_request())
540            .await
541            .unwrap();
542        // Gate rejected haiku, so sonnet should be used.
543        // Final tier always accepts regardless of gate.
544        assert_eq!(resp.text(), "great answer");
545        assert_eq!(resp.model.as_deref(), Some("sonnet"));
546    }
547
548    #[tokio::test]
549    async fn three_tier_skips_erroring_tier() {
550        let provider = CascadingProvider::builder()
551            .add_tier("haiku", FixedProvider::err("haiku"))
552            .add_tier(
553                "sonnet",
554                FixedProvider::ok("sonnet", text_response("mid", 10)),
555            )
556            .add_tier(
557                "opus",
558                FixedProvider::ok("opus", text_response("expensive", 50)),
559            )
560            .gate(AlwaysAccept)
561            .build()
562            .unwrap();
563
564        let resp = LlmProvider::complete(&provider, test_request())
565            .await
566            .unwrap();
567        assert_eq!(resp.text(), "mid");
568        assert_eq!(resp.model.as_deref(), Some("sonnet"));
569    }
570
571    #[tokio::test]
572    async fn final_tier_always_accepts() {
573        // AlwaysReject gate, but final tier should still be returned
574        let provider = CascadingProvider::builder()
575            .add_tier(
576                "haiku",
577                FixedProvider::ok("haiku", text_response("cheap", 10)),
578            )
579            .add_tier(
580                "sonnet",
581                FixedProvider::ok("sonnet", text_response("final", 50)),
582            )
583            .gate(AlwaysReject)
584            .build()
585            .unwrap();
586
587        let resp = LlmProvider::complete(&provider, test_request())
588            .await
589            .unwrap();
590        assert_eq!(resp.text(), "final");
591        assert_eq!(resp.model.as_deref(), Some("sonnet"));
592    }
593
594    #[tokio::test]
595    async fn stream_uses_complete_for_non_final_tiers() {
596        // Track which method was called. We use a special provider that panics on
597        // stream_complete for the cheap tier (non-final tiers should use complete()).
598        struct CompleteOnlyProvider;
599        impl LlmProvider for CompleteOnlyProvider {
600            async fn complete(
601                &self,
602                _request: CompletionRequest,
603            ) -> Result<CompletionResponse, Error> {
604                Ok(text_response("cheap answer", 10))
605            }
606            async fn stream_complete(
607                &self,
608                _request: CompletionRequest,
609                _on_text: &OnText,
610            ) -> Result<CompletionResponse, Error> {
611                panic!("non-final tier should not call stream_complete");
612            }
613        }
614
615        let provider = CascadingProvider::builder()
616            .add_tier("cheap", CompleteOnlyProvider)
617            .add_tier(
618                "expensive",
619                FixedProvider::ok("expensive", text_response("expensive", 50)),
620            )
621            .gate(AlwaysAccept)
622            .build()
623            .unwrap();
624
625        let on_text: &OnText = &|_| {};
626        let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
627            .await
628            .unwrap();
629        assert_eq!(resp.text(), "cheap answer");
630    }
631
632    #[tokio::test]
633    async fn stream_emits_text_when_cheap_accepted() {
634        let collected = Arc::new(Mutex::new(Vec::<String>::new()));
635        let collected_clone = collected.clone();
636        let on_text: &OnText = &move |text: &str| {
637            collected_clone.lock().expect("lock").push(text.to_string());
638        };
639
640        let provider = CascadingProvider::builder()
641            .add_tier(
642                "cheap",
643                FixedProvider::ok("cheap", text_response("hello world", 10)),
644            )
645            .add_tier(
646                "expensive",
647                FixedProvider::ok("expensive", text_response("expensive", 50)),
648            )
649            .gate(AlwaysAccept)
650            .build()
651            .unwrap();
652
653        let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
654            .await
655            .unwrap();
656        assert_eq!(resp.text(), "hello world");
657
658        let texts = collected.lock().expect("lock");
659        assert_eq!(*texts, vec!["hello world"]);
660    }
661
662    #[tokio::test]
663    async fn stream_streams_final_tier() {
664        // When gate rejects cheap tier, final tier should use stream_complete
665        let streamed = Arc::new(Mutex::new(Vec::<String>::new()));
666        let streamed_clone = streamed.clone();
667        let on_text: &OnText = &move |text: &str| {
668            streamed_clone.lock().expect("lock").push(text.to_string());
669        };
670
671        struct StreamingProvider;
672        impl LlmProvider for StreamingProvider {
673            async fn complete(
674                &self,
675                _request: CompletionRequest,
676            ) -> Result<CompletionResponse, Error> {
677                panic!("final tier with streaming should use stream_complete");
678            }
679            async fn stream_complete(
680                &self,
681                _request: CompletionRequest,
682                on_text: &OnText,
683            ) -> Result<CompletionResponse, Error> {
684                on_text("streamed ");
685                on_text("response");
686                Ok(CompletionResponse {
687                    content: vec![ContentBlock::Text {
688                        text: "streamed response".into(),
689                    }],
690                    stop_reason: StopReason::EndTurn,
691                    usage: TokenUsage {
692                        output_tokens: 20,
693                        ..Default::default()
694                    },
695                    model: None,
696                })
697            }
698        }
699
700        let provider = CascadingProvider::builder()
701            .add_tier(
702                "cheap",
703                FixedProvider::ok("cheap", text_response("dunno", 10)),
704            )
705            .add_tier("expensive", StreamingProvider)
706            .gate(AlwaysReject)
707            .build()
708            .unwrap();
709
710        let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
711            .await
712            .unwrap();
713        assert_eq!(resp.text(), "streamed response");
714        assert_eq!(resp.model.as_deref(), Some("expensive"));
715
716        let texts = streamed.lock().expect("lock");
717        assert_eq!(*texts, vec!["streamed ", "response"]);
718    }
719
720    #[tokio::test]
721    async fn response_model_set_to_accepting_tier() {
722        let provider = CascadingProvider::builder()
723            .add_tier("haiku", FixedProvider::err("haiku"))
724            .add_tier(
725                "sonnet",
726                FixedProvider::ok("sonnet", text_response("answer", 10)),
727            )
728            .gate(AlwaysAccept)
729            .build()
730            .unwrap();
731
732        let resp = LlmProvider::complete(&provider, test_request())
733            .await
734            .unwrap();
735        assert_eq!(resp.model.as_deref(), Some("sonnet"));
736    }
737
738    #[test]
739    fn builder_rejects_zero_tiers() {
740        let result = CascadingProvider::builder().gate(AlwaysAccept).build();
741        assert!(result.is_err());
742    }
743
744    #[test]
745    fn cascading_provider_is_send_sync() {
746        fn assert_send_sync<T: Send + Sync>() {}
747        assert_send_sync::<CascadingProvider>();
748    }
749
750    #[test]
751    fn builder_defaults_to_heuristic_gate() {
752        let provider = CascadingProvider::builder()
753            .add_tier("haiku", FixedProvider::ok("haiku", text_response("hi", 10)))
754            .build()
755            .unwrap();
756        // Should build without explicit gate
757        assert_eq!(LlmProvider::model_name(&provider), Some("haiku"));
758    }
759
760    #[tokio::test]
761    async fn single_tier_streams_directly() {
762        // Single tier should use stream_complete, not complete
763        struct StreamOnlyProvider;
764        impl LlmProvider for StreamOnlyProvider {
765            async fn complete(
766                &self,
767                _request: CompletionRequest,
768            ) -> Result<CompletionResponse, Error> {
769                panic!("single tier should stream directly");
770            }
771            async fn stream_complete(
772                &self,
773                _request: CompletionRequest,
774                on_text: &OnText,
775            ) -> Result<CompletionResponse, Error> {
776                on_text("streamed");
777                Ok(text_response("streamed", 10))
778            }
779        }
780
781        let provider = CascadingProvider::builder()
782            .add_tier("only", StreamOnlyProvider)
783            .gate(AlwaysAccept)
784            .build()
785            .unwrap();
786
787        let on_text: &OnText = &|_| {};
788        let resp = LlmProvider::stream_complete(&provider, test_request(), on_text)
789            .await
790            .unwrap();
791        assert_eq!(resp.text(), "streamed");
792        assert_eq!(resp.model.as_deref(), Some("only"));
793    }
794
795    #[tokio::test]
796    async fn all_tiers_error_returns_last_error() {
797        let provider = CascadingProvider::builder()
798            .add_tier("tier1", FixedProvider::err("tier1"))
799            .add_tier("tier2", FixedProvider::err("tier2"))
800            .gate(AlwaysAccept)
801            .build()
802            .unwrap();
803
804        let err = LlmProvider::complete(&provider, test_request())
805            .await
806            .unwrap_err();
807        assert!(err.to_string().contains("tier2"), "error: {err}");
808    }
809
810    #[tokio::test]
811    async fn heuristic_gate_integration_with_cascade() {
812        // Cheap gives short answer → gate rejects → escalates to expensive
813        let provider = CascadingProvider::builder()
814            .add_tier("haiku", FixedProvider::ok("haiku", text_response("Hi", 2)))
815            .add_tier(
816                "sonnet",
817                FixedProvider::ok("sonnet", text_response("detailed answer here", 30)),
818            )
819            // Default HeuristicGate with min_output_tokens=5
820            .build()
821            .unwrap();
822
823        let resp = LlmProvider::complete(&provider, test_request())
824            .await
825            .unwrap();
826        assert_eq!(resp.text(), "detailed answer here");
827        assert_eq!(resp.model.as_deref(), Some("sonnet"));
828    }
829}