use std::sync::{Arc, Mutex};
use crate::message::AssistantMessage;
use crate::provider::{EventStream, ModelInfo, Provider, ProviderError, Request};
use crate::stream::{AssistantStreamEvent, StopReason, Usage};
#[doc(hidden)]
pub enum MockResponse {
Events(Vec<AssistantStreamEvent>),
Error(ProviderError),
}
#[doc(hidden)]
pub struct MockProvider {
id: String,
models: Vec<ModelInfo>,
responses: Arc<Mutex<Vec<MockResponse>>>,
call_log: Arc<Mutex<Vec<Request>>>,
}
impl MockProvider {
pub fn new(id: &str, responses: Vec<Vec<AssistantStreamEvent>>) -> Self {
Self::new_with_errors(
id,
responses.into_iter().map(MockResponse::Events).collect(),
)
}
pub fn new_with_errors(id: &str, responses: Vec<MockResponse>) -> Self {
Self {
id: id.to_owned(),
models: vec![ModelInfo {
id: "mock-model".into(),
display_name: "Mock Model".into(),
context_window: 100_000,
max_output_tokens: 4_096,
supports_streaming: true,
supports_thinking: false,
}],
responses: Arc::new(Mutex::new(responses)),
call_log: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn stream_call_count(&self) -> usize {
self.call_log.lock().unwrap().len()
}
pub fn recorded_messages(&self) -> Vec<Vec<crate::message::Message>> {
self.call_log
.lock()
.unwrap()
.iter()
.map(|r| r.messages.clone())
.collect()
}
pub fn call_log_handle(&self) -> Arc<Mutex<Vec<Request>>> {
Arc::clone(&self.call_log)
}
}
pub fn base_assistant() -> AssistantMessage {
AssistantMessage {
content: vec![],
api: crate::ApiKind::Anthropic,
provider: "mock".into(),
model: "mock-model".into(),
response_model: None,
response_id: None,
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp_ms: 0,
}
}
pub fn text_response(text: &str) -> Vec<AssistantStreamEvent> {
let mut partial = base_assistant();
partial
.content
.push(crate::message::AssistantContent::Text { text: text.into() });
vec![
AssistantStreamEvent::Start {
partial: base_assistant(),
},
AssistantStreamEvent::TextDelta {
content_index: 0,
delta: text.into(),
partial: partial.clone(),
},
AssistantStreamEvent::Done {
reason: StopReason::Stop,
message: partial,
},
]
}
pub fn tool_call_response(
tool_call_id: &str,
tool_name: &str,
arguments: &str,
) -> Vec<AssistantStreamEvent> {
let tool_call = crate::message::ToolCall {
id: tool_call_id.into(),
name: tool_name.into(),
arguments: arguments.into(),
};
let mut partial = base_assistant();
partial
.content
.push(crate::message::AssistantContent::ToolCall {
tool_call: tool_call.clone(),
});
vec![
AssistantStreamEvent::Start {
partial: base_assistant(),
},
AssistantStreamEvent::ToolCallEnd {
content_index: 0,
tool_call,
partial: partial.clone(),
},
AssistantStreamEvent::Done {
reason: StopReason::ToolUse,
message: partial,
},
]
}
pub fn error_response(error_message: &str) -> Vec<AssistantStreamEvent> {
let mut partial = base_assistant();
partial.error_message = Some(error_message.into());
vec![
AssistantStreamEvent::Start {
partial: base_assistant(),
},
AssistantStreamEvent::Error {
reason: StopReason::Error,
message: partial,
},
]
}
impl Provider for MockProvider {
fn id(&self) -> &str {
&self.id
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
fn stream(&self, request: Request) -> EventStream {
self.call_log.lock().unwrap().push(request);
let mut responses = self.responses.lock().unwrap();
assert!(
!responses.is_empty(),
"MockProvider: stream() called more times than responses were configured"
);
let response = responses.remove(0);
match response {
MockResponse::Events(events) => {
let stream =
futures_util::stream::iter(events.into_iter().map(Ok::<_, ProviderError>));
Box::pin(stream)
}
MockResponse::Error(e) => {
let stream = futures_util::stream::iter(vec![Err(e)]);
Box::pin(stream)
}
}
}
}