1use std::future::Future;
52use std::sync::Arc;
53
54use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ThinkingConfig};
55use anyhow::Result;
56use async_trait::async_trait;
57use futures::StreamExt;
58use tokio::sync::Mutex;
59
60use crate::model_capabilities::ModelCapabilities;
61use crate::provider::{LlmProvider, StructuredOutputSupport};
62use crate::streaming::{StreamBox, StreamDelta};
63
64pub struct RefreshingProvider<P, F> {
77 inner: Arc<Mutex<P>>,
78 refresh: Arc<F>,
79 template: P,
84 model: String,
85 provider: &'static str,
86 thinking: Option<ThinkingConfig>,
87}
88
89impl<P: Clone, F> Clone for RefreshingProvider<P, F> {
90 fn clone(&self) -> Self {
91 Self {
92 inner: Arc::clone(&self.inner),
93 refresh: Arc::clone(&self.refresh),
94 template: self.template.clone(),
95 model: self.model.clone(),
96 provider: self.provider,
97 thinking: self.thinking.clone(),
98 }
99 }
100}
101
102impl<P, F, Fut> RefreshingProvider<P, F>
103where
104 P: LlmProvider + Clone + 'static,
105 F: Fn() -> Fut + Send + Sync + 'static,
106 Fut: Future<Output = Result<P>> + Send + 'static,
107{
108 #[must_use]
116 pub fn new(inner: P, refresh: F) -> Self {
117 let model = inner.model().to_string();
118 let provider = inner.provider();
119 let thinking = inner.configured_thinking().cloned();
120 let template = inner.clone();
121 Self {
122 inner: Arc::new(Mutex::new(inner)),
123 refresh: Arc::new(refresh),
124 template,
125 model,
126 provider,
127 thinking,
128 }
129 }
130
131 async fn snapshot(&self) -> P {
132 self.inner.lock().await.clone()
133 }
134
135 async fn run_refresh(&self) -> Result<()> {
136 let fresh = (self.refresh)().await?;
137 *self.inner.lock().await = fresh;
138 Ok(())
139 }
140}
141
142#[must_use]
150pub fn is_unauthorized_error(message: &str) -> bool {
151 let lower = message.to_ascii_lowercase();
152 lower.contains(" 401")
153 || lower.contains("status=401")
154 || lower.contains("unauthorized")
155 || lower.contains("authentication")
156 || lower.contains("token_expired")
157 || lower.contains("invalid api key")
158 || lower.contains("invalid_api_key")
159}
160
161#[async_trait]
162impl<P, F, Fut> LlmProvider for RefreshingProvider<P, F>
163where
164 P: LlmProvider + Clone + 'static,
165 F: Fn() -> Fut + Send + Sync + 'static,
166 Fut: Future<Output = Result<P>> + Send + 'static,
167{
168 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
169 let outcome = self.snapshot().await.chat(request.clone()).await?;
170 if let ChatOutcome::InvalidRequest(message) = &outcome
171 && is_unauthorized_error(message)
172 {
173 match self.run_refresh().await {
174 Ok(()) => return self.snapshot().await.chat(request).await,
175 Err(error) => {
176 log::warn!("RefreshingProvider refresh after 401 failed: {error:#}");
177 }
178 }
179 }
180 Ok(outcome)
181 }
182
183 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
184 let this = self.clone();
185 Box::pin(async_stream::stream! {
186 let mut refreshed = false;
187 'attempts: loop {
188 let provider = this.snapshot().await;
189 let mut stream = provider.chat_stream(request.clone());
190 let mut saw_output = false;
191
192 while let Some(item) = stream.next().await {
193 match item {
194 Ok(StreamDelta::Error { message, kind })
195 if !saw_output
196 && !refreshed
197 && is_unauthorized_error(&message) =>
198 {
199 match this.run_refresh().await {
200 Ok(()) => {
201 refreshed = true;
202 continue 'attempts;
203 }
204 Err(error) => {
205 log::warn!(
206 "RefreshingProvider refresh after streaming 401 failed: {error:#}"
207 );
208 yield Ok(StreamDelta::Error { message, kind });
209 return;
210 }
211 }
212 }
213 Ok(delta) => {
214 if matches!(
215 delta,
216 StreamDelta::TextDelta { .. }
217 | StreamDelta::ThinkingDelta { .. }
218 | StreamDelta::ToolUseStart { .. }
219 | StreamDelta::ToolInputDelta { .. }
220 | StreamDelta::SignatureDelta { .. }
221 | StreamDelta::RedactedThinking { .. }
222 ) {
223 saw_output = true;
224 }
225 let done = matches!(delta, StreamDelta::Done { .. });
226 yield Ok(delta);
227 if done {
228 return;
229 }
230 }
231 Err(error)
232 if !saw_output
233 && !refreshed
234 && is_unauthorized_error(&error.to_string()) =>
235 {
236 match this.run_refresh().await {
237 Ok(()) => {
238 refreshed = true;
239 continue 'attempts;
240 }
241 Err(refresh_error) => {
242 log::warn!(
243 "RefreshingProvider refresh after stream failure failed: {refresh_error:#}"
244 );
245 yield Err(error);
246 return;
247 }
248 }
249 }
250 Err(error) => {
251 yield Err(error);
252 return;
253 }
254 }
255 }
256 return;
257 }
258 })
259 }
260
261 fn model(&self) -> &str {
262 &self.model
263 }
264
265 fn provider(&self) -> &'static str {
266 self.provider
267 }
268
269 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
270 self.thinking.as_ref()
271 }
272
273 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
280 self.template.capabilities()
281 }
282
283 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
284 self.template.validate_thinking_config(thinking)
285 }
286
287 fn default_max_tokens(&self) -> u32 {
288 self.template.default_max_tokens()
289 }
290
291 fn structured_output_support(&self) -> StructuredOutputSupport {
292 self.template.structured_output_support()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 use std::collections::VecDeque;
301 use std::sync::Mutex as StdMutex;
302 use std::sync::atomic::{AtomicUsize, Ordering};
303
304 use agent_sdk_foundation::llm::{ChatResponse, ContentBlock, StopReason, Usage};
305 use anyhow::Context;
306
307 use crate::streaming::StreamErrorKind;
308
309 #[derive(Clone)]
310 enum MockStreamItem {
311 Ok(StreamDelta),
312 Err(String),
313 }
314
315 #[derive(Clone)]
316 struct MockProvider {
317 model: String,
318 provider_name: &'static str,
319 outcomes: Arc<StdMutex<VecDeque<ChatOutcome>>>,
320 stream_batches: Arc<StdMutex<VecDeque<Vec<MockStreamItem>>>>,
321 chat_calls: Arc<AtomicUsize>,
322 stream_calls: Arc<AtomicUsize>,
323 }
324
325 impl MockProvider {
326 fn new() -> Self {
327 Self {
328 model: "mock-model".to_string(),
329 provider_name: "mock",
330 outcomes: Arc::new(StdMutex::new(VecDeque::new())),
331 stream_batches: Arc::new(StdMutex::new(VecDeque::new())),
332 chat_calls: Arc::new(AtomicUsize::new(0)),
333 stream_calls: Arc::new(AtomicUsize::new(0)),
334 }
335 }
336
337 fn queue_chat(&self, outcome: ChatOutcome) -> Result<()> {
338 self.outcomes
339 .lock()
340 .ok()
341 .context("outcomes lock poisoned")?
342 .push_back(outcome);
343 Ok(())
344 }
345
346 fn queue_stream(&self, batch: Vec<MockStreamItem>) -> Result<()> {
347 self.stream_batches
348 .lock()
349 .ok()
350 .context("stream_batches lock poisoned")?
351 .push_back(batch);
352 Ok(())
353 }
354
355 fn chat_call_count(&self) -> usize {
356 self.chat_calls.load(Ordering::SeqCst)
357 }
358
359 fn stream_call_count(&self) -> usize {
360 self.stream_calls.load(Ordering::SeqCst)
361 }
362 }
363
364 #[async_trait]
365 impl LlmProvider for MockProvider {
366 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
367 self.chat_calls.fetch_add(1, Ordering::SeqCst);
368 let mut queue = self
369 .outcomes
370 .lock()
371 .ok()
372 .context("outcomes lock poisoned")?;
373 queue.pop_front().context("MockProvider: no queued outcome")
374 }
375
376 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
377 self.stream_calls.fetch_add(1, Ordering::SeqCst);
378 let batch: Vec<MockStreamItem> = self
379 .stream_batches
380 .lock()
381 .ok()
382 .and_then(|mut q| q.pop_front())
383 .unwrap_or_else(|| vec![MockStreamItem::Err("no queued stream batch".into())]);
384 Box::pin(async_stream::stream! {
385 for item in batch {
386 match item {
387 MockStreamItem::Ok(delta) => yield Ok(delta),
388 MockStreamItem::Err(msg) => {
389 yield Err(anyhow::anyhow!(msg));
390 return;
391 }
392 }
393 }
394 })
395 }
396
397 fn model(&self) -> &str {
398 &self.model
399 }
400
401 fn provider(&self) -> &'static str {
402 self.provider_name
403 }
404
405 fn default_max_tokens(&self) -> u32 {
408 32_000
409 }
410
411 fn structured_output_support(&self) -> StructuredOutputSupport {
412 StructuredOutputSupport::Native
413 }
414
415 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
416 if thinking.is_some() {
417 Err(anyhow::anyhow!("mock rejects thinking"))
418 } else {
419 Ok(())
420 }
421 }
422 }
423
424 fn success_response() -> ChatResponse {
425 ChatResponse {
426 id: "msg_test".to_string(),
427 content: vec![ContentBlock::Text {
428 text: "ok".to_string(),
429 }],
430 model: "mock-model".to_string(),
431 stop_reason: Some(StopReason::EndTurn),
432 usage: Usage {
433 input_tokens: 1,
434 output_tokens: 1,
435 cached_input_tokens: 0,
436 cache_creation_input_tokens: 0,
437 },
438 }
439 }
440
441 fn empty_request() -> ChatRequest {
442 ChatRequest {
443 system: String::new(),
444 messages: Vec::new(),
445 tools: None,
446 max_tokens: 100,
447 max_tokens_explicit: false,
448 session_id: None,
449 cached_content: None,
450 thinking: None,
451 tool_choice: None,
452 response_format: None,
453 }
454 }
455
456 type BoxedFut = std::pin::Pin<Box<dyn Future<Output = Result<MockProvider>> + Send>>;
457 type RefreshFn = Box<dyn Fn() -> BoxedFut + Send + Sync + 'static>;
458 type Wrapped = RefreshingProvider<MockProvider, RefreshFn>;
459
460 fn wrap_success(mock: &MockProvider, counter: &Arc<AtomicUsize>) -> Wrapped {
461 let counter = Arc::clone(counter);
462 let template = mock.clone();
463 let cb: RefreshFn = Box::new(move || {
464 counter.fetch_add(1, Ordering::SeqCst);
465 let provider = template.clone();
466 Box::pin(async move { Ok(provider) })
467 });
468 RefreshingProvider::new(mock.clone(), cb)
469 }
470
471 fn wrap_failure(
472 mock: &MockProvider,
473 counter: &Arc<AtomicUsize>,
474 error: &'static str,
475 ) -> Wrapped {
476 let counter = Arc::clone(counter);
477 let cb: RefreshFn = Box::new(move || {
478 counter.fetch_add(1, Ordering::SeqCst);
479 Box::pin(async move { Err(anyhow::anyhow!(error)) })
480 });
481 RefreshingProvider::new(mock.clone(), cb)
482 }
483
484 #[test]
486 fn wrapper_delegates_capability_overrides_to_inner() {
487 let mock = MockProvider::new();
488 let refresh_count = Arc::new(AtomicUsize::new(0));
489 let wrapped = wrap_success(&mock, &refresh_count);
490
491 assert_eq!(wrapped.default_max_tokens(), 32_000);
494 assert_eq!(
495 wrapped.structured_output_support(),
496 StructuredOutputSupport::Native
497 );
498 assert!(
499 wrapped
500 .validate_thinking_config(Some(&ThinkingConfig::adaptive()))
501 .is_err()
502 );
503 assert!(wrapped.validate_thinking_config(None).is_ok());
504 }
505
506 #[test]
508 fn is_unauthorized_error_matches_expected_strings() {
509 assert!(is_unauthorized_error("HTTP 401"));
510 assert!(is_unauthorized_error("status=401 Unauthorized"));
511 assert!(is_unauthorized_error("Invalid API key"));
512 assert!(is_unauthorized_error("invalid_api_key"));
513 assert!(is_unauthorized_error("token_expired"));
514 assert!(is_unauthorized_error("Authentication failed"));
515 assert!(is_unauthorized_error("UNAUTHORIZED"));
516
517 assert!(!is_unauthorized_error("rate limited"));
518 assert!(!is_unauthorized_error("network error"));
519 assert!(!is_unauthorized_error(""));
520 assert!(!is_unauthorized_error("internal server error"));
521 }
522
523 #[tokio::test]
525 async fn chat_successful_pass_through_does_not_refresh() -> Result<()> {
526 let mock = MockProvider::new();
527 mock.queue_chat(ChatOutcome::Success(success_response()))?;
528
529 let refresh_count = Arc::new(AtomicUsize::new(0));
530 let wrapped = wrap_success(&mock, &refresh_count);
531
532 let outcome = wrapped.chat(empty_request()).await?;
533 assert!(matches!(outcome, ChatOutcome::Success(_)));
534 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
535 assert_eq!(mock.chat_call_count(), 1);
536 Ok(())
537 }
538
539 #[tokio::test]
541 async fn chat_401_triggers_refresh_and_retries() -> Result<()> {
542 let mock = MockProvider::new();
543 mock.queue_chat(ChatOutcome::InvalidRequest("401 Unauthorized".into()))?;
544 mock.queue_chat(ChatOutcome::Success(success_response()))?;
545
546 let refresh_count = Arc::new(AtomicUsize::new(0));
547 let wrapped = wrap_success(&mock, &refresh_count);
548
549 let outcome = wrapped.chat(empty_request()).await?;
550 assert!(matches!(outcome, ChatOutcome::Success(_)));
551 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
552 assert_eq!(mock.chat_call_count(), 2);
553 Ok(())
554 }
555
556 #[tokio::test]
558 async fn chat_surfaces_original_401_when_refresh_fails() -> Result<()> {
559 let mock = MockProvider::new();
560 mock.queue_chat(ChatOutcome::InvalidRequest(
561 "status=401 Unauthorized".into(),
562 ))?;
563
564 let refresh_count = Arc::new(AtomicUsize::new(0));
565 let wrapped = wrap_failure(&mock, &refresh_count, "refresh callback failed");
566
567 let outcome = wrapped.chat(empty_request()).await?;
568 match outcome {
569 ChatOutcome::InvalidRequest(msg) => assert!(
570 msg.contains("401"),
571 "expected original 401 message, got {msg}"
572 ),
573 other => panic!("expected InvalidRequest, got {other:?}"),
574 }
575 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
576 assert_eq!(mock.chat_call_count(), 1);
577 Ok(())
578 }
579
580 async fn drain(mut stream: StreamBox<'_>) -> Vec<Result<StreamDelta>> {
581 let mut out = Vec::new();
582 while let Some(item) = stream.next().await {
583 out.push(item);
584 }
585 out
586 }
587
588 #[tokio::test]
590 async fn chat_stream_successful_pass_through() -> Result<()> {
591 let mock = MockProvider::new();
592 mock.queue_stream(vec![
593 MockStreamItem::Ok(StreamDelta::TextDelta {
594 delta: "hi".into(),
595 block_index: 0,
596 }),
597 MockStreamItem::Ok(StreamDelta::Done {
598 stop_reason: Some(StopReason::EndTurn),
599 }),
600 ])?;
601
602 let refresh_count = Arc::new(AtomicUsize::new(0));
603 let wrapped = wrap_success(&mock, &refresh_count);
604
605 let deltas = drain(wrapped.chat_stream(empty_request())).await;
606 assert_eq!(deltas.len(), 2);
607 assert!(matches!(
608 deltas[0].as_ref().ok(),
609 Some(StreamDelta::TextDelta { delta, .. }) if delta == "hi"
610 ));
611 assert!(matches!(
612 deltas[1].as_ref().ok(),
613 Some(StreamDelta::Done { .. })
614 ));
615 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
616 assert_eq!(mock.stream_call_count(), 1);
617 Ok(())
618 }
619
620 #[tokio::test]
622 async fn chat_stream_401_before_output_retries() -> Result<()> {
623 let mock = MockProvider::new();
624 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
625 message: "status=401 Unauthorized".into(),
626 kind: StreamErrorKind::InvalidRequest,
627 })])?;
628 mock.queue_stream(vec![
629 MockStreamItem::Ok(StreamDelta::TextDelta {
630 delta: "retried".into(),
631 block_index: 0,
632 }),
633 MockStreamItem::Ok(StreamDelta::Done {
634 stop_reason: Some(StopReason::EndTurn),
635 }),
636 ])?;
637
638 let refresh_count = Arc::new(AtomicUsize::new(0));
639 let wrapped = wrap_success(&mock, &refresh_count);
640
641 let deltas = drain(wrapped.chat_stream(empty_request())).await;
642 assert_eq!(deltas.len(), 2);
644 assert!(matches!(
645 deltas[0].as_ref().ok(),
646 Some(StreamDelta::TextDelta { delta, .. }) if delta == "retried"
647 ));
648 assert!(matches!(
649 deltas[1].as_ref().ok(),
650 Some(StreamDelta::Done { .. })
651 ));
652 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
653 assert_eq!(mock.stream_call_count(), 2);
654 Ok(())
655 }
656
657 #[tokio::test]
659 async fn chat_stream_401_after_output_does_not_retry() -> Result<()> {
660 let mock = MockProvider::new();
661 mock.queue_stream(vec![
662 MockStreamItem::Ok(StreamDelta::TextDelta {
663 delta: "partial".into(),
664 block_index: 0,
665 }),
666 MockStreamItem::Ok(StreamDelta::Error {
667 message: "401 Unauthorized".into(),
668 kind: StreamErrorKind::InvalidRequest,
669 }),
670 ])?;
671
672 let refresh_count = Arc::new(AtomicUsize::new(0));
673 let wrapped = wrap_success(&mock, &refresh_count);
674
675 let deltas = drain(wrapped.chat_stream(empty_request())).await;
676 assert_eq!(deltas.len(), 2);
677 assert!(matches!(
678 deltas[0].as_ref().ok(),
679 Some(StreamDelta::TextDelta { delta, .. }) if delta == "partial"
680 ));
681 assert!(matches!(
682 deltas[1].as_ref().ok(),
683 Some(StreamDelta::Error { message, .. }) if message.contains("401")
684 ));
685 assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
686 assert_eq!(mock.stream_call_count(), 1);
687 Ok(())
688 }
689
690 #[tokio::test]
692 async fn chat_stream_only_one_retry_per_call() -> Result<()> {
693 let mock = MockProvider::new();
694 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
695 message: "status=401 Unauthorized".into(),
696 kind: StreamErrorKind::InvalidRequest,
697 })])?;
698 mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
699 message: "still 401 Unauthorized".into(),
700 kind: StreamErrorKind::InvalidRequest,
701 })])?;
702
703 let refresh_count = Arc::new(AtomicUsize::new(0));
704 let wrapped = wrap_success(&mock, &refresh_count);
705
706 let deltas = drain(wrapped.chat_stream(empty_request())).await;
707 assert_eq!(deltas.len(), 1);
708 assert!(matches!(
709 deltas[0].as_ref().ok(),
710 Some(StreamDelta::Error { message, .. }) if message == "still 401 Unauthorized"
711 ));
712 assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
713 assert_eq!(mock.stream_call_count(), 2);
714 Ok(())
715 }
716
717 #[derive(Clone)]
723 struct ConcurrentMock {
724 model: String,
725 provider_name: &'static str,
726 total_calls: Arc<AtomicUsize>,
727 initial_barrier: Arc<tokio::sync::Barrier>,
728 }
729
730 type CMFut = std::pin::Pin<Box<dyn Future<Output = Result<ConcurrentMock>> + Send>>;
731 type CMRefresh = Box<dyn Fn() -> CMFut + Send + Sync + 'static>;
732
733 #[async_trait]
734 impl LlmProvider for ConcurrentMock {
735 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
736 let call_index = self.total_calls.fetch_add(1, Ordering::SeqCst);
737 if call_index < 2 {
738 self.initial_barrier.wait().await;
739 Ok(ChatOutcome::InvalidRequest("401 Unauthorized".into()))
740 } else {
741 Ok(ChatOutcome::Success(success_response()))
742 }
743 }
744
745 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
746 Box::pin(async_stream::stream! {
747 yield Err(anyhow::anyhow!("chat_stream not used in this test"));
748 })
749 }
750
751 fn model(&self) -> &str {
752 &self.model
753 }
754
755 fn provider(&self) -> &'static str {
756 self.provider_name
757 }
758 }
759
760 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762 async fn chat_concurrent_callers_share_refresh() -> Result<()> {
763 let mock = ConcurrentMock {
764 model: "mock-model".to_string(),
765 provider_name: "mock",
766 total_calls: Arc::new(AtomicUsize::new(0)),
767 initial_barrier: Arc::new(tokio::sync::Barrier::new(2)),
768 };
769 let call_count = Arc::clone(&mock.total_calls);
770 let refresh_count = Arc::new(AtomicUsize::new(0));
771 let refresh_counter = Arc::clone(&refresh_count);
772 let template = mock.clone();
773
774 let cb: CMRefresh = Box::new(move || {
775 refresh_counter.fetch_add(1, Ordering::SeqCst);
776 let provider = template.clone();
777 Box::pin(async move { Ok(provider) })
778 });
779 let wrapped = RefreshingProvider::new(mock, cb);
780
781 let a = wrapped.clone();
782 let b = wrapped.clone();
783 let task_a = tokio::spawn(async move { a.chat(empty_request()).await });
784 let task_b = tokio::spawn(async move { b.chat(empty_request()).await });
785
786 let outcome_a = task_a.await.context("task_a join")??;
787 let outcome_b = task_b.await.context("task_b join")??;
788
789 assert!(matches!(outcome_a, ChatOutcome::Success(_)));
790 assert!(matches!(outcome_b, ChatOutcome::Success(_)));
791 assert_eq!(call_count.load(Ordering::SeqCst), 4);
792 let refreshes = refresh_count.load(Ordering::SeqCst);
793 assert!(
794 refreshes <= 2,
795 "expected at most 2 refresh calls (one per caller), got {refreshes}"
796 );
797 Ok(())
798 }
799}