Skip to main content

agent_sdk_providers/
refresh.rs

1//! Provider wrapper that refreshes credentials on 401 and retries once.
2//!
3//! [`RefreshingProvider`] wraps any [`LlmProvider`] and adds a host-driven
4//! credential refresh step: when the inner provider reports an unauthorized
5//! error (HTTP 401, expired OAuth token, invalid API key, etc.), the wrapper
6//! calls a host-supplied async callback to rebuild the inner provider with
7//! fresh credentials and retries the original request once.
8//!
9//! This is the generic form of the per-provider refresh wrappers that
10//! OAuth-backed hosts would otherwise copy across every provider they use.
11//!
12//! # Example
13//!
14//! ```no_run
15//! # use std::sync::Arc;
16//! # use anyhow::Result;
17//! # use agent_sdk_providers::{LlmProvider, RefreshingProvider};
18//! # async fn demo<P: LlmProvider + Clone + 'static>(initial: P) -> Result<()> {
19//! let refreshed_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
20//! let counter = Arc::clone(&refreshed_count);
21//! let wrapped = RefreshingProvider::new(initial.clone(), move || {
22//!     let counter = Arc::clone(&counter);
23//!     let provider = initial.clone();
24//!     async move {
25//!         counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
26//!         Ok(provider)
27//!     }
28//! });
29//! # let _ = wrapped;
30//! # Ok(())
31//! # }
32//! ```
33//!
34//! ## Streaming semantics
35//!
36//! For non-streaming [`chat`](LlmProvider::chat), the retry happens when the
37//! first call resolves to [`ChatOutcome::InvalidRequest`] with a 401-looking
38//! message. The second call replaces the first — callers only see the final
39//! outcome.
40//!
41//! For [`chat_stream`](LlmProvider::chat_stream), retry happens only when the
42//! 401 arrives before any content delta is forwarded. If a
43//! [`StreamDelta::TextDelta`], [`StreamDelta::ThinkingDelta`], tool-call
44//! delta, or thinking signature has already been yielded to the consumer,
45//! the error is forwarded as-is — retrying would duplicate partial output.
46//!
47//! At most one retry happens per call, whether streaming or not. If the
48//! retried call fails in the same way, the wrapper surfaces the second
49//! error unchanged.
50
51use std::future::Future;
52use std::sync::Arc;
53
54use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ThinkingConfig};
55use anyhow::Result;
56use async_trait::async_trait;
57use futures::StreamExt;
58use tokio::sync::Mutex;
59
60use crate::model_capabilities::ModelCapabilities;
61use crate::provider::{LlmProvider, StructuredOutputSupport};
62use crate::streaming::{StreamBox, StreamDelta};
63
64/// Wraps a provider with host-driven credential refresh on 401.
65///
66/// The inner provider is stored behind `Arc<Mutex<P>>` so it can be swapped
67/// atomically when the refresh callback produces a new provider. Cloning a
68/// wrapper is cheap — clones share the same inner state.
69///
70/// Metadata (`model`, `provider`, `configured_thinking`) and capability shaping
71/// (`capabilities`, `default_max_tokens`, `validate_thinking_config`,
72/// `structured_output_support`) are captured from the initial provider at
73/// construction time and assumed constant across refreshes (the refresh
74/// callback rebuilds the same provider shape with a fresh token, not a
75/// different model).
76pub struct RefreshingProvider<P, F> {
77    inner: Arc<Mutex<P>>,
78    refresh: Arc<F>,
79    /// A clone of the initial provider kept solely for delegating the
80    /// **synchronous** capability methods (which never touch credentials), so
81    /// wrapping a provider never changes how requests are shaped. The async
82    /// `chat`/`chat_stream` paths still go through the refreshable `inner`.
83    template: P,
84    model: String,
85    provider: &'static str,
86    thinking: Option<ThinkingConfig>,
87}
88
89impl<P: Clone, F> Clone for RefreshingProvider<P, F> {
90    fn clone(&self) -> Self {
91        Self {
92            inner: Arc::clone(&self.inner),
93            refresh: Arc::clone(&self.refresh),
94            template: self.template.clone(),
95            model: self.model.clone(),
96            provider: self.provider,
97            thinking: self.thinking.clone(),
98        }
99    }
100}
101
102impl<P, F, Fut> RefreshingProvider<P, F>
103where
104    P: LlmProvider + Clone + 'static,
105    F: Fn() -> Fut + Send + Sync + 'static,
106    Fut: Future<Output = Result<P>> + Send + 'static,
107{
108    /// Build a wrapper from an initial provider and a refresh callback.
109    ///
110    /// The refresh callback is invoked each time the inner provider emits a
111    /// 401 response. It must be idempotent and safe to call concurrently.
112    /// The callback should return a fully-built provider ready to use;
113    /// typically it reads fresh credentials from its auth store and calls
114    /// the inner provider's constructor.
115    #[must_use]
116    pub fn new(inner: P, refresh: F) -> Self {
117        let model = inner.model().to_string();
118        let provider = inner.provider();
119        let thinking = inner.configured_thinking().cloned();
120        let template = inner.clone();
121        Self {
122            inner: Arc::new(Mutex::new(inner)),
123            refresh: Arc::new(refresh),
124            template,
125            model,
126            provider,
127            thinking,
128        }
129    }
130
131    async fn snapshot(&self) -> P {
132        self.inner.lock().await.clone()
133    }
134
135    async fn run_refresh(&self) -> Result<()> {
136        let fresh = (self.refresh)().await?;
137        *self.inner.lock().await = fresh;
138        Ok(())
139    }
140}
141
142/// Classify a provider error message as a 401 / unauthorized condition.
143///
144/// Returns `true` when the message looks like the provider rejected the
145/// request because of missing, invalid, or expired credentials. Detection is
146/// case-insensitive and matches the error-body shapes emitted by the
147/// first-party providers in this crate. Hosts that implement their own retry
148/// logic on top of [`LlmProvider`] can gate on the same helper.
149#[must_use]
150pub fn is_unauthorized_error(message: &str) -> bool {
151    let lower = message.to_ascii_lowercase();
152    lower.contains(" 401")
153        || lower.contains("status=401")
154        || lower.contains("unauthorized")
155        || lower.contains("authentication")
156        || lower.contains("token_expired")
157        || lower.contains("invalid api key")
158        || lower.contains("invalid_api_key")
159}
160
161#[async_trait]
162impl<P, F, Fut> LlmProvider for RefreshingProvider<P, F>
163where
164    P: LlmProvider + Clone + 'static,
165    F: Fn() -> Fut + Send + Sync + 'static,
166    Fut: Future<Output = Result<P>> + Send + 'static,
167{
168    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
169        let outcome = self.snapshot().await.chat(request.clone()).await?;
170        if let ChatOutcome::InvalidRequest(message) = &outcome
171            && is_unauthorized_error(message)
172        {
173            match self.run_refresh().await {
174                Ok(()) => return self.snapshot().await.chat(request).await,
175                Err(error) => {
176                    log::warn!("RefreshingProvider refresh after 401 failed: {error:#}");
177                }
178            }
179        }
180        Ok(outcome)
181    }
182
183    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
184        let this = self.clone();
185        Box::pin(async_stream::stream! {
186            let mut refreshed = false;
187            'attempts: loop {
188                let provider = this.snapshot().await;
189                let mut stream = provider.chat_stream(request.clone());
190                let mut saw_output = false;
191
192                while let Some(item) = stream.next().await {
193                    match item {
194                        Ok(StreamDelta::Error { message, kind })
195                            if !saw_output
196                                && !refreshed
197                                && is_unauthorized_error(&message) =>
198                        {
199                            match this.run_refresh().await {
200                                Ok(()) => {
201                                    refreshed = true;
202                                    continue 'attempts;
203                                }
204                                Err(error) => {
205                                    log::warn!(
206                                        "RefreshingProvider refresh after streaming 401 failed: {error:#}"
207                                    );
208                                    yield Ok(StreamDelta::Error { message, kind });
209                                    return;
210                                }
211                            }
212                        }
213                        Ok(delta) => {
214                            if matches!(
215                                delta,
216                                StreamDelta::TextDelta { .. }
217                                    | StreamDelta::ThinkingDelta { .. }
218                                    | StreamDelta::ToolUseStart { .. }
219                                    | StreamDelta::ToolInputDelta { .. }
220                                    | StreamDelta::SignatureDelta { .. }
221                                    | StreamDelta::RedactedThinking { .. }
222                            ) {
223                                saw_output = true;
224                            }
225                            let done = matches!(delta, StreamDelta::Done { .. });
226                            yield Ok(delta);
227                            if done {
228                                return;
229                            }
230                        }
231                        Err(error)
232                            if !saw_output
233                                && !refreshed
234                                && is_unauthorized_error(&error.to_string()) =>
235                        {
236                            match this.run_refresh().await {
237                                Ok(()) => {
238                                    refreshed = true;
239                                    continue 'attempts;
240                                }
241                                Err(refresh_error) => {
242                                    log::warn!(
243                                        "RefreshingProvider refresh after stream failure failed: {refresh_error:#}"
244                                    );
245                                    yield Err(error);
246                                    return;
247                                }
248                            }
249                        }
250                        Err(error) => {
251                            yield Err(error);
252                            return;
253                        }
254                    }
255                }
256                return;
257            }
258        })
259    }
260
261    fn model(&self) -> &str {
262        &self.model
263    }
264
265    fn provider(&self) -> &'static str {
266        self.provider
267    }
268
269    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
270        self.thinking.as_ref()
271    }
272
273    // Delegate capability shaping to the wrapped provider so that wrapping
274    // never silently changes request shaping (e.g. losing Vertex's max-token
275    // clamp or a provider's adaptive-thinking validation). These methods are
276    // synchronous and credential-independent, so they go through the captured
277    // `template` rather than locking the async-refreshable `inner`.
278
279    fn capabilities(&self) -> Option<&'static ModelCapabilities> {
280        self.template.capabilities()
281    }
282
283    fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
284        self.template.validate_thinking_config(thinking)
285    }
286
287    fn default_max_tokens(&self) -> u32 {
288        self.template.default_max_tokens()
289    }
290
291    fn structured_output_support(&self) -> StructuredOutputSupport {
292        self.template.structured_output_support()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    use std::collections::VecDeque;
301    use std::sync::Mutex as StdMutex;
302    use std::sync::atomic::{AtomicUsize, Ordering};
303
304    use agent_sdk_foundation::llm::{ChatResponse, ContentBlock, StopReason, Usage};
305    use anyhow::Context;
306
307    use crate::streaming::StreamErrorKind;
308
309    #[derive(Clone)]
310    enum MockStreamItem {
311        Ok(StreamDelta),
312        Err(String),
313    }
314
315    #[derive(Clone)]
316    struct MockProvider {
317        model: String,
318        provider_name: &'static str,
319        outcomes: Arc<StdMutex<VecDeque<ChatOutcome>>>,
320        stream_batches: Arc<StdMutex<VecDeque<Vec<MockStreamItem>>>>,
321        chat_calls: Arc<AtomicUsize>,
322        stream_calls: Arc<AtomicUsize>,
323    }
324
325    impl MockProvider {
326        fn new() -> Self {
327            Self {
328                model: "mock-model".to_string(),
329                provider_name: "mock",
330                outcomes: Arc::new(StdMutex::new(VecDeque::new())),
331                stream_batches: Arc::new(StdMutex::new(VecDeque::new())),
332                chat_calls: Arc::new(AtomicUsize::new(0)),
333                stream_calls: Arc::new(AtomicUsize::new(0)),
334            }
335        }
336
337        fn queue_chat(&self, outcome: ChatOutcome) -> Result<()> {
338            self.outcomes
339                .lock()
340                .ok()
341                .context("outcomes lock poisoned")?
342                .push_back(outcome);
343            Ok(())
344        }
345
346        fn queue_stream(&self, batch: Vec<MockStreamItem>) -> Result<()> {
347            self.stream_batches
348                .lock()
349                .ok()
350                .context("stream_batches lock poisoned")?
351                .push_back(batch);
352            Ok(())
353        }
354
355        fn chat_call_count(&self) -> usize {
356            self.chat_calls.load(Ordering::SeqCst)
357        }
358
359        fn stream_call_count(&self) -> usize {
360            self.stream_calls.load(Ordering::SeqCst)
361        }
362    }
363
364    #[async_trait]
365    impl LlmProvider for MockProvider {
366        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
367            self.chat_calls.fetch_add(1, Ordering::SeqCst);
368            let mut queue = self
369                .outcomes
370                .lock()
371                .ok()
372                .context("outcomes lock poisoned")?;
373            queue.pop_front().context("MockProvider: no queued outcome")
374        }
375
376        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
377            self.stream_calls.fetch_add(1, Ordering::SeqCst);
378            let batch: Vec<MockStreamItem> = self
379                .stream_batches
380                .lock()
381                .ok()
382                .and_then(|mut q| q.pop_front())
383                .unwrap_or_else(|| vec![MockStreamItem::Err("no queued stream batch".into())]);
384            Box::pin(async_stream::stream! {
385                for item in batch {
386                    match item {
387                        MockStreamItem::Ok(delta) => yield Ok(delta),
388                        MockStreamItem::Err(msg) => {
389                            yield Err(anyhow::anyhow!(msg));
390                            return;
391                        }
392                    }
393                }
394            })
395        }
396
397        fn model(&self) -> &str {
398            &self.model
399        }
400
401        fn provider(&self) -> &'static str {
402            self.provider_name
403        }
404
405        // Sentinel capability overrides used to prove the wrapper delegates
406        // rather than falling back to trait defaults.
407        fn default_max_tokens(&self) -> u32 {
408            32_000
409        }
410
411        fn structured_output_support(&self) -> StructuredOutputSupport {
412            StructuredOutputSupport::Native
413        }
414
415        fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
416            if thinking.is_some() {
417                Err(anyhow::anyhow!("mock rejects thinking"))
418            } else {
419                Ok(())
420            }
421        }
422    }
423
424    fn success_response() -> ChatResponse {
425        ChatResponse {
426            id: "msg_test".to_string(),
427            content: vec![ContentBlock::Text {
428                text: "ok".to_string(),
429            }],
430            model: "mock-model".to_string(),
431            stop_reason: Some(StopReason::EndTurn),
432            usage: Usage {
433                input_tokens: 1,
434                output_tokens: 1,
435                cached_input_tokens: 0,
436                cache_creation_input_tokens: 0,
437            },
438        }
439    }
440
441    fn empty_request() -> ChatRequest {
442        ChatRequest {
443            system: String::new(),
444            messages: Vec::new(),
445            tools: None,
446            max_tokens: 100,
447            max_tokens_explicit: false,
448            session_id: None,
449            cached_content: None,
450            thinking: None,
451            tool_choice: None,
452            response_format: None,
453        }
454    }
455
456    type BoxedFut = std::pin::Pin<Box<dyn Future<Output = Result<MockProvider>> + Send>>;
457    type RefreshFn = Box<dyn Fn() -> BoxedFut + Send + Sync + 'static>;
458    type Wrapped = RefreshingProvider<MockProvider, RefreshFn>;
459
460    fn wrap_success(mock: &MockProvider, counter: &Arc<AtomicUsize>) -> Wrapped {
461        let counter = Arc::clone(counter);
462        let template = mock.clone();
463        let cb: RefreshFn = Box::new(move || {
464            counter.fetch_add(1, Ordering::SeqCst);
465            let provider = template.clone();
466            Box::pin(async move { Ok(provider) })
467        });
468        RefreshingProvider::new(mock.clone(), cb)
469    }
470
471    fn wrap_failure(
472        mock: &MockProvider,
473        counter: &Arc<AtomicUsize>,
474        error: &'static str,
475    ) -> Wrapped {
476        let counter = Arc::clone(counter);
477        let cb: RefreshFn = Box::new(move || {
478            counter.fetch_add(1, Ordering::SeqCst);
479            Box::pin(async move { Err(anyhow::anyhow!(error)) })
480        });
481        RefreshingProvider::new(mock.clone(), cb)
482    }
483
484    // Test 0: capability delegation (finding #12)
485    #[test]
486    fn wrapper_delegates_capability_overrides_to_inner() {
487        let mock = MockProvider::new();
488        let refresh_count = Arc::new(AtomicUsize::new(0));
489        let wrapped = wrap_success(&mock, &refresh_count);
490
491        // Without delegation these would return trait defaults (4096 / Native
492        // is coincidental / Ok), masking per-provider clamps and validation.
493        assert_eq!(wrapped.default_max_tokens(), 32_000);
494        assert_eq!(
495            wrapped.structured_output_support(),
496            StructuredOutputSupport::Native
497        );
498        assert!(
499            wrapped
500                .validate_thinking_config(Some(&ThinkingConfig::adaptive()))
501                .is_err()
502        );
503        assert!(wrapped.validate_thinking_config(None).is_ok());
504    }
505
506    // Test 1
507    #[test]
508    fn is_unauthorized_error_matches_expected_strings() {
509        assert!(is_unauthorized_error("HTTP 401"));
510        assert!(is_unauthorized_error("status=401 Unauthorized"));
511        assert!(is_unauthorized_error("Invalid API key"));
512        assert!(is_unauthorized_error("invalid_api_key"));
513        assert!(is_unauthorized_error("token_expired"));
514        assert!(is_unauthorized_error("Authentication failed"));
515        assert!(is_unauthorized_error("UNAUTHORIZED"));
516
517        assert!(!is_unauthorized_error("rate limited"));
518        assert!(!is_unauthorized_error("network error"));
519        assert!(!is_unauthorized_error(""));
520        assert!(!is_unauthorized_error("internal server error"));
521    }
522
523    // Test 2
524    #[tokio::test]
525    async fn chat_successful_pass_through_does_not_refresh() -> Result<()> {
526        let mock = MockProvider::new();
527        mock.queue_chat(ChatOutcome::Success(success_response()))?;
528
529        let refresh_count = Arc::new(AtomicUsize::new(0));
530        let wrapped = wrap_success(&mock, &refresh_count);
531
532        let outcome = wrapped.chat(empty_request()).await?;
533        assert!(matches!(outcome, ChatOutcome::Success(_)));
534        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
535        assert_eq!(mock.chat_call_count(), 1);
536        Ok(())
537    }
538
539    // Test 3
540    #[tokio::test]
541    async fn chat_401_triggers_refresh_and_retries() -> Result<()> {
542        let mock = MockProvider::new();
543        mock.queue_chat(ChatOutcome::InvalidRequest("401 Unauthorized".into()))?;
544        mock.queue_chat(ChatOutcome::Success(success_response()))?;
545
546        let refresh_count = Arc::new(AtomicUsize::new(0));
547        let wrapped = wrap_success(&mock, &refresh_count);
548
549        let outcome = wrapped.chat(empty_request()).await?;
550        assert!(matches!(outcome, ChatOutcome::Success(_)));
551        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
552        assert_eq!(mock.chat_call_count(), 2);
553        Ok(())
554    }
555
556    // Test 4
557    #[tokio::test]
558    async fn chat_surfaces_original_401_when_refresh_fails() -> Result<()> {
559        let mock = MockProvider::new();
560        mock.queue_chat(ChatOutcome::InvalidRequest(
561            "status=401 Unauthorized".into(),
562        ))?;
563
564        let refresh_count = Arc::new(AtomicUsize::new(0));
565        let wrapped = wrap_failure(&mock, &refresh_count, "refresh callback failed");
566
567        let outcome = wrapped.chat(empty_request()).await?;
568        match outcome {
569            ChatOutcome::InvalidRequest(msg) => assert!(
570                msg.contains("401"),
571                "expected original 401 message, got {msg}"
572            ),
573            other => panic!("expected InvalidRequest, got {other:?}"),
574        }
575        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
576        assert_eq!(mock.chat_call_count(), 1);
577        Ok(())
578    }
579
580    async fn drain(mut stream: StreamBox<'_>) -> Vec<Result<StreamDelta>> {
581        let mut out = Vec::new();
582        while let Some(item) = stream.next().await {
583            out.push(item);
584        }
585        out
586    }
587
588    // Test 5
589    #[tokio::test]
590    async fn chat_stream_successful_pass_through() -> Result<()> {
591        let mock = MockProvider::new();
592        mock.queue_stream(vec![
593            MockStreamItem::Ok(StreamDelta::TextDelta {
594                delta: "hi".into(),
595                block_index: 0,
596            }),
597            MockStreamItem::Ok(StreamDelta::Done {
598                stop_reason: Some(StopReason::EndTurn),
599            }),
600        ])?;
601
602        let refresh_count = Arc::new(AtomicUsize::new(0));
603        let wrapped = wrap_success(&mock, &refresh_count);
604
605        let deltas = drain(wrapped.chat_stream(empty_request())).await;
606        assert_eq!(deltas.len(), 2);
607        assert!(matches!(
608            deltas[0].as_ref().ok(),
609            Some(StreamDelta::TextDelta { delta, .. }) if delta == "hi"
610        ));
611        assert!(matches!(
612            deltas[1].as_ref().ok(),
613            Some(StreamDelta::Done { .. })
614        ));
615        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
616        assert_eq!(mock.stream_call_count(), 1);
617        Ok(())
618    }
619
620    // Test 6
621    #[tokio::test]
622    async fn chat_stream_401_before_output_retries() -> Result<()> {
623        let mock = MockProvider::new();
624        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
625            message: "status=401 Unauthorized".into(),
626            kind: StreamErrorKind::InvalidRequest,
627        })])?;
628        mock.queue_stream(vec![
629            MockStreamItem::Ok(StreamDelta::TextDelta {
630                delta: "retried".into(),
631                block_index: 0,
632            }),
633            MockStreamItem::Ok(StreamDelta::Done {
634                stop_reason: Some(StopReason::EndTurn),
635            }),
636        ])?;
637
638        let refresh_count = Arc::new(AtomicUsize::new(0));
639        let wrapped = wrap_success(&mock, &refresh_count);
640
641        let deltas = drain(wrapped.chat_stream(empty_request())).await;
642        // Consumer sees only the post-refresh stream.
643        assert_eq!(deltas.len(), 2);
644        assert!(matches!(
645            deltas[0].as_ref().ok(),
646            Some(StreamDelta::TextDelta { delta, .. }) if delta == "retried"
647        ));
648        assert!(matches!(
649            deltas[1].as_ref().ok(),
650            Some(StreamDelta::Done { .. })
651        ));
652        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
653        assert_eq!(mock.stream_call_count(), 2);
654        Ok(())
655    }
656
657    // Test 7
658    #[tokio::test]
659    async fn chat_stream_401_after_output_does_not_retry() -> Result<()> {
660        let mock = MockProvider::new();
661        mock.queue_stream(vec![
662            MockStreamItem::Ok(StreamDelta::TextDelta {
663                delta: "partial".into(),
664                block_index: 0,
665            }),
666            MockStreamItem::Ok(StreamDelta::Error {
667                message: "401 Unauthorized".into(),
668                kind: StreamErrorKind::InvalidRequest,
669            }),
670        ])?;
671
672        let refresh_count = Arc::new(AtomicUsize::new(0));
673        let wrapped = wrap_success(&mock, &refresh_count);
674
675        let deltas = drain(wrapped.chat_stream(empty_request())).await;
676        assert_eq!(deltas.len(), 2);
677        assert!(matches!(
678            deltas[0].as_ref().ok(),
679            Some(StreamDelta::TextDelta { delta, .. }) if delta == "partial"
680        ));
681        assert!(matches!(
682            deltas[1].as_ref().ok(),
683            Some(StreamDelta::Error { message, .. }) if message.contains("401")
684        ));
685        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
686        assert_eq!(mock.stream_call_count(), 1);
687        Ok(())
688    }
689
690    // Test 8
691    #[tokio::test]
692    async fn chat_stream_only_one_retry_per_call() -> Result<()> {
693        let mock = MockProvider::new();
694        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
695            message: "status=401 Unauthorized".into(),
696            kind: StreamErrorKind::InvalidRequest,
697        })])?;
698        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
699            message: "still 401 Unauthorized".into(),
700            kind: StreamErrorKind::InvalidRequest,
701        })])?;
702
703        let refresh_count = Arc::new(AtomicUsize::new(0));
704        let wrapped = wrap_success(&mock, &refresh_count);
705
706        let deltas = drain(wrapped.chat_stream(empty_request())).await;
707        assert_eq!(deltas.len(), 1);
708        assert!(matches!(
709            deltas[0].as_ref().ok(),
710            Some(StreamDelta::Error { message, .. }) if message == "still 401 Unauthorized"
711        ));
712        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
713        assert_eq!(mock.stream_call_count(), 2);
714        Ok(())
715    }
716
717    // Custom deterministic mock for the concurrent scenario: the first
718    // two chat() calls both wait on a barrier so both concurrent tasks
719    // observe a 401 on their initial call, then all subsequent calls
720    // return Success. This avoids flakiness from scheduling order in a
721    // shared FIFO queue.
722    #[derive(Clone)]
723    struct ConcurrentMock {
724        model: String,
725        provider_name: &'static str,
726        total_calls: Arc<AtomicUsize>,
727        initial_barrier: Arc<tokio::sync::Barrier>,
728    }
729
730    type CMFut = std::pin::Pin<Box<dyn Future<Output = Result<ConcurrentMock>> + Send>>;
731    type CMRefresh = Box<dyn Fn() -> CMFut + Send + Sync + 'static>;
732
733    #[async_trait]
734    impl LlmProvider for ConcurrentMock {
735        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
736            let call_index = self.total_calls.fetch_add(1, Ordering::SeqCst);
737            if call_index < 2 {
738                self.initial_barrier.wait().await;
739                Ok(ChatOutcome::InvalidRequest("401 Unauthorized".into()))
740            } else {
741                Ok(ChatOutcome::Success(success_response()))
742            }
743        }
744
745        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
746            Box::pin(async_stream::stream! {
747                yield Err(anyhow::anyhow!("chat_stream not used in this test"));
748            })
749        }
750
751        fn model(&self) -> &str {
752            &self.model
753        }
754
755        fn provider(&self) -> &'static str {
756            self.provider_name
757        }
758    }
759
760    // Test 9
761    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762    async fn chat_concurrent_callers_share_refresh() -> Result<()> {
763        let mock = ConcurrentMock {
764            model: "mock-model".to_string(),
765            provider_name: "mock",
766            total_calls: Arc::new(AtomicUsize::new(0)),
767            initial_barrier: Arc::new(tokio::sync::Barrier::new(2)),
768        };
769        let call_count = Arc::clone(&mock.total_calls);
770        let refresh_count = Arc::new(AtomicUsize::new(0));
771        let refresh_counter = Arc::clone(&refresh_count);
772        let template = mock.clone();
773
774        let cb: CMRefresh = Box::new(move || {
775            refresh_counter.fetch_add(1, Ordering::SeqCst);
776            let provider = template.clone();
777            Box::pin(async move { Ok(provider) })
778        });
779        let wrapped = RefreshingProvider::new(mock, cb);
780
781        let a = wrapped.clone();
782        let b = wrapped.clone();
783        let task_a = tokio::spawn(async move { a.chat(empty_request()).await });
784        let task_b = tokio::spawn(async move { b.chat(empty_request()).await });
785
786        let outcome_a = task_a.await.context("task_a join")??;
787        let outcome_b = task_b.await.context("task_b join")??;
788
789        assert!(matches!(outcome_a, ChatOutcome::Success(_)));
790        assert!(matches!(outcome_b, ChatOutcome::Success(_)));
791        assert_eq!(call_count.load(Ordering::SeqCst), 4);
792        let refreshes = refresh_count.load(Ordering::SeqCst);
793        assert!(
794            refreshes <= 2,
795            "expected at most 2 refresh calls (one per caller), got {refreshes}"
796        );
797        Ok(())
798    }
799}