use async_trait::async_trait;
use futures::stream;
use crate::llm::{
BoxStream, CallOptions, ChatModel, LlmError, Message, MessageChunk, ToolCall, ToolCallChunk,
ToolDefinition,
};
#[derive(Debug, thiserror::Error)]
#[error("Mock error")]
struct MockError;
#[derive(Clone, Debug)]
pub struct MockChatModel {
model_name: String,
response: Option<String>,
tool_calls: Vec<ToolCall>,
tools: Vec<ToolDefinition>,
should_error: bool,
}
impl MockChatModel {
#[must_use]
pub fn new(model_name: impl Into<String>) -> Self {
Self {
model_name: model_name.into(),
response: None,
tool_calls: Vec::new(),
tools: Vec::new(),
should_error: false,
}
}
#[must_use]
pub fn with_response(mut self, response: impl Into<String>) -> Self {
self.response = Some(response.into());
self
}
#[must_use]
pub fn with_tool_calls(mut self, calls: Vec<ToolCall>) -> Self {
self.tool_calls = calls;
self
}
#[must_use]
pub const fn with_error(mut self) -> Self {
self.should_error = true;
self
}
}
impl Default for MockChatModel {
fn default() -> Self {
Self::new("mock-model")
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl ChatModel for MockChatModel {
async fn invoke(
&self,
_messages: &[Message],
_options: Option<&CallOptions>,
) -> Result<Message, LlmError> {
if self.should_error {
return Err(LlmError::Other(Box::new(MockError)));
}
let content = self.response.clone().unwrap_or_default();
let msg = Message::ai_with_tool_calls(content, self.tool_calls.clone());
Ok(msg)
}
fn stream(
&self,
_messages: &[Message],
_options: Option<&CallOptions>,
) -> BoxStream<'_, Result<MessageChunk, LlmError>> {
if self.should_error {
let error = LlmError::Other(Box::new(MockError));
return Box::pin(stream::once(async move { Err(error) }));
}
let content = self.response.clone().unwrap_or_default();
let chunk = MessageChunk {
content,
tool_call_chunks: self
.tool_calls
.iter()
.enumerate()
.map(|(index, call)| ToolCallChunk {
id: Some(call.id.clone()),
name: Some(call.name.clone()),
args_delta: call.arguments.to_string(),
index,
})
.collect(),
usage_delta: None,
};
Box::pin(stream::once(async move { Ok(chunk) }))
}
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
let mut new_model = self.clone();
new_model.tools = tools;
new_model
}
fn model_name(&self) -> &str {
&self.model_name
}
}