Skip to main content

agent_sdk_providers/
refresh.rs

1//! Provider wrapper that refreshes credentials on 401 and retries once.
2//!
3//! [`RefreshingProvider`] wraps any [`LlmProvider`] and adds a host-driven
4//! credential refresh step: when the inner provider reports an unauthorized
5//! error (HTTP 401, expired OAuth token, invalid API key, etc.), the wrapper
6//! calls a host-supplied async callback to rebuild the inner provider with
7//! fresh credentials and retries the original request once.
8//!
9//! This is the generic form of the per-provider refresh wrappers that
10//! OAuth-backed hosts would otherwise copy across every provider they use.
11//!
12//! # Example
13//!
14//! ```no_run
15//! # use std::sync::Arc;
16//! # use anyhow::Result;
17//! # use agent_sdk_providers::{LlmProvider, RefreshingProvider};
18//! # async fn demo<P: LlmProvider + Clone + 'static>(initial: P) -> Result<()> {
19//! let refreshed_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
20//! let counter = Arc::clone(&refreshed_count);
21//! let wrapped = RefreshingProvider::new(initial.clone(), move || {
22//!     let counter = Arc::clone(&counter);
23//!     let provider = initial.clone();
24//!     async move {
25//!         counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
26//!         Ok(provider)
27//!     }
28//! });
29//! # let _ = wrapped;
30//! # Ok(())
31//! # }
32//! ```
33//!
34//! ## Streaming semantics
35//!
36//! For non-streaming [`chat`](LlmProvider::chat), the retry happens when the
37//! first call resolves to [`ChatOutcome::InvalidRequest`] with a 401-looking
38//! message. The second call replaces the first — callers only see the final
39//! outcome.
40//!
41//! For [`chat_stream`](LlmProvider::chat_stream), retry happens only when the
42//! 401 arrives before any content delta is forwarded. If a
43//! [`StreamDelta::TextDelta`], [`StreamDelta::ThinkingDelta`], tool-call
44//! delta, or thinking signature has already been yielded to the consumer,
45//! the error is forwarded as-is — retrying would duplicate partial output.
46//!
47//! At most one retry happens per call, whether streaming or not. If the
48//! retried call fails in the same way, the wrapper surfaces the second
49//! error unchanged.
50
51use std::future::Future;
52use std::sync::Arc;
53
54use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ThinkingConfig};
55use anyhow::Result;
56use async_trait::async_trait;
57use futures::StreamExt;
58use tokio::sync::Mutex;
59
60use crate::provider::LlmProvider;
61use crate::streaming::{StreamBox, StreamDelta};
62
63/// Wraps a provider with host-driven credential refresh on 401.
64///
65/// The inner provider is stored behind `Arc<Mutex<P>>` so it can be swapped
66/// atomically when the refresh callback produces a new provider. Cloning a
67/// wrapper is cheap — clones share the same inner state.
68///
69/// Metadata (`model`, `provider`, `configured_thinking`) is captured from the
70/// initial provider at construction time and assumed constant across
71/// refreshes (the refresh callback rebuilds the same provider shape with a
72/// fresh token, not a different model).
73pub struct RefreshingProvider<P, F> {
74    inner: Arc<Mutex<P>>,
75    refresh: Arc<F>,
76    model: String,
77    provider: &'static str,
78    thinking: Option<ThinkingConfig>,
79}
80
81impl<P, F> Clone for RefreshingProvider<P, F> {
82    fn clone(&self) -> Self {
83        Self {
84            inner: Arc::clone(&self.inner),
85            refresh: Arc::clone(&self.refresh),
86            model: self.model.clone(),
87            provider: self.provider,
88            thinking: self.thinking.clone(),
89        }
90    }
91}
92
93impl<P, F, Fut> RefreshingProvider<P, F>
94where
95    P: LlmProvider + Clone + 'static,
96    F: Fn() -> Fut + Send + Sync + 'static,
97    Fut: Future<Output = Result<P>> + Send + 'static,
98{
99    /// Build a wrapper from an initial provider and a refresh callback.
100    ///
101    /// The refresh callback is invoked each time the inner provider emits a
102    /// 401 response. It must be idempotent and safe to call concurrently.
103    /// The callback should return a fully-built provider ready to use;
104    /// typically it reads fresh credentials from its auth store and calls
105    /// the inner provider's constructor.
106    #[must_use]
107    pub fn new(inner: P, refresh: F) -> Self {
108        let model = inner.model().to_string();
109        let provider = inner.provider();
110        let thinking = inner.configured_thinking().cloned();
111        Self {
112            inner: Arc::new(Mutex::new(inner)),
113            refresh: Arc::new(refresh),
114            model,
115            provider,
116            thinking,
117        }
118    }
119
120    async fn snapshot(&self) -> P {
121        self.inner.lock().await.clone()
122    }
123
124    async fn run_refresh(&self) -> Result<()> {
125        let fresh = (self.refresh)().await?;
126        *self.inner.lock().await = fresh;
127        Ok(())
128    }
129}
130
131/// Classify a provider error message as a 401 / unauthorized condition.
132///
133/// Returns `true` when the message looks like the provider rejected the
134/// request because of missing, invalid, or expired credentials. Detection is
135/// case-insensitive and matches the error-body shapes emitted by the
136/// first-party providers in this crate. Hosts that implement their own retry
137/// logic on top of [`LlmProvider`] can gate on the same helper.
138#[must_use]
139pub fn is_unauthorized_error(message: &str) -> bool {
140    let lower = message.to_ascii_lowercase();
141    lower.contains(" 401")
142        || lower.contains("status=401")
143        || lower.contains("unauthorized")
144        || lower.contains("authentication")
145        || lower.contains("token_expired")
146        || lower.contains("invalid api key")
147        || lower.contains("invalid_api_key")
148}
149
150#[async_trait]
151impl<P, F, Fut> LlmProvider for RefreshingProvider<P, F>
152where
153    P: LlmProvider + Clone + 'static,
154    F: Fn() -> Fut + Send + Sync + 'static,
155    Fut: Future<Output = Result<P>> + Send + 'static,
156{
157    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
158        let outcome = self.snapshot().await.chat(request.clone()).await?;
159        if let ChatOutcome::InvalidRequest(message) = &outcome
160            && is_unauthorized_error(message)
161        {
162            match self.run_refresh().await {
163                Ok(()) => return self.snapshot().await.chat(request).await,
164                Err(error) => {
165                    log::warn!("RefreshingProvider refresh after 401 failed: {error:#}");
166                }
167            }
168        }
169        Ok(outcome)
170    }
171
172    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
173        let this = self.clone();
174        Box::pin(async_stream::stream! {
175            let mut refreshed = false;
176            'attempts: loop {
177                let provider = this.snapshot().await;
178                let mut stream = provider.chat_stream(request.clone());
179                let mut saw_output = false;
180
181                while let Some(item) = stream.next().await {
182                    match item {
183                        Ok(StreamDelta::Error { message, kind })
184                            if !saw_output
185                                && !refreshed
186                                && is_unauthorized_error(&message) =>
187                        {
188                            match this.run_refresh().await {
189                                Ok(()) => {
190                                    refreshed = true;
191                                    continue 'attempts;
192                                }
193                                Err(error) => {
194                                    log::warn!(
195                                        "RefreshingProvider refresh after streaming 401 failed: {error:#}"
196                                    );
197                                    yield Ok(StreamDelta::Error { message, kind });
198                                    return;
199                                }
200                            }
201                        }
202                        Ok(delta) => {
203                            if matches!(
204                                delta,
205                                StreamDelta::TextDelta { .. }
206                                    | StreamDelta::ThinkingDelta { .. }
207                                    | StreamDelta::ToolUseStart { .. }
208                                    | StreamDelta::ToolInputDelta { .. }
209                                    | StreamDelta::SignatureDelta { .. }
210                                    | StreamDelta::RedactedThinking { .. }
211                            ) {
212                                saw_output = true;
213                            }
214                            let done = matches!(delta, StreamDelta::Done { .. });
215                            yield Ok(delta);
216                            if done {
217                                return;
218                            }
219                        }
220                        Err(error)
221                            if !saw_output
222                                && !refreshed
223                                && is_unauthorized_error(&error.to_string()) =>
224                        {
225                            match this.run_refresh().await {
226                                Ok(()) => {
227                                    refreshed = true;
228                                    continue 'attempts;
229                                }
230                                Err(refresh_error) => {
231                                    log::warn!(
232                                        "RefreshingProvider refresh after stream failure failed: {refresh_error:#}"
233                                    );
234                                    yield Err(error);
235                                    return;
236                                }
237                            }
238                        }
239                        Err(error) => {
240                            yield Err(error);
241                            return;
242                        }
243                    }
244                }
245                return;
246            }
247        })
248    }
249
250    fn model(&self) -> &str {
251        &self.model
252    }
253
254    fn provider(&self) -> &'static str {
255        self.provider
256    }
257
258    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
259        self.thinking.as_ref()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    use std::collections::VecDeque;
268    use std::sync::Mutex as StdMutex;
269    use std::sync::atomic::{AtomicUsize, Ordering};
270
271    use agent_sdk_foundation::llm::{ChatResponse, ContentBlock, StopReason, Usage};
272    use anyhow::Context;
273
274    use crate::streaming::StreamErrorKind;
275
276    #[derive(Clone)]
277    enum MockStreamItem {
278        Ok(StreamDelta),
279        Err(String),
280    }
281
282    #[derive(Clone)]
283    struct MockProvider {
284        model: String,
285        provider_name: &'static str,
286        outcomes: Arc<StdMutex<VecDeque<ChatOutcome>>>,
287        stream_batches: Arc<StdMutex<VecDeque<Vec<MockStreamItem>>>>,
288        chat_calls: Arc<AtomicUsize>,
289        stream_calls: Arc<AtomicUsize>,
290    }
291
292    impl MockProvider {
293        fn new() -> Self {
294            Self {
295                model: "mock-model".to_string(),
296                provider_name: "mock",
297                outcomes: Arc::new(StdMutex::new(VecDeque::new())),
298                stream_batches: Arc::new(StdMutex::new(VecDeque::new())),
299                chat_calls: Arc::new(AtomicUsize::new(0)),
300                stream_calls: Arc::new(AtomicUsize::new(0)),
301            }
302        }
303
304        fn queue_chat(&self, outcome: ChatOutcome) -> Result<()> {
305            self.outcomes
306                .lock()
307                .ok()
308                .context("outcomes lock poisoned")?
309                .push_back(outcome);
310            Ok(())
311        }
312
313        fn queue_stream(&self, batch: Vec<MockStreamItem>) -> Result<()> {
314            self.stream_batches
315                .lock()
316                .ok()
317                .context("stream_batches lock poisoned")?
318                .push_back(batch);
319            Ok(())
320        }
321
322        fn chat_call_count(&self) -> usize {
323            self.chat_calls.load(Ordering::SeqCst)
324        }
325
326        fn stream_call_count(&self) -> usize {
327            self.stream_calls.load(Ordering::SeqCst)
328        }
329    }
330
331    #[async_trait]
332    impl LlmProvider for MockProvider {
333        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
334            self.chat_calls.fetch_add(1, Ordering::SeqCst);
335            let mut queue = self
336                .outcomes
337                .lock()
338                .ok()
339                .context("outcomes lock poisoned")?;
340            queue.pop_front().context("MockProvider: no queued outcome")
341        }
342
343        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
344            self.stream_calls.fetch_add(1, Ordering::SeqCst);
345            let batch: Vec<MockStreamItem> = self
346                .stream_batches
347                .lock()
348                .ok()
349                .and_then(|mut q| q.pop_front())
350                .unwrap_or_else(|| vec![MockStreamItem::Err("no queued stream batch".into())]);
351            Box::pin(async_stream::stream! {
352                for item in batch {
353                    match item {
354                        MockStreamItem::Ok(delta) => yield Ok(delta),
355                        MockStreamItem::Err(msg) => {
356                            yield Err(anyhow::anyhow!(msg));
357                            return;
358                        }
359                    }
360                }
361            })
362        }
363
364        fn model(&self) -> &str {
365            &self.model
366        }
367
368        fn provider(&self) -> &'static str {
369            self.provider_name
370        }
371    }
372
373    fn success_response() -> ChatResponse {
374        ChatResponse {
375            id: "msg_test".to_string(),
376            content: vec![ContentBlock::Text {
377                text: "ok".to_string(),
378            }],
379            model: "mock-model".to_string(),
380            stop_reason: Some(StopReason::EndTurn),
381            usage: Usage {
382                input_tokens: 1,
383                output_tokens: 1,
384                cached_input_tokens: 0,
385                cache_creation_input_tokens: 0,
386            },
387        }
388    }
389
390    fn empty_request() -> ChatRequest {
391        ChatRequest {
392            system: String::new(),
393            messages: Vec::new(),
394            tools: None,
395            max_tokens: 100,
396            max_tokens_explicit: false,
397            session_id: None,
398            cached_content: None,
399            thinking: None,
400            tool_choice: None,
401            response_format: None,
402        }
403    }
404
405    type BoxedFut = std::pin::Pin<Box<dyn Future<Output = Result<MockProvider>> + Send>>;
406    type RefreshFn = Box<dyn Fn() -> BoxedFut + Send + Sync + 'static>;
407    type Wrapped = RefreshingProvider<MockProvider, RefreshFn>;
408
409    fn wrap_success(mock: &MockProvider, counter: &Arc<AtomicUsize>) -> Wrapped {
410        let counter = Arc::clone(counter);
411        let template = mock.clone();
412        let cb: RefreshFn = Box::new(move || {
413            counter.fetch_add(1, Ordering::SeqCst);
414            let provider = template.clone();
415            Box::pin(async move { Ok(provider) })
416        });
417        RefreshingProvider::new(mock.clone(), cb)
418    }
419
420    fn wrap_failure(
421        mock: &MockProvider,
422        counter: &Arc<AtomicUsize>,
423        error: &'static str,
424    ) -> Wrapped {
425        let counter = Arc::clone(counter);
426        let cb: RefreshFn = Box::new(move || {
427            counter.fetch_add(1, Ordering::SeqCst);
428            Box::pin(async move { Err(anyhow::anyhow!(error)) })
429        });
430        RefreshingProvider::new(mock.clone(), cb)
431    }
432
433    // Test 1
434    #[test]
435    fn is_unauthorized_error_matches_expected_strings() {
436        assert!(is_unauthorized_error("HTTP 401"));
437        assert!(is_unauthorized_error("status=401 Unauthorized"));
438        assert!(is_unauthorized_error("Invalid API key"));
439        assert!(is_unauthorized_error("invalid_api_key"));
440        assert!(is_unauthorized_error("token_expired"));
441        assert!(is_unauthorized_error("Authentication failed"));
442        assert!(is_unauthorized_error("UNAUTHORIZED"));
443
444        assert!(!is_unauthorized_error("rate limited"));
445        assert!(!is_unauthorized_error("network error"));
446        assert!(!is_unauthorized_error(""));
447        assert!(!is_unauthorized_error("internal server error"));
448    }
449
450    // Test 2
451    #[tokio::test]
452    async fn chat_successful_pass_through_does_not_refresh() -> Result<()> {
453        let mock = MockProvider::new();
454        mock.queue_chat(ChatOutcome::Success(success_response()))?;
455
456        let refresh_count = Arc::new(AtomicUsize::new(0));
457        let wrapped = wrap_success(&mock, &refresh_count);
458
459        let outcome = wrapped.chat(empty_request()).await?;
460        assert!(matches!(outcome, ChatOutcome::Success(_)));
461        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
462        assert_eq!(mock.chat_call_count(), 1);
463        Ok(())
464    }
465
466    // Test 3
467    #[tokio::test]
468    async fn chat_401_triggers_refresh_and_retries() -> Result<()> {
469        let mock = MockProvider::new();
470        mock.queue_chat(ChatOutcome::InvalidRequest("401 Unauthorized".into()))?;
471        mock.queue_chat(ChatOutcome::Success(success_response()))?;
472
473        let refresh_count = Arc::new(AtomicUsize::new(0));
474        let wrapped = wrap_success(&mock, &refresh_count);
475
476        let outcome = wrapped.chat(empty_request()).await?;
477        assert!(matches!(outcome, ChatOutcome::Success(_)));
478        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
479        assert_eq!(mock.chat_call_count(), 2);
480        Ok(())
481    }
482
483    // Test 4
484    #[tokio::test]
485    async fn chat_surfaces_original_401_when_refresh_fails() -> Result<()> {
486        let mock = MockProvider::new();
487        mock.queue_chat(ChatOutcome::InvalidRequest(
488            "status=401 Unauthorized".into(),
489        ))?;
490
491        let refresh_count = Arc::new(AtomicUsize::new(0));
492        let wrapped = wrap_failure(&mock, &refresh_count, "refresh callback failed");
493
494        let outcome = wrapped.chat(empty_request()).await?;
495        match outcome {
496            ChatOutcome::InvalidRequest(msg) => assert!(
497                msg.contains("401"),
498                "expected original 401 message, got {msg}"
499            ),
500            other => panic!("expected InvalidRequest, got {other:?}"),
501        }
502        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
503        assert_eq!(mock.chat_call_count(), 1);
504        Ok(())
505    }
506
507    async fn drain(mut stream: StreamBox<'_>) -> Vec<Result<StreamDelta>> {
508        let mut out = Vec::new();
509        while let Some(item) = stream.next().await {
510            out.push(item);
511        }
512        out
513    }
514
515    // Test 5
516    #[tokio::test]
517    async fn chat_stream_successful_pass_through() -> Result<()> {
518        let mock = MockProvider::new();
519        mock.queue_stream(vec![
520            MockStreamItem::Ok(StreamDelta::TextDelta {
521                delta: "hi".into(),
522                block_index: 0,
523            }),
524            MockStreamItem::Ok(StreamDelta::Done {
525                stop_reason: Some(StopReason::EndTurn),
526            }),
527        ])?;
528
529        let refresh_count = Arc::new(AtomicUsize::new(0));
530        let wrapped = wrap_success(&mock, &refresh_count);
531
532        let deltas = drain(wrapped.chat_stream(empty_request())).await;
533        assert_eq!(deltas.len(), 2);
534        assert!(matches!(
535            deltas[0].as_ref().ok(),
536            Some(StreamDelta::TextDelta { delta, .. }) if delta == "hi"
537        ));
538        assert!(matches!(
539            deltas[1].as_ref().ok(),
540            Some(StreamDelta::Done { .. })
541        ));
542        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
543        assert_eq!(mock.stream_call_count(), 1);
544        Ok(())
545    }
546
547    // Test 6
548    #[tokio::test]
549    async fn chat_stream_401_before_output_retries() -> Result<()> {
550        let mock = MockProvider::new();
551        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
552            message: "status=401 Unauthorized".into(),
553            kind: StreamErrorKind::InvalidRequest,
554        })])?;
555        mock.queue_stream(vec![
556            MockStreamItem::Ok(StreamDelta::TextDelta {
557                delta: "retried".into(),
558                block_index: 0,
559            }),
560            MockStreamItem::Ok(StreamDelta::Done {
561                stop_reason: Some(StopReason::EndTurn),
562            }),
563        ])?;
564
565        let refresh_count = Arc::new(AtomicUsize::new(0));
566        let wrapped = wrap_success(&mock, &refresh_count);
567
568        let deltas = drain(wrapped.chat_stream(empty_request())).await;
569        // Consumer sees only the post-refresh stream.
570        assert_eq!(deltas.len(), 2);
571        assert!(matches!(
572            deltas[0].as_ref().ok(),
573            Some(StreamDelta::TextDelta { delta, .. }) if delta == "retried"
574        ));
575        assert!(matches!(
576            deltas[1].as_ref().ok(),
577            Some(StreamDelta::Done { .. })
578        ));
579        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
580        assert_eq!(mock.stream_call_count(), 2);
581        Ok(())
582    }
583
584    // Test 7
585    #[tokio::test]
586    async fn chat_stream_401_after_output_does_not_retry() -> Result<()> {
587        let mock = MockProvider::new();
588        mock.queue_stream(vec![
589            MockStreamItem::Ok(StreamDelta::TextDelta {
590                delta: "partial".into(),
591                block_index: 0,
592            }),
593            MockStreamItem::Ok(StreamDelta::Error {
594                message: "401 Unauthorized".into(),
595                kind: StreamErrorKind::InvalidRequest,
596            }),
597        ])?;
598
599        let refresh_count = Arc::new(AtomicUsize::new(0));
600        let wrapped = wrap_success(&mock, &refresh_count);
601
602        let deltas = drain(wrapped.chat_stream(empty_request())).await;
603        assert_eq!(deltas.len(), 2);
604        assert!(matches!(
605            deltas[0].as_ref().ok(),
606            Some(StreamDelta::TextDelta { delta, .. }) if delta == "partial"
607        ));
608        assert!(matches!(
609            deltas[1].as_ref().ok(),
610            Some(StreamDelta::Error { message, .. }) if message.contains("401")
611        ));
612        assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
613        assert_eq!(mock.stream_call_count(), 1);
614        Ok(())
615    }
616
617    // Test 8
618    #[tokio::test]
619    async fn chat_stream_only_one_retry_per_call() -> Result<()> {
620        let mock = MockProvider::new();
621        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
622            message: "status=401 Unauthorized".into(),
623            kind: StreamErrorKind::InvalidRequest,
624        })])?;
625        mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
626            message: "still 401 Unauthorized".into(),
627            kind: StreamErrorKind::InvalidRequest,
628        })])?;
629
630        let refresh_count = Arc::new(AtomicUsize::new(0));
631        let wrapped = wrap_success(&mock, &refresh_count);
632
633        let deltas = drain(wrapped.chat_stream(empty_request())).await;
634        assert_eq!(deltas.len(), 1);
635        assert!(matches!(
636            deltas[0].as_ref().ok(),
637            Some(StreamDelta::Error { message, .. }) if message == "still 401 Unauthorized"
638        ));
639        assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
640        assert_eq!(mock.stream_call_count(), 2);
641        Ok(())
642    }
643
644    // Custom deterministic mock for the concurrent scenario: the first
645    // two chat() calls both wait on a barrier so both concurrent tasks
646    // observe a 401 on their initial call, then all subsequent calls
647    // return Success. This avoids flakiness from scheduling order in a
648    // shared FIFO queue.
649    #[derive(Clone)]
650    struct ConcurrentMock {
651        model: String,
652        provider_name: &'static str,
653        total_calls: Arc<AtomicUsize>,
654        initial_barrier: Arc<tokio::sync::Barrier>,
655    }
656
657    type CMFut = std::pin::Pin<Box<dyn Future<Output = Result<ConcurrentMock>> + Send>>;
658    type CMRefresh = Box<dyn Fn() -> CMFut + Send + Sync + 'static>;
659
660    #[async_trait]
661    impl LlmProvider for ConcurrentMock {
662        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
663            let call_index = self.total_calls.fetch_add(1, Ordering::SeqCst);
664            if call_index < 2 {
665                self.initial_barrier.wait().await;
666                Ok(ChatOutcome::InvalidRequest("401 Unauthorized".into()))
667            } else {
668                Ok(ChatOutcome::Success(success_response()))
669            }
670        }
671
672        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
673            Box::pin(async_stream::stream! {
674                yield Err(anyhow::anyhow!("chat_stream not used in this test"));
675            })
676        }
677
678        fn model(&self) -> &str {
679            &self.model
680        }
681
682        fn provider(&self) -> &'static str {
683            self.provider_name
684        }
685    }
686
687    // Test 9
688    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
689    async fn chat_concurrent_callers_share_refresh() -> Result<()> {
690        let mock = ConcurrentMock {
691            model: "mock-model".to_string(),
692            provider_name: "mock",
693            total_calls: Arc::new(AtomicUsize::new(0)),
694            initial_barrier: Arc::new(tokio::sync::Barrier::new(2)),
695        };
696        let call_count = Arc::clone(&mock.total_calls);
697        let refresh_count = Arc::new(AtomicUsize::new(0));
698        let refresh_counter = Arc::clone(&refresh_count);
699        let template = mock.clone();
700
701        let cb: CMRefresh = Box::new(move || {
702            refresh_counter.fetch_add(1, Ordering::SeqCst);
703            let provider = template.clone();
704            Box::pin(async move { Ok(provider) })
705        });
706        let wrapped = RefreshingProvider::new(mock, cb);
707
708        let a = wrapped.clone();
709        let b = wrapped.clone();
710        let task_a = tokio::spawn(async move { a.chat(empty_request()).await });
711        let task_b = tokio::spawn(async move { b.chat(empty_request()).await });
712
713        let outcome_a = task_a.await.context("task_a join")??;
714        let outcome_b = task_b.await.context("task_b join")??;
715
716        assert!(matches!(outcome_a, ChatOutcome::Success(_)));
717        assert!(matches!(outcome_b, ChatOutcome::Success(_)));
718        assert_eq!(call_count.load(Ordering::SeqCst), 4);
719        let refreshes = refresh_count.load(Ordering::SeqCst);
720        assert!(
721            refreshes <= 2,
722            "expected at most 2 refresh calls (one per caller), got {refreshes}"
723        );
724        Ok(())
725    }
726}