use std::collections::VecDeque;
use std::sync::Mutex;
use std::time::Duration;
use async_trait::async_trait;
use futures::stream::{self, BoxStream};
use crate::request::ChatRequest;
use crate::retry::{RetryClassification, Retryable};
use crate::stream::StreamChunk;
use crate::traits::CompletionModel;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScriptedError(pub String);
impl std::fmt::Display for ScriptedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "scripted error: {}", self.0)
}
}
impl std::error::Error for ScriptedError {}
impl Retryable for ScriptedError {
fn retry_classification(&self) -> RetryClassification {
let s = self.0.as_str();
if s.contains("permanent") {
return RetryClassification::Permanent;
}
if let Some((_, after)) = s.split_once("transient") {
let retry_after = after.strip_prefix(':').and_then(|tail| {
let digits: String = tail.chars().take_while(|c| c.is_ascii_digit()).collect();
digits.parse::<u64>().ok().map(Duration::from_millis)
});
return RetryClassification::Transient { retry_after };
}
RetryClassification::Permanent
}
}
pub type ScriptedTurn = Result<Vec<Result<StreamChunk, ScriptedError>>, ScriptedError>;
pub struct ScriptedModel {
name: String,
model: String,
scripts: Mutex<VecDeque<ScriptedTurn>>,
}
impl ScriptedModel {
pub fn new<I>(turns: I) -> Self
where
I: IntoIterator<Item = Vec<StreamChunk>>,
{
Self::with_turns(
turns
.into_iter()
.map(|chunks| Ok(chunks.into_iter().map(Ok).collect())),
)
}
pub fn with_turns<I>(turns: I) -> Self
where
I: IntoIterator<Item = ScriptedTurn>,
{
Self {
name: "scripted".into(),
model: "scripted".into(),
scripts: Mutex::new(turns.into_iter().collect()),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
}
#[async_trait]
impl CompletionModel for ScriptedModel {
type Error = ScriptedError;
fn name(&self) -> &str {
&self.name
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
_req: ChatRequest,
) -> Result<BoxStream<'static, Result<StreamChunk, Self::Error>>, Self::Error> {
let next = self.scripts.lock().unwrap().pop_front();
match next {
None => Ok(Box::pin(stream::empty())),
Some(Err(e)) => Err(e),
Some(Ok(chunks)) => Ok(Box::pin(stream::iter(chunks))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::{FinishReason, Usage};
use futures::StreamExt;
#[tokio::test]
async fn replays_chunks_in_order() {
let model = ScriptedModel::new([vec![
StreamChunk::TextDelta {
delta: "hello".into(),
},
StreamChunk::TurnFinished {
reason: FinishReason::EndTurn,
usage: Usage::default(),
service_tier: None,
},
]]);
let stream = model
.chat_stream(ChatRequest::new(vec![], 0))
.await
.unwrap();
let chunks: Vec<_> = stream.collect().await;
assert_eq!(chunks.len(), 2);
}
#[tokio::test]
async fn surfaces_mid_stream_error_after_chunks() {
let model = ScriptedModel::with_turns([Ok(vec![
Ok(StreamChunk::TextDelta {
delta: "partial".into(),
}),
Err(ScriptedError("connection dropped".into())),
])]);
let stream = model
.chat_stream(ChatRequest::new(vec![], 0))
.await
.expect("chat_stream should open the stream");
let chunks: Vec<_> = stream.collect().await;
assert_eq!(chunks.len(), 2, "expected one Ok chunk then one Err");
assert!(matches!(chunks[0], Ok(StreamChunk::TextDelta { .. })));
match &chunks[1] {
Err(ScriptedError(msg)) => assert_eq!(msg, "connection dropped"),
other => panic!("expected mid-stream Err, got {other:?}"),
}
}
#[tokio::test]
async fn surfaces_setup_time_error_per_turn() {
let model = ScriptedModel::with_turns([Err(ScriptedError("rate limited".into()))]);
let result = model.chat_stream(ChatRequest::new(vec![], 0)).await;
match result {
Err(ScriptedError(msg)) => assert_eq!(msg, "rate limited"),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn scripted_error_is_retryable_by_convention() {
assert_eq!(
ScriptedError("permanent: bad auth".into()).retry_classification(),
RetryClassification::Permanent,
);
assert_eq!(
ScriptedError("transient".into()).retry_classification(),
RetryClassification::Transient { retry_after: None },
);
assert_eq!(
ScriptedError("transient:75 ms".into()).retry_classification(),
RetryClassification::Transient {
retry_after: Some(Duration::from_millis(75))
},
);
assert_eq!(
ScriptedError("something else entirely".into()).retry_classification(),
RetryClassification::Permanent,
);
}
}