use crate::error::{LlmError, ReactError, Result};
use crate::llm::types::{
DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, Message, ToolCall,
};
use crate::llm::{ChatChunk, ChatRequest, ChatResponse, LlmClient};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
enum MockLlmResponse {
Content(String),
ToolCalls(Vec<ToolCall>),
Err(ReactError),
}
pub struct MockLlmClient {
model_name: String,
responses: Arc<Mutex<VecDeque<MockLlmResponse>>>,
calls: Arc<Mutex<Vec<Vec<Message>>>>,
}
impl Default for MockLlmClient {
fn default() -> Self {
Self::new()
}
}
impl MockLlmClient {
pub fn new() -> Self {
Self {
model_name: "mock-model".to_string(),
responses: Arc::new(Mutex::new(VecDeque::new())),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = name.into();
self
}
pub fn with_response(self, text: impl Into<String>) -> Self {
self.responses
.lock()
.unwrap()
.push_back(MockLlmResponse::Content(text.into()));
self
}
pub fn with_responses(self, texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
{
let mut q = self.responses.lock().unwrap();
for t in texts {
q.push_back(MockLlmResponse::Content(t.into()));
}
}
self
}
pub fn with_error(self, err: ReactError) -> Self {
self.responses
.lock()
.unwrap()
.push_back(MockLlmResponse::Err(err));
self
}
pub fn then_tool_call(
self,
id: impl Into<String>,
function_name: impl Into<String>,
arguments: impl Into<String>,
) -> Self {
let tc = ToolCall {
id: id.into(),
call_type: "function".to_string(),
function: FunctionCall {
name: function_name.into(),
arguments: arguments.into(),
},
};
self.responses
.lock()
.unwrap()
.push_back(MockLlmResponse::ToolCalls(vec![tc]));
self
}
pub fn then_tool_calls(self, calls: Vec<ToolCall>) -> Self {
self.responses
.lock()
.unwrap()
.push_back(MockLlmResponse::ToolCalls(calls));
self
}
pub fn with_network_error(self, msg: impl Into<String>) -> Self {
self.with_error(ReactError::Llm(Box::new(LlmError::NetworkError(
msg.into(),
))))
}
pub fn with_rate_limit_error(self) -> Self {
self.with_error(ReactError::Llm(Box::new(LlmError::ApiError {
status: 429,
message: "Too Many Requests".to_string(),
})))
}
pub fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
pub fn last_messages(&self) -> Option<Vec<Message>> {
self.calls.lock().unwrap().last().cloned()
}
pub fn all_calls(&self) -> Vec<Vec<Message>> {
self.calls.lock().unwrap().clone()
}
pub fn remaining(&self) -> usize {
self.responses.lock().unwrap().len()
}
pub fn reset_calls(&self) {
self.calls.lock().unwrap().clear();
}
fn pop_response(&self) -> Result<PopResult> {
match self.responses.lock().unwrap().pop_front() {
Some(MockLlmResponse::Content(text)) => Ok(PopResult::Content(text)),
Some(MockLlmResponse::ToolCalls(calls)) => Ok(PopResult::ToolCalls(calls)),
Some(MockLlmResponse::Err(e)) => Err(e),
None => Err(ReactError::Llm(Box::new(LlmError::EmptyResponse))),
}
}
}
enum PopResult {
Content(String),
ToolCalls(Vec<ToolCall>),
}
impl LlmClient for MockLlmClient {
fn chat(&self, request: ChatRequest) -> BoxFuture<'_, Result<ChatResponse>> {
Box::pin(async move {
self.calls.lock().unwrap().push(request.messages);
match self.pop_response()? {
PopResult::Content(text) => Ok(ChatResponse {
message: Message::assistant(text),
finish_reason: Some("stop".to_string()),
raw: crate::llm::types::ChatCompletionResponse::default(),
}),
PopResult::ToolCalls(calls) => Ok(ChatResponse {
message: Message::assistant_with_tools(calls),
finish_reason: Some("tool_calls".to_string()),
raw: crate::llm::types::ChatCompletionResponse::default(),
}),
}
})
}
fn chat_stream(
&self,
request: ChatRequest,
) -> BoxFuture<'_, Result<BoxStream<'_, Result<ChatChunk>>>> {
Box::pin(async move {
self.calls.lock().unwrap().push(request.messages);
match self.pop_response()? {
PopResult::Content(text) => {
let stream = futures::stream::once(async move {
Ok(ChatChunk {
delta: DeltaMessage {
role: Some("assistant".to_string()),
content: Some(text),
reasoning_content: None,
tool_calls: None,
},
finish_reason: Some("stop".to_string()),
usage: None,
})
});
Ok(Box::pin(stream) as BoxStream<'_, Result<ChatChunk>>)
}
PopResult::ToolCalls(calls) => {
let delta_calls: Vec<DeltaToolCall> = calls
.into_iter()
.enumerate()
.map(|(i, tc)| DeltaToolCall {
index: i as u32,
id: Some(tc.id),
call_type: Some(tc.call_type),
function: Some(DeltaFunctionCall {
name: Some(tc.function.name),
arguments: Some(tc.function.arguments),
}),
})
.collect();
let stream = futures::stream::once(async move {
Ok(ChatChunk {
delta: DeltaMessage {
role: Some("assistant".to_string()),
content: None,
reasoning_content: None,
tool_calls: Some(delta_calls),
},
finish_reason: Some("tool_calls".to_string()),
usage: None,
})
});
Ok(Box::pin(stream) as BoxStream<'_, Result<ChatChunk>>)
}
}
})
}
fn model_name(&self) -> &str {
&self.model_name
}
}