use std::future::Future;
use std::sync::Arc;
use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ThinkingConfig};
use anyhow::Result;
use async_trait::async_trait;
use futures::StreamExt;
use tokio::sync::Mutex;
use crate::provider::LlmProvider;
use crate::streaming::{StreamBox, StreamDelta};
pub struct RefreshingProvider<P, F> {
inner: Arc<Mutex<P>>,
refresh: Arc<F>,
model: String,
provider: &'static str,
thinking: Option<ThinkingConfig>,
}
impl<P, F> Clone for RefreshingProvider<P, F> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
refresh: Arc::clone(&self.refresh),
model: self.model.clone(),
provider: self.provider,
thinking: self.thinking.clone(),
}
}
}
impl<P, F, Fut> RefreshingProvider<P, F>
where
P: LlmProvider + Clone + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<P>> + Send + 'static,
{
#[must_use]
pub fn new(inner: P, refresh: F) -> Self {
let model = inner.model().to_string();
let provider = inner.provider();
let thinking = inner.configured_thinking().cloned();
Self {
inner: Arc::new(Mutex::new(inner)),
refresh: Arc::new(refresh),
model,
provider,
thinking,
}
}
async fn snapshot(&self) -> P {
self.inner.lock().await.clone()
}
async fn run_refresh(&self) -> Result<()> {
let fresh = (self.refresh)().await?;
*self.inner.lock().await = fresh;
Ok(())
}
}
#[must_use]
pub fn is_unauthorized_error(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
lower.contains(" 401")
|| lower.contains("status=401")
|| lower.contains("unauthorized")
|| lower.contains("authentication")
|| lower.contains("token_expired")
|| lower.contains("invalid api key")
|| lower.contains("invalid_api_key")
}
#[async_trait]
impl<P, F, Fut> LlmProvider for RefreshingProvider<P, F>
where
P: LlmProvider + Clone + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<P>> + Send + 'static,
{
async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
let outcome = self.snapshot().await.chat(request.clone()).await?;
if let ChatOutcome::InvalidRequest(message) = &outcome
&& is_unauthorized_error(message)
{
match self.run_refresh().await {
Ok(()) => return self.snapshot().await.chat(request).await,
Err(error) => {
log::warn!("RefreshingProvider refresh after 401 failed: {error:#}");
}
}
}
Ok(outcome)
}
fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
let this = self.clone();
Box::pin(async_stream::stream! {
let mut refreshed = false;
'attempts: loop {
let provider = this.snapshot().await;
let mut stream = provider.chat_stream(request.clone());
let mut saw_output = false;
while let Some(item) = stream.next().await {
match item {
Ok(StreamDelta::Error { message, kind })
if !saw_output
&& !refreshed
&& is_unauthorized_error(&message) =>
{
match this.run_refresh().await {
Ok(()) => {
refreshed = true;
continue 'attempts;
}
Err(error) => {
log::warn!(
"RefreshingProvider refresh after streaming 401 failed: {error:#}"
);
yield Ok(StreamDelta::Error { message, kind });
return;
}
}
}
Ok(delta) => {
if matches!(
delta,
StreamDelta::TextDelta { .. }
| StreamDelta::ThinkingDelta { .. }
| StreamDelta::ToolUseStart { .. }
| StreamDelta::ToolInputDelta { .. }
| StreamDelta::SignatureDelta { .. }
| StreamDelta::RedactedThinking { .. }
) {
saw_output = true;
}
let done = matches!(delta, StreamDelta::Done { .. });
yield Ok(delta);
if done {
return;
}
}
Err(error)
if !saw_output
&& !refreshed
&& is_unauthorized_error(&error.to_string()) =>
{
match this.run_refresh().await {
Ok(()) => {
refreshed = true;
continue 'attempts;
}
Err(refresh_error) => {
log::warn!(
"RefreshingProvider refresh after stream failure failed: {refresh_error:#}"
);
yield Err(error);
return;
}
}
}
Err(error) => {
yield Err(error);
return;
}
}
}
return;
}
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
self.provider
}
fn configured_thinking(&self) -> Option<&ThinkingConfig> {
self.thinking.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use agent_sdk_foundation::llm::{ChatResponse, ContentBlock, StopReason, Usage};
use anyhow::Context;
use crate::streaming::StreamErrorKind;
#[derive(Clone)]
enum MockStreamItem {
Ok(StreamDelta),
Err(String),
}
#[derive(Clone)]
struct MockProvider {
model: String,
provider_name: &'static str,
outcomes: Arc<StdMutex<VecDeque<ChatOutcome>>>,
stream_batches: Arc<StdMutex<VecDeque<Vec<MockStreamItem>>>>,
chat_calls: Arc<AtomicUsize>,
stream_calls: Arc<AtomicUsize>,
}
impl MockProvider {
fn new() -> Self {
Self {
model: "mock-model".to_string(),
provider_name: "mock",
outcomes: Arc::new(StdMutex::new(VecDeque::new())),
stream_batches: Arc::new(StdMutex::new(VecDeque::new())),
chat_calls: Arc::new(AtomicUsize::new(0)),
stream_calls: Arc::new(AtomicUsize::new(0)),
}
}
fn queue_chat(&self, outcome: ChatOutcome) -> Result<()> {
self.outcomes
.lock()
.ok()
.context("outcomes lock poisoned")?
.push_back(outcome);
Ok(())
}
fn queue_stream(&self, batch: Vec<MockStreamItem>) -> Result<()> {
self.stream_batches
.lock()
.ok()
.context("stream_batches lock poisoned")?
.push_back(batch);
Ok(())
}
fn chat_call_count(&self) -> usize {
self.chat_calls.load(Ordering::SeqCst)
}
fn stream_call_count(&self) -> usize {
self.stream_calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
self.chat_calls.fetch_add(1, Ordering::SeqCst);
let mut queue = self
.outcomes
.lock()
.ok()
.context("outcomes lock poisoned")?;
queue.pop_front().context("MockProvider: no queued outcome")
}
fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
self.stream_calls.fetch_add(1, Ordering::SeqCst);
let batch: Vec<MockStreamItem> = self
.stream_batches
.lock()
.ok()
.and_then(|mut q| q.pop_front())
.unwrap_or_else(|| vec![MockStreamItem::Err("no queued stream batch".into())]);
Box::pin(async_stream::stream! {
for item in batch {
match item {
MockStreamItem::Ok(delta) => yield Ok(delta),
MockStreamItem::Err(msg) => {
yield Err(anyhow::anyhow!(msg));
return;
}
}
}
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
self.provider_name
}
}
fn success_response() -> ChatResponse {
ChatResponse {
id: "msg_test".to_string(),
content: vec![ContentBlock::Text {
text: "ok".to_string(),
}],
model: "mock-model".to_string(),
stop_reason: Some(StopReason::EndTurn),
usage: Usage {
input_tokens: 1,
output_tokens: 1,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
},
}
}
fn empty_request() -> ChatRequest {
ChatRequest {
system: String::new(),
messages: Vec::new(),
tools: None,
max_tokens: 100,
max_tokens_explicit: false,
session_id: None,
cached_content: None,
thinking: None,
tool_choice: None,
response_format: None,
}
}
type BoxedFut = std::pin::Pin<Box<dyn Future<Output = Result<MockProvider>> + Send>>;
type RefreshFn = Box<dyn Fn() -> BoxedFut + Send + Sync + 'static>;
type Wrapped = RefreshingProvider<MockProvider, RefreshFn>;
fn wrap_success(mock: &MockProvider, counter: &Arc<AtomicUsize>) -> Wrapped {
let counter = Arc::clone(counter);
let template = mock.clone();
let cb: RefreshFn = Box::new(move || {
counter.fetch_add(1, Ordering::SeqCst);
let provider = template.clone();
Box::pin(async move { Ok(provider) })
});
RefreshingProvider::new(mock.clone(), cb)
}
fn wrap_failure(
mock: &MockProvider,
counter: &Arc<AtomicUsize>,
error: &'static str,
) -> Wrapped {
let counter = Arc::clone(counter);
let cb: RefreshFn = Box::new(move || {
counter.fetch_add(1, Ordering::SeqCst);
Box::pin(async move { Err(anyhow::anyhow!(error)) })
});
RefreshingProvider::new(mock.clone(), cb)
}
#[test]
fn is_unauthorized_error_matches_expected_strings() {
assert!(is_unauthorized_error("HTTP 401"));
assert!(is_unauthorized_error("status=401 Unauthorized"));
assert!(is_unauthorized_error("Invalid API key"));
assert!(is_unauthorized_error("invalid_api_key"));
assert!(is_unauthorized_error("token_expired"));
assert!(is_unauthorized_error("Authentication failed"));
assert!(is_unauthorized_error("UNAUTHORIZED"));
assert!(!is_unauthorized_error("rate limited"));
assert!(!is_unauthorized_error("network error"));
assert!(!is_unauthorized_error(""));
assert!(!is_unauthorized_error("internal server error"));
}
#[tokio::test]
async fn chat_successful_pass_through_does_not_refresh() -> Result<()> {
let mock = MockProvider::new();
mock.queue_chat(ChatOutcome::Success(success_response()))?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let outcome = wrapped.chat(empty_request()).await?;
assert!(matches!(outcome, ChatOutcome::Success(_)));
assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
assert_eq!(mock.chat_call_count(), 1);
Ok(())
}
#[tokio::test]
async fn chat_401_triggers_refresh_and_retries() -> Result<()> {
let mock = MockProvider::new();
mock.queue_chat(ChatOutcome::InvalidRequest("401 Unauthorized".into()))?;
mock.queue_chat(ChatOutcome::Success(success_response()))?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let outcome = wrapped.chat(empty_request()).await?;
assert!(matches!(outcome, ChatOutcome::Success(_)));
assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
assert_eq!(mock.chat_call_count(), 2);
Ok(())
}
#[tokio::test]
async fn chat_surfaces_original_401_when_refresh_fails() -> Result<()> {
let mock = MockProvider::new();
mock.queue_chat(ChatOutcome::InvalidRequest(
"status=401 Unauthorized".into(),
))?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_failure(&mock, &refresh_count, "refresh callback failed");
let outcome = wrapped.chat(empty_request()).await?;
match outcome {
ChatOutcome::InvalidRequest(msg) => assert!(
msg.contains("401"),
"expected original 401 message, got {msg}"
),
other => panic!("expected InvalidRequest, got {other:?}"),
}
assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
assert_eq!(mock.chat_call_count(), 1);
Ok(())
}
async fn drain(mut stream: StreamBox<'_>) -> Vec<Result<StreamDelta>> {
let mut out = Vec::new();
while let Some(item) = stream.next().await {
out.push(item);
}
out
}
#[tokio::test]
async fn chat_stream_successful_pass_through() -> Result<()> {
let mock = MockProvider::new();
mock.queue_stream(vec![
MockStreamItem::Ok(StreamDelta::TextDelta {
delta: "hi".into(),
block_index: 0,
}),
MockStreamItem::Ok(StreamDelta::Done {
stop_reason: Some(StopReason::EndTurn),
}),
])?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let deltas = drain(wrapped.chat_stream(empty_request())).await;
assert_eq!(deltas.len(), 2);
assert!(matches!(
deltas[0].as_ref().ok(),
Some(StreamDelta::TextDelta { delta, .. }) if delta == "hi"
));
assert!(matches!(
deltas[1].as_ref().ok(),
Some(StreamDelta::Done { .. })
));
assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
assert_eq!(mock.stream_call_count(), 1);
Ok(())
}
#[tokio::test]
async fn chat_stream_401_before_output_retries() -> Result<()> {
let mock = MockProvider::new();
mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
message: "status=401 Unauthorized".into(),
kind: StreamErrorKind::InvalidRequest,
})])?;
mock.queue_stream(vec![
MockStreamItem::Ok(StreamDelta::TextDelta {
delta: "retried".into(),
block_index: 0,
}),
MockStreamItem::Ok(StreamDelta::Done {
stop_reason: Some(StopReason::EndTurn),
}),
])?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let deltas = drain(wrapped.chat_stream(empty_request())).await;
assert_eq!(deltas.len(), 2);
assert!(matches!(
deltas[0].as_ref().ok(),
Some(StreamDelta::TextDelta { delta, .. }) if delta == "retried"
));
assert!(matches!(
deltas[1].as_ref().ok(),
Some(StreamDelta::Done { .. })
));
assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
assert_eq!(mock.stream_call_count(), 2);
Ok(())
}
#[tokio::test]
async fn chat_stream_401_after_output_does_not_retry() -> Result<()> {
let mock = MockProvider::new();
mock.queue_stream(vec![
MockStreamItem::Ok(StreamDelta::TextDelta {
delta: "partial".into(),
block_index: 0,
}),
MockStreamItem::Ok(StreamDelta::Error {
message: "401 Unauthorized".into(),
kind: StreamErrorKind::InvalidRequest,
}),
])?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let deltas = drain(wrapped.chat_stream(empty_request())).await;
assert_eq!(deltas.len(), 2);
assert!(matches!(
deltas[0].as_ref().ok(),
Some(StreamDelta::TextDelta { delta, .. }) if delta == "partial"
));
assert!(matches!(
deltas[1].as_ref().ok(),
Some(StreamDelta::Error { message, .. }) if message.contains("401")
));
assert_eq!(refresh_count.load(Ordering::SeqCst), 0);
assert_eq!(mock.stream_call_count(), 1);
Ok(())
}
#[tokio::test]
async fn chat_stream_only_one_retry_per_call() -> Result<()> {
let mock = MockProvider::new();
mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
message: "status=401 Unauthorized".into(),
kind: StreamErrorKind::InvalidRequest,
})])?;
mock.queue_stream(vec![MockStreamItem::Ok(StreamDelta::Error {
message: "still 401 Unauthorized".into(),
kind: StreamErrorKind::InvalidRequest,
})])?;
let refresh_count = Arc::new(AtomicUsize::new(0));
let wrapped = wrap_success(&mock, &refresh_count);
let deltas = drain(wrapped.chat_stream(empty_request())).await;
assert_eq!(deltas.len(), 1);
assert!(matches!(
deltas[0].as_ref().ok(),
Some(StreamDelta::Error { message, .. }) if message == "still 401 Unauthorized"
));
assert_eq!(refresh_count.load(Ordering::SeqCst), 1);
assert_eq!(mock.stream_call_count(), 2);
Ok(())
}
#[derive(Clone)]
struct ConcurrentMock {
model: String,
provider_name: &'static str,
total_calls: Arc<AtomicUsize>,
initial_barrier: Arc<tokio::sync::Barrier>,
}
type CMFut = std::pin::Pin<Box<dyn Future<Output = Result<ConcurrentMock>> + Send>>;
type CMRefresh = Box<dyn Fn() -> CMFut + Send + Sync + 'static>;
#[async_trait]
impl LlmProvider for ConcurrentMock {
async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
let call_index = self.total_calls.fetch_add(1, Ordering::SeqCst);
if call_index < 2 {
self.initial_barrier.wait().await;
Ok(ChatOutcome::InvalidRequest("401 Unauthorized".into()))
} else {
Ok(ChatOutcome::Success(success_response()))
}
}
fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
Box::pin(async_stream::stream! {
yield Err(anyhow::anyhow!("chat_stream not used in this test"));
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
self.provider_name
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn chat_concurrent_callers_share_refresh() -> Result<()> {
let mock = ConcurrentMock {
model: "mock-model".to_string(),
provider_name: "mock",
total_calls: Arc::new(AtomicUsize::new(0)),
initial_barrier: Arc::new(tokio::sync::Barrier::new(2)),
};
let call_count = Arc::clone(&mock.total_calls);
let refresh_count = Arc::new(AtomicUsize::new(0));
let refresh_counter = Arc::clone(&refresh_count);
let template = mock.clone();
let cb: CMRefresh = Box::new(move || {
refresh_counter.fetch_add(1, Ordering::SeqCst);
let provider = template.clone();
Box::pin(async move { Ok(provider) })
});
let wrapped = RefreshingProvider::new(mock, cb);
let a = wrapped.clone();
let b = wrapped.clone();
let task_a = tokio::spawn(async move { a.chat(empty_request()).await });
let task_b = tokio::spawn(async move { b.chat(empty_request()).await });
let outcome_a = task_a.await.context("task_a join")??;
let outcome_b = task_b.await.context("task_b join")??;
assert!(matches!(outcome_a, ChatOutcome::Success(_)));
assert!(matches!(outcome_b, ChatOutcome::Success(_)));
assert_eq!(call_count.load(Ordering::SeqCst), 4);
let refreshes = refresh_count.load(Ordering::SeqCst);
assert!(
refreshes <= 2,
"expected at most 2 refresh calls (one per caller), got {refreshes}"
);
Ok(())
}
}