use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use crate::model_cache::RemoteModelInfo;
use crate::provider::{
ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, ToolDefinition,
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone)]
pub struct MockProvider {
responses: Arc<Mutex<VecDeque<String>>>,
pub default_response: String,
pub embedding: Vec<f32>,
pub supports_embeddings: bool,
pub streaming: bool,
pub fail_chat: bool,
pub delay_ms: u64,
errors: Arc<Mutex<VecDeque<crate::LlmError>>>,
recorded: Option<Arc<Mutex<Vec<Vec<Message>>>>>,
tool_responses: Arc<Mutex<VecDeque<ChatResponse>>>,
pub tool_call_count: Arc<Mutex<u32>>,
pub models: Vec<RemoteModelInfo>,
pub name_override: Option<String>,
pub embed_invalid_input: bool,
pub embed_call_count: Arc<std::sync::atomic::AtomicU64>,
pub embed_delay_ms: u64,
}
impl Default for MockProvider {
fn default() -> Self {
Self {
responses: Arc::new(Mutex::new(VecDeque::new())),
default_response: "mock response".into(),
embedding: vec![0.0; 384],
supports_embeddings: false,
streaming: false,
fail_chat: false,
delay_ms: 0,
errors: Arc::new(Mutex::new(VecDeque::new())),
recorded: None,
tool_responses: Arc::new(Mutex::new(VecDeque::new())),
tool_call_count: Arc::new(Mutex::new(0)),
models: vec![],
name_override: None,
embed_invalid_input: false,
embed_call_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
embed_delay_ms: 0,
}
}
}
impl MockProvider {
#[must_use]
pub fn with_responses(responses: Vec<String>) -> Self {
Self {
responses: Arc::new(Mutex::new(VecDeque::from(responses))),
..Self::default()
}
}
#[must_use]
pub fn failing() -> Self {
Self {
fail_chat: true,
..Self::default()
}
}
#[must_use]
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name_override = Some(name.into());
self
}
#[must_use]
pub fn with_embed_invalid_input(mut self) -> Self {
self.embed_invalid_input = true;
self.supports_embeddings = true;
self
}
#[must_use]
pub fn with_errors(mut self, errors: Vec<crate::LlmError>) -> Self {
self.errors = Arc::new(Mutex::new(VecDeque::from(errors)));
self
}
#[must_use]
pub fn with_streaming(mut self) -> Self {
self.streaming = true;
self
}
#[must_use]
pub fn with_delay(mut self, ms: u64) -> Self {
self.delay_ms = ms;
self
}
#[must_use]
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = embedding;
self.supports_embeddings = true;
self
}
#[must_use]
pub fn with_embed_delay(mut self, ms: u64) -> Self {
self.embed_delay_ms = ms;
self.supports_embeddings = true;
self
}
#[must_use]
pub fn with_recording(mut self) -> (Self, Arc<Mutex<Vec<Vec<Message>>>>) {
let buf = Arc::new(Mutex::new(Vec::new()));
self.recorded = Some(Arc::clone(&buf));
(self, buf)
}
#[must_use]
pub fn with_generation_overrides(self, _overrides: GenerationOverrides) -> Self {
self
}
#[must_use]
pub fn with_models(mut self, models: Vec<RemoteModelInfo>) -> Self {
self.models = models;
self
}
#[must_use]
pub fn with_tool_use(mut self, responses: Vec<ChatResponse>) -> (Self, Arc<Mutex<u32>>) {
self.tool_responses = Arc::new(Mutex::new(VecDeque::from(responses)));
let counter = Arc::clone(&self.tool_call_count);
(self, counter)
}
}
impl LlmProvider for MockProvider {
#[allow(clippy::unnecessary_literal_bound)]
fn name(&self) -> &str {
self.name_override.as_deref().unwrap_or("mock")
}
async fn chat(&self, messages: &[Message]) -> Result<String, crate::LlmError> {
if self.delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
}
if let Some(buf) = &self.recorded
&& let Ok(mut guard) = buf.lock()
{
guard.push(messages.to_vec());
}
if self.fail_chat {
return Err(crate::LlmError::Other("mock LLM error".into()));
}
if let Ok(mut errors) = self.errors.lock()
&& !errors.is_empty()
{
return Err(errors.pop_front().expect("non-empty"));
}
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
Ok(self.default_response.clone())
} else {
Ok(responses.pop_front().expect("non-empty"))
}
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, crate::LlmError> {
let response = self.chat(messages).await?;
let chunks: Vec<Result<crate::StreamChunk, crate::LlmError>> = response
.chars()
.map(|c| Ok(crate::StreamChunk::Content(c.to_string())))
.collect();
Ok(Box::pin(tokio_stream::iter(chunks)))
}
fn supports_streaming(&self) -> bool {
self.streaming
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, crate::LlmError> {
self.embed_call_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if self.embed_delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(self.embed_delay_ms)).await;
}
if let Ok(mut errors) = self.errors.lock()
&& !errors.is_empty()
{
return Err(errors.pop_front().expect("non-empty"));
}
if self.embed_invalid_input {
return Err(crate::LlmError::InvalidInput {
provider: self.name().to_owned(),
message: "input exceeds maximum sequence length".into(),
});
}
if self.supports_embeddings {
Ok(self.embedding.clone())
} else {
Err(crate::LlmError::EmbedUnsupported {
provider: "mock".into(),
})
}
}
fn supports_embeddings(&self) -> bool {
self.supports_embeddings
}
async fn chat_with_tools(
&self,
messages: &[Message],
_tools: &[ToolDefinition],
) -> Result<ChatResponse, crate::LlmError> {
*self.tool_call_count.lock().unwrap() += 1;
let queued = self.tool_responses.lock().unwrap().pop_front();
if let Some(response) = queued {
return Ok(response);
}
Ok(ChatResponse::Text(self.chat(messages).await?))
}
}