use ares::llm::client::{LLMClientFactoryTrait, Provider};
use ares::llm::{LLMClient, LLMResponse};
use ares::types::{AppError, Result, ToolCall, ToolDefinition};
use async_trait::async_trait;
use futures::stream::{self, StreamExt};
use std::sync::Arc;
#[derive(Clone)]
pub struct MockLLMClient {
response: String,
tool_calls: Vec<ToolCall>,
should_fail: bool,
}
impl MockLLMClient {
pub fn new(response: &str) -> Self {
Self {
response: response.to_string(),
tool_calls: vec![],
should_fail: false,
}
}
pub fn with_tool_calls(response: &str, tool_calls: Vec<ToolCall>) -> Self {
Self {
response: response.to_string(),
tool_calls,
should_fail: false,
}
}
pub fn failing() -> Self {
Self {
response: String::new(),
tool_calls: vec![],
should_fail: true,
}
}
}
#[async_trait]
impl LLMClient for MockLLMClient {
async fn generate(&self, _prompt: &str) -> Result<String> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
Ok(self.response.clone())
}
async fn generate_with_system(&self, _system: &str, _prompt: &str) -> Result<String> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
Ok(self.response.clone())
}
async fn generate_with_history(&self, _messages: &[(String, String)]) -> Result<LLMResponse> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
Ok(LLMResponse {
content: self.response.clone(),
tool_calls: self.tool_calls.clone(),
finish_reason: "stop".to_string(),
usage: None,
})
}
async fn generate_with_tools(
&self,
_prompt: &str,
_tools: &[ToolDefinition],
) -> Result<LLMResponse> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
let finish_reason = if self.tool_calls.is_empty() {
"stop"
} else {
"tool_calls"
};
Ok(LLMResponse {
content: self.response.clone(),
tool_calls: self.tool_calls.clone(),
finish_reason: finish_reason.to_string(),
usage: None,
})
}
async fn stream(
&self,
_prompt: &str,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
let response = self.response.clone();
let chunks: Vec<String> = response
.chars()
.collect::<Vec<_>>()
.chunks(5)
.map(|c| c.iter().collect())
.collect();
let stream = stream::iter(chunks.into_iter().map(Ok));
Ok(Box::new(stream.boxed()))
}
async fn stream_with_system(
&self,
_system: &str,
_prompt: &str,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
let response = self.response.clone();
let chunks: Vec<String> = response
.chars()
.collect::<Vec<_>>()
.chunks(5)
.map(|c| c.iter().collect())
.collect();
let stream = stream::iter(chunks.into_iter().map(Ok));
Ok(Box::new(stream.boxed()))
}
async fn stream_with_history(
&self,
_messages: &[(String, String)],
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
let response = self.response.clone();
let chunks: Vec<String> = response
.chars()
.collect::<Vec<_>>()
.chunks(5)
.map(|c| c.iter().collect())
.collect();
let stream = stream::iter(chunks.into_iter().map(Ok));
Ok(Box::new(stream.boxed()))
}
fn model_name(&self) -> &str {
"mock-model"
}
async fn generate_with_tools_and_history(
&self,
_messages: &[ares::llm::ConversationMessage],
_tools: &[ToolDefinition],
) -> Result<LLMResponse> {
if self.should_fail {
return Err(AppError::LLM("Mock LLM failure".to_string()));
}
let finish_reason = if self.tool_calls.is_empty() {
"stop"
} else {
"tool_calls"
};
Ok(LLMResponse {
content: self.response.clone(),
tool_calls: self.tool_calls.clone(),
finish_reason: finish_reason.to_string(),
usage: None,
})
}
}
pub struct MockLLMFactory {
provider: Provider,
client: Arc<MockLLMClient>,
}
impl MockLLMFactory {
pub fn new(client: MockLLMClient) -> Self {
Self {
provider: Provider::Ollama {
base_url: "http://localhost:11434".to_string(),
model: "mock".to_string(),
params: Default::default(),
},
client: Arc::new(client),
}
}
}
#[async_trait]
impl LLMClientFactoryTrait for MockLLMFactory {
fn default_provider(&self) -> &Provider {
&self.provider
}
async fn create_default(&self) -> Result<Box<dyn LLMClient>> {
Ok(Box::new((*self.client).clone()))
}
async fn create_with_provider(&self, _provider: Provider) -> Result<Box<dyn LLMClient>> {
Ok(Box::new((*self.client).clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_client_generate() {
let client = MockLLMClient::new("test response");
let result = client.generate("prompt").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test response");
}
#[tokio::test]
async fn test_mock_client_failing() {
let client = MockLLMClient::failing();
let result = client.generate("prompt").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_factory() {
let client = MockLLMClient::new("factory response");
let factory = MockLLMFactory::new(client);
let llm = factory.create_default().await.unwrap();
let result = llm.generate("test").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "factory response");
}
}