use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
use synaptic_models::{RetryChatModel, RetryPolicy};
use tokio::sync::Mutex;
struct FailThenSucceedModel {
attempts: Arc<Mutex<usize>>,
fail_count: usize,
error_kind: &'static str,
}
impl FailThenSucceedModel {
fn new(fail_count: usize, error_kind: &'static str) -> Self {
Self {
attempts: Arc::new(Mutex::new(0)),
fail_count,
error_kind,
}
}
}
#[async_trait::async_trait]
impl ChatModel for FailThenSucceedModel {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, SynapticError> {
let mut attempts = self.attempts.lock().await;
*attempts += 1;
if *attempts <= self.fail_count {
match self.error_kind {
"rate_limit" => Err(SynapticError::RateLimit("rate limited".to_string())),
"timeout" => Err(SynapticError::Timeout("timed out".to_string())),
_ => Err(SynapticError::Model("non-retryable".to_string())),
}
} else {
Ok(ChatResponse {
message: Message::ai("success"),
usage: None,
})
}
}
}
#[tokio::test]
async fn retries_on_rate_limit() {
let inner = Arc::new(FailThenSucceedModel::new(2, "rate_limit"));
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
};
let model = RetryChatModel::new(inner.clone(), policy);
let request = ChatRequest::new(vec![Message::human("hi")]);
let response = model.chat(request).await.unwrap();
assert_eq!(response.message.content(), "success");
assert_eq!(*inner.attempts.lock().await, 3);
}
#[tokio::test]
async fn retries_on_timeout() {
let inner = Arc::new(FailThenSucceedModel::new(1, "timeout"));
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
};
let model = RetryChatModel::new(inner.clone(), policy);
let request = ChatRequest::new(vec![Message::human("hi")]);
let response = model.chat(request).await.unwrap();
assert_eq!(response.message.content(), "success");
assert_eq!(*inner.attempts.lock().await, 2);
}
#[tokio::test]
async fn does_not_retry_non_retryable_error() {
let inner = Arc::new(FailThenSucceedModel::new(1, "model"));
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
};
let model = RetryChatModel::new(inner.clone(), policy);
let request = ChatRequest::new(vec![Message::human("hi")]);
let err = model.chat(request).await.unwrap_err();
assert!(err.to_string().contains("non-retryable"));
assert_eq!(*inner.attempts.lock().await, 1);
}
#[tokio::test]
async fn exhausts_retries() {
let inner = Arc::new(FailThenSucceedModel::new(5, "rate_limit"));
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
};
let model = RetryChatModel::new(inner.clone(), policy);
let request = ChatRequest::new(vec![Message::human("hi")]);
let err = model.chat(request).await.unwrap_err();
assert!(err.to_string().contains("rate limit"));
assert_eq!(*inner.attempts.lock().await, 3);
}
struct StreamOnceModel;
#[async_trait::async_trait]
impl ChatModel for StreamOnceModel {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, SynapticError> {
Ok(ChatResponse {
message: Message::ai("streamed"),
usage: None,
})
}
fn stream_chat(&self, _request: ChatRequest) -> ChatStream<'_> {
Box::pin(async_stream::stream! {
yield Ok(synaptic_core::AIMessageChunk {
content: "chunk".to_string(),
..Default::default()
});
})
}
}
#[tokio::test]
async fn retry_stream_chat_succeeds() {
let inner = Arc::new(StreamOnceModel);
let policy = RetryPolicy {
max_attempts: 2,
base_delay: Duration::from_millis(1),
};
let model = RetryChatModel::new(inner, policy);
let request = ChatRequest::new(vec![Message::human("hi")]);
let chunks: Vec<_> = model
.stream_chat(request)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, "chunk");
}