1use std::future::Future;
52use std::sync::Arc;
53
54use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ThinkingConfig};
55use anyhow::Result;
56use async_trait::async_trait;
57use futures::StreamExt;
58use tokio::sync::Mutex;
59
60use crate::provider::LlmProvider;
61use crate::streaming::{StreamBox, StreamDelta};
62
63pub struct RefreshingProvider<P, F> {
74 inner: Arc<Mutex<P>>,
75 refresh: Arc<F>,
76 model: String,
77 provider: &'static str,
78 thinking: Option<ThinkingConfig>,
79}
80
81impl<P, F> Clone for RefreshingProvider<P, F> {
82 fn clone(&self) -> Self {
83 Self {
84 inner: Arc::clone(&self.inner),
85 refresh: Arc::clone(&self.refresh),
86 model: self.model.clone(),
87 provider: self.provider,
88 thinking: self.thinking.clone(),
89 }
90 }
91}
92
93impl<P, F, Fut> RefreshingProvider<P, F>
94where
95 P: LlmProvider + Clone + 'static,
96 F: Fn() -> Fut + Send + Sync + 'static,
97 Fut: Future<Output = Result<P>> + Send + 'static,
98{
99 #[must_use]
107 pub fn new(inner: P, refresh: F) -> Self {
108 let model = inner.model().to_string();
109 let provider = inner.provider();
110 let thinking = inner.configured_thinking().cloned();
111 Self {
112 inner: Arc::new(Mutex::new(inner)),
113 refresh: Arc::new(refresh),
114 model,
115 provider,
116 thinking,
117 }
118 }
119
120 async fn snapshot(&self) -> P {
121 self.inner.lock().await.clone()
122 }
123
124 async fn run_refresh(&self) -> Result<()> {
125 let fresh = (self.refresh)().await?;
126 *self.inner.lock().await = fresh;
127 Ok(())
128 }
129}
130
131#[must_use]
139pub fn is_unauthorized_error(message: &str) -> bool {
140 let lower = message.to_ascii_lowercase();
141 lower.contains(" 401")
142 || lower.contains("status=401")
143 || lower.contains("unauthorized")
144 || lower.contains("authentication")
145 || lower.contains("token_expired")
146 || lower.contains("invalid api key")
147 || lower.contains("invalid_api_key")
148}
149
150#[async_trait]
151impl<P, F, Fut> LlmProvider for RefreshingProvider<P, F>
152where
153 P: LlmProvider + Clone + 'static,
154 F: Fn() -> Fut + Send + Sync + 'static,
155 Fut: Future<Output = Result<P>> + Send + 'static,
156{
157 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
158 let outcome = self.snapshot().await.chat(request.clone()).await?;
159 if let ChatOutcome::InvalidRequest(message) = &outcome
160 && is_unauthorized_error(message)
161 {
162 match self.run_refresh().await {
163 Ok(()) => return self.snapshot().await.chat(request).await,
164 Err(error) => {
165 log::warn!("RefreshingProvider refresh after 401 failed: {error:#}");
166 }
167 }
168 }
169 Ok(outcome)
170 }
171
172 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
173 let this = self.clone();
174 Box::pin(async_stream::stream! {
175 let mut refreshed = false;
176 'attempts: loop {
177 let provider = this.snapshot().await;
178 let mut stream = provider.chat_stream(request.clone());
179 let mut saw_output = false;
180
181 while let Some(item) = stream.next().await {
182 match item {
183 Ok(StreamDelta::Error { message, kind })
184 if !saw_output
185 && !refreshed
186 && is_unauthorized_error(&message) =>
187 {
188 match this.run_refresh().await {
189 Ok(()) => {
190 refreshed = true;
191 continue 'attempts;
192 }
193 Err(error) => {
194 log::warn!(
195 "RefreshingProvider refresh after streaming 401 failed: {error:#}"
196 );
197 yield Ok(StreamDelta::Error { message, kind });
198 return;
199 }
200 }
201 }
202 Ok(delta) => {
203 if matches!(
204 delta,
205 StreamDelta::TextDelta { .. }
206 | StreamDelta::ThinkingDelta { .. }
207 | StreamDelta::ToolUseStart { .. }
208 | StreamDelta::ToolInputDelta { .. }
209 | StreamDelta::SignatureDelta { .. }
210 | StreamDelta::RedactedThinking { .. }
211 ) {
212 saw_output = true;
213 }
214 let done = matches!(delta, StreamDelta::Done { .. });
215 yield Ok(delta);
216 if done {
217 return;
218 }
219 }
220 Err(error)
221 if !saw_output
222 && !refreshed
223 && is_unauthorized_error(&error.to_string()) =>
224 {
225 match this.run_refresh().await {
226 Ok(()) => {
227 refreshed = true;
228 continue 'attempts;
229 }
230 Err(refresh_error) => {
231 log::warn!(
232 "RefreshingProvider refresh after stream failure failed: {refresh_error:#}"
233 );
234 yield Err(error);
235 return;
236 }
237 }
238 }
239 Err(error) => {
240 yield Err(error);
241 return;
242 }
243 }
244 }
245 return;
246 }
247 })
248 }
249
250 fn model(&self) -> &str {
251 &self.model
252 }
253
254 fn provider(&self) -> &'static str {
255 self.provider
256 }
257
258 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
259 self.thinking.as_ref()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 use std::collections::VecDeque;
268 use std::sync::Mutex as StdMutex;
269 use std::sync::atomic::{AtomicUsize, Ordering};
270
271 use agent_sdk_foundation::llm::{ChatResponse, ContentBlock, StopReason, Usage};
272 use anyhow::Context;
273
274 use crate::streaming::StreamErrorKind;
275
276 #[derive(Clone)]
277 enum MockStreamItem {
278 Ok(StreamDelta),
279 Err(String),
280 }
281
282 #[derive(Clone)]
283 struct MockProvider {
284 model: String,
285 provider_name: &'static str,
286 outcomes: Arc<StdMutex<VecDeque<ChatOutcome>>>,
287 stream_batches: Arc<StdMutex<VecDeque<Vec<MockStreamItem>>>>,
288 chat_calls: Arc<AtomicUsize>,
289 stream_calls: Arc<AtomicUsize>,
290 }
291
292 impl MockProvider {
293 fn new() -> Self {
294 Self {
295 model: "mock-model".to_string(),
296 provider_name: "mock",
297 outcomes: Arc::new(StdMutex::new(VecDeque::new())),
298 stream_batches: Arc::new(StdMutex::new(VecDeque::new())),
299 chat_calls: Arc::new(AtomicUsize::new(0)),
300 stream_calls: Arc::new(AtomicUsize::new(0)),
301 }
302 }
303
304 fn queue_chat(&self, outcome: ChatOutcome) -> Result<()> {
305 self.outcomes
306 .lock()
307 .ok()
308 .context("outcomes lock poisoned")?
309 .push_back(outcome);
310 Ok(())
311 }
312
313 fn queue_stream(&self, batch: Vec<MockStreamItem>) -> Result<()> {
314 self.stream_batches
315 .lock()
316 .ok()
317 .context("stream_batches lock poisoned")?
318 .push_back(batch);
319 Ok(())
320 }
321
322 fn chat_call_count(&self) -> usize {
323 self.chat_calls.load(Ordering::SeqCst)
324 }
325
326 fn stream_call_count(&self) -> usize {
327 self.stream_calls.load(Ordering::SeqCst)
328 }
329 }
330
331 #[async_trait]
332 impl LlmProvider for MockProvider {
333 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
334 self.chat_calls.fetch_add(1, Ordering::SeqCst);
335 let mut queue = self
336 .outcomes
337 .lock()
338 .ok()
339 .context("outcomes lock poisoned")?;
340 queue.pop_front().context("MockProvider: no queued outcome")
341 }
342
343 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
344 self.stream_calls.fetch_add(1, Ordering::SeqCst);
345 let batch: Vec<MockStreamItem> = self
346 .stream_batches
347 .lock()
348 .ok()
349 .and_then(|mut q| q.pop_front())
350 .unwrap_or_else(|| vec![MockStreamItem::Err("no queued stream batch".into())]);
351 Box::pin(async_stream::stream! {
352 for item in batch {
353 match item {
354 MockStreamItem::Ok(delta) => yield Ok(delta),
355 MockStreamItem::Err(msg) => {
356 yield Err(anyhow::anyhow!(msg));
357 return;
358 }
359 }
360 }
361 })
362 }
363
364 fn model(&self) -> &str {
365 &self.model
366 }
367
368 fn provider(&self) -> &'static str {
369 self.provider_name
370 }
371 }
372
373 fn success_response() -> ChatResponse {
374 ChatResponse {
375 id: "msg_test".to_string(),
376 content: vec![ContentBlock::Text {
377 text: "ok".to_string(),
378 }],
379 model: "mock-model".to_string(),
380 stop_reason: Some(StopReason::EndTurn),
381 usage: Usage {
382 input_tokens: 1,
383 output_tokens: 1,
384 cached_input_tokens: 0,
385 cache_creation_input_tokens: 0,
386 },
387 }
388 }
389
390 fn empty_request() -> ChatRequest {
391 ChatRequest {
392 system: String::new(),
393 messages: Vec::new(),
394 tools: None,
395 max_tokens: 100,
396 max_tokens_explicit: false,
397 session_id: None,
398 cached_content: None,
399 thinking: None,
400 tool_choice: None,
401 response_format: None,
402 }
403 }
404
405 type BoxedFut = std::pin::Pin<Box<dyn Future<Output = Result<MockProvider>> + Send>>;
406 type RefreshFn = Box<dyn Fn() -> BoxedFut + Send + Sync + 'static>;
407 type Wrapped = RefreshingProvider<MockProvider, RefreshFn>;
408
409 fn wrap_success(mock: &MockProvider, counter: &Arc<AtomicUsize>) -> Wrapped {
410 let counter = Arc::clone(counter);
411 let template = mock.clone();
412 let cb: RefreshFn = Box::new(move || {
413 counter.fetch_add(1, Ordering::SeqCst);
414 let provider = template.clone();
415 Box::pin(async move { Ok(provider) })
416 });
417 RefreshingProvider::new(mock.clone(), cb)
418 }
419
420 fn wrap_failure(
421 mock: &MockProvider,
422 counter: &Arc<AtomicUsize>,
423 error: &'static str,
424 ) -> Wrapped {
425 let counter = Arc::clone(counter);
426 let cb: RefreshFn = Box::new(move || {
427 counter.fetch_add(1, Ordering::SeqCst);
428 Box::pin(async move { Err(anyhow::anyhow!(error)) })
429 });
430 RefreshingProvider::new(mock.clone(), cb)
431 }
432
433 #[test]
435 fn is_unauthorized_error_matches_expected_strings() {
436 assert!(is_unauthorized_error("HTTP 401"));
437 assert!(is_unauthorized_error("status=401 Unauthorized"));
438 assert!(is_unauthorized_error("Invalid API key"));
439 assert!(is_unauthorized_error("invalid_api_key"));
440 assert!(is_unauthorized_error("token_expired"));
441 assert!(is_unauthorized_error("Authentication failed"));
442 assert!(is_unauthorized_error("UNAUTHORIZED"));
443
444 assert!(!is_unauthorized_error("rate limited"));
445 assert!(!is_unauthorized_error("network error"));
446 assert!(!is_unauthorized_error(""));
447 assert!(!is_unauthorized_error("internal server error"));
448 }
449
450 #[tokio::test]
452 async fn chat_successful_pass_through_does_not_refresh() -> Result<()> {
453 let mock = MockProvider::new();
454 mock.queue_chat(ChatOutcome::Success(success_response()))?;
455
456 let refresh_count = Arc::new(AtomicUsize::new(0));
457 let wrapped = wrap_success(&mock, &refresh_count);
458
459 let outcome = wrapped.chat(empty_request()).await?;
460 assert!(matches!(outcome, ChatOutcome::Success(_)));
461 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
462 assert_eq!(mock.chat_call_count(), 1);
463 Ok(())
464 }
465
466 #[tokio::test]
468 async fn chat_401_triggers_refresh_and_retries() -> Result<()> {
469 let mock = MockProvider::new();
470 mock.queue_chat(ChatOutcome::InvalidRequest("401 Unauthorized".into()))?;
471 mock.queue_chat(ChatOutcome::Success(success_response()))?;
472
473 let refresh_count = Arc::new(AtomicUsize::new(0));
474 let wrapped = wrap_success(&mock, &refresh_count);
475
476 let outcome = wrapped.chat(empty_request()).await?;
477 assert!(matches!(outcome, ChatOutcome::Success(_)));
478 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
479 assert_eq!(mock.chat_call_count(), 2);
480 Ok(())
481 }
482
483 #[tokio::test]
485 async fn chat_surfaces_original_401_when_refresh_fails() -> Result<()> {
486 let mock = MockProvider::new();
487 mock.queue_chat(ChatOutcome::InvalidRequest(
488 "status=401 Unauthorized".into(),
489 ))?;
490
491 let refresh_count = Arc::new(AtomicUsize::new(0));
492 let wrapped = wrap_failure(&mock, &refresh_count, "refresh callback failed");
493
494 let outcome = wrapped.chat(empty_request()).await?;
495 match outcome {
496 ChatOutcome::InvalidRequest(msg) => assert!(
497 msg.contains("401"),
498 "expected original 401 message, got {msg}"
499 ),
500 other => panic!("expected InvalidRequest, got {other:?}"),
501 }
502 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
503 assert_eq!(mock.chat_call_count(), 1);
504 Ok(())
505 }
506
507 async fn drain(mut stream: StreamBox<'_>) -> Vec<Result<StreamDelta>> {
508 let mut out = Vec::new();
509 while let Some(item) = stream.next().await {
510 out.push(item);
511 }
512 out
513 }
514
515 #[tokio::test]
517 async fn chat_stream_successful_pass_through() -> Result<()> {
518 let mock = MockProvider::new();
519 mock.queue_stream(vec![
520 MockStreamItem::Ok(StreamDelta::TextDelta {
521 delta: "hi".into(),
522 block_index: 0,
523 }),
524 MockStreamItem::Ok(StreamDelta::Done {
525 stop_reason: Some(StopReason::EndTurn),
526 }),
527 ])?;
528
529 let refresh_count = Arc::new(AtomicUsize::new(0));
530 let wrapped = wrap_success(&mock, &refresh_count);
531
532 let deltas = drain(wrapped.chat_stream(empty_request())).await;
533 assert_eq!(deltas.len(), 2);
534 assert!(matches!(
535 deltas[0].as_ref().ok(),
536 Some(StreamDelta::TextDelta { delta, .. }) if delta == "hi"
537 ));
538 assert!(matches!(
539 deltas[1].as_ref().ok(),
540 Some(StreamDelta::Done { .. })
541 ));
542 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
543 assert_eq!(mock.stream_call_count(), 1);
544 Ok(())
545 }
546
547 #[tokio::test]
549 async fn chat_stream_401_before_output_retries() -> Result<()> {
550 let mock = MockProvider::new();
551 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
552 message: "status=401 Unauthorized".into(),
553 kind: StreamErrorKind::InvalidRequest,
554 })])?;
555 mock.queue_stream(vec![
556 MockStreamItem::Ok(StreamDelta::TextDelta {
557 delta: "retried".into(),
558 block_index: 0,
559 }),
560 MockStreamItem::Ok(StreamDelta::Done {
561 stop_reason: Some(StopReason::EndTurn),
562 }),
563 ])?;
564
565 let refresh_count = Arc::new(AtomicUsize::new(0));
566 let wrapped = wrap_success(&mock, &refresh_count);
567
568 let deltas = drain(wrapped.chat_stream(empty_request())).await;
569 assert_eq!(deltas.len(), 2);
571 assert!(matches!(
572 deltas[0].as_ref().ok(),
573 Some(StreamDelta::TextDelta { delta, .. }) if delta == "retried"
574 ));
575 assert!(matches!(
576 deltas[1].as_ref().ok(),
577 Some(StreamDelta::Done { .. })
578 ));
579 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
580 assert_eq!(mock.stream_call_count(), 2);
581 Ok(())
582 }
583
584 #[tokio::test]
586 async fn chat_stream_401_after_output_does_not_retry() -> Result<()> {
587 let mock = MockProvider::new();
588 mock.queue_stream(vec![
589 MockStreamItem::Ok(StreamDelta::TextDelta {
590 delta: "partial".into(),
591 block_index: 0,
592 }),
593 MockStreamItem::Ok(StreamDelta::Error {
594 message: "401 Unauthorized".into(),
595 kind: StreamErrorKind::InvalidRequest,
596 }),
597 ])?;
598
599 let refresh_count = Arc::new(AtomicUsize::new(0));
600 let wrapped = wrap_success(&mock, &refresh_count);
601
602 let deltas = drain(wrapped.chat_stream(empty_request())).await;
603 assert_eq!(deltas.len(), 2);
604 assert!(matches!(
605 deltas[0].as_ref().ok(),
606 Some(StreamDelta::TextDelta { delta, .. }) if delta == "partial"
607 ));
608 assert!(matches!(
609 deltas[1].as_ref().ok(),
610 Some(StreamDelta::Error { message, .. }) if message.contains("401")
611 ));
612 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
613 assert_eq!(mock.stream_call_count(), 1);
614 Ok(())
615 }
616
617 #[tokio::test]
619 async fn chat_stream_only_one_retry_per_call() -> Result<()> {
620 let mock = MockProvider::new();
621 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
622 message: "status=401 Unauthorized".into(),
623 kind: StreamErrorKind::InvalidRequest,
624 })])?;
625 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
626 message: "still 401 Unauthorized".into(),
627 kind: StreamErrorKind::InvalidRequest,
628 })])?;
629
630 let refresh_count = Arc::new(AtomicUsize::new(0));
631 let wrapped = wrap_success(&mock, &refresh_count);
632
633 let deltas = drain(wrapped.chat_stream(empty_request())).await;
634 assert_eq!(deltas.len(), 1);
635 assert!(matches!(
636 deltas[0].as_ref().ok(),
637 Some(StreamDelta::Error { message, .. }) if message == "still 401 Unauthorized"
638 ));
639 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
640 assert_eq!(mock.stream_call_count(), 2);
641 Ok(())
642 }
643
644 #[derive(Clone)]
650 struct ConcurrentMock {
651 model: String,
652 provider_name: &'static str,
653 total_calls: Arc<AtomicUsize>,
654 initial_barrier: Arc<tokio::sync::Barrier>,
655 }
656
657 type CMFut = std::pin::Pin<Box<dyn Future<Output = Result<ConcurrentMock>> + Send>>;
658 type CMRefresh = Box<dyn Fn() -> CMFut + Send + Sync + 'static>;
659
660 #[async_trait]
661 impl LlmProvider for ConcurrentMock {
662 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
663 let call_index = self.total_calls.fetch_add(1, Ordering::SeqCst);
664 if call_index < 2 {
665 self.initial_barrier.wait().await;
666 Ok(ChatOutcome::InvalidRequest("401 Unauthorized".into()))
667 } else {
668 Ok(ChatOutcome::Success(success_response()))
669 }
670 }
671
672 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
673 Box::pin(async_stream::stream! {
674 yield Err(anyhow::anyhow!("chat_stream not used in this test"));
675 })
676 }
677
678 fn model(&self) -> &str {
679 &self.model
680 }
681
682 fn provider(&self) -> &'static str {
683 self.provider_name
684 }
685 }
686
687 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
689 async fn chat_concurrent_callers_share_refresh() -> Result<()> {
690 let mock = ConcurrentMock {
691 model: "mock-model".to_string(),
692 provider_name: "mock",
693 total_calls: Arc::new(AtomicUsize::new(0)),
694 initial_barrier: Arc::new(tokio::sync::Barrier::new(2)),
695 };
696 let call_count = Arc::clone(&mock.total_calls);
697 let refresh_count = Arc::new(AtomicUsize::new(0));
698 let refresh_counter = Arc::clone(&refresh_count);
699 let template = mock.clone();
700
701 let cb: CMRefresh = Box::new(move || {
702 refresh_counter.fetch_add(1, Ordering::SeqCst);
703 let provider = template.clone();
704 Box::pin(async move { Ok(provider) })
705 });
706 let wrapped = RefreshingProvider::new(mock, cb);
707
708 let a = wrapped.clone();
709 let b = wrapped.clone();
710 let task_a = tokio::spawn(async move { a.chat(empty_request()).await });
711 let task_b = tokio::spawn(async move { b.chat(empty_request()).await });
712
713 let outcome_a = task_a.await.context("task_a join")??;
714 let outcome_b = task_b.await.context("task_b join")??;
715
716 assert!(matches!(outcome_a, ChatOutcome::Success(_)));
717 assert!(matches!(outcome_b, ChatOutcome::Success(_)));
718 assert_eq!(call_count.load(Ordering::SeqCst), 4);
719 let refreshes = refresh_count.load(Ordering::SeqCst);
720 assert!(
721 refreshes <= 2,
722 "expected at most 2 refresh calls (one per caller), got {refreshes}"
723 );
724 Ok(())
725 }
726}