use super::{Request, Response, StreamChunk};
use anyhow::Result;
use futures_core::Stream;
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
#[derive(Clone)]
pub struct TestModel {
responses: Arc<Mutex<VecDeque<Response>>>,
chunks: Arc<Mutex<VecDeque<Vec<StreamChunk>>>>,
model_name: String,
}
impl TestModel {
pub fn new(responses: Vec<Response>) -> Self {
Self {
responses: Arc::new(Mutex::new(responses.into())),
chunks: Arc::new(Mutex::new(VecDeque::new())),
model_name: "test-model".into(),
}
}
pub fn with_chunks(chunks: Vec<Vec<StreamChunk>>) -> Self {
Self {
responses: Arc::new(Mutex::new(VecDeque::new())),
chunks: Arc::new(Mutex::new(chunks.into())),
model_name: "test-model".into(),
}
}
pub fn with_both(responses: Vec<Response>, chunks: Vec<Vec<StreamChunk>>) -> Self {
Self {
responses: Arc::new(Mutex::new(responses.into())),
chunks: Arc::new(Mutex::new(chunks.into())),
model_name: "test-model".into(),
}
}
}
impl super::Model for TestModel {
async fn send(&self, _request: &Request) -> Result<Response> {
let mut responses = self.responses.lock().unwrap();
responses
.pop_front()
.ok_or_else(|| anyhow::anyhow!("TestModel: no more scripted responses for send()"))
}
fn stream(&self, _request: Request) -> impl Stream<Item = Result<StreamChunk>> + Send {
let chunks = {
let mut all = self.chunks.lock().unwrap();
all.pop_front()
};
async_stream::stream! {
match chunks {
Some(chunks) => {
for chunk in chunks {
yield Ok(chunk);
}
}
None => {
yield Err(anyhow::anyhow!("TestModel: no more scripted chunks for stream()"));
}
}
}
}
fn active_model(&self) -> String {
self.model_name.clone()
}
}