use async_trait::async_trait;
use std::any::Any;
use std::sync::{Arc, Mutex};
use crate::core::Provider;
use crate::error::LlmConnectorError;
use crate::types::{
ChatRequest, ChatResponse, Choice, EmbedRequest, EmbedResponse, EmbeddingData, Message, Role,
ToolCall, Usage,
};
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
pub struct MockProvider {
responses: Mutex<Vec<Result<ChatResponse, LlmConnectorError>>>,
default_response: ChatResponse,
requests: Mutex<Vec<ChatRequest>>,
}
impl MockProvider {
pub fn new(content: impl Into<String>) -> Self {
let content = content.into();
Self {
responses: Mutex::new(Vec::new()),
default_response: Self::make_response(content, None, None),
requests: Mutex::new(Vec::new()),
}
}
pub fn with_error(error: LlmConnectorError) -> Self {
let mut provider = Self::new("");
provider.responses = Mutex::new(vec![Err(error)]);
provider
}
pub fn with_responses(responses: Vec<Result<ChatResponse, LlmConnectorError>>) -> Self {
Self {
responses: Mutex::new(responses.into_iter().rev().collect()),
default_response: Self::make_response("".to_string(), None, None),
requests: Mutex::new(Vec::new()),
}
}
pub fn get_requests(&self) -> Vec<ChatRequest> {
self.requests.lock().unwrap().clone()
}
pub fn request_count(&self) -> usize {
self.requests.lock().unwrap().len()
}
fn make_response(content: String, model: Option<String>, usage: Option<Usage>) -> ChatResponse {
let message = Message::text(Role::Assistant, &content);
ChatResponse {
id: "mock-id".to_string(),
object: "chat.completion".to_string(),
created: 0,
model: model.unwrap_or_else(|| "mock-model".to_string()),
choices: vec![Choice {
index: 0,
message,
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
content,
reasoning_content: None,
usage,
system_fingerprint: None,
}
}
fn make_tool_call_response(
tool_calls: Vec<ToolCall>,
model: Option<String>,
usage: Option<Usage>,
) -> ChatResponse {
let message = Message::assistant_with_tool_calls(tool_calls);
let content = message.content_as_text();
ChatResponse {
id: "mock-id".to_string(),
object: "chat.completion".to_string(),
created: 0,
model: model.unwrap_or_else(|| "mock-model".to_string()),
choices: vec![Choice {
index: 0,
message,
finish_reason: Some("tool_calls".to_string()),
logprobs: None,
}],
content,
reasoning_content: None,
usage,
system_fingerprint: None,
}
}
}
#[async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str {
"mock"
}
fn capabilities(&self) -> crate::protocols::common::capabilities::ProviderCapabilities {
crate::protocols::common::capabilities::ProviderCapabilities::default()
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
self.requests.lock().unwrap().push(request.clone());
let mut responses = self.responses.lock().unwrap();
if let Some(response) = responses.pop() {
response
} else {
Ok(self.default_response.clone())
}
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, _request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
Err(LlmConnectorError::UnsupportedOperation(
"MockProvider does not support streaming".to_string(),
))
}
async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
Ok(vec!["mock-model".to_string()])
}
async fn embed(&self, request: &EmbedRequest) -> Result<EmbedResponse, LlmConnectorError> {
Ok(EmbedResponse {
object: "list".to_string(),
data: vec![EmbeddingData {
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
index: 0,
}],
model: request.model.clone(),
usage: Usage::default(),
})
}
fn as_any(&self) -> &dyn Any {
self
}
}
pub struct MockProviderBuilder {
content: Option<String>,
model: Option<String>,
usage: Option<Usage>,
tool_calls: Option<Vec<ToolCall>>,
responses: Vec<Result<ChatResponse, LlmConnectorError>>,
}
impl MockProviderBuilder {
pub fn new() -> Self {
Self {
content: None,
model: None,
usage: None,
tool_calls: None,
responses: Vec::new(),
}
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
self.content = Some(content.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
pub fn add_response_content(mut self, content: impl Into<String>) -> Self {
let resp =
MockProvider::make_response(content.into(), self.model.clone(), self.usage.clone());
self.responses.push(Ok(resp));
self
}
pub fn add_response(mut self, response: ChatResponse) -> Self {
self.responses.push(Ok(response));
self
}
pub fn add_error(mut self, error: LlmConnectorError) -> Self {
self.responses.push(Err(error));
self
}
pub fn build(self) -> MockProvider {
if !self.responses.is_empty() {
MockProvider::with_responses(self.responses)
} else if let Some(tool_calls) = self.tool_calls {
let default = MockProvider::make_tool_call_response(
tool_calls,
self.model.clone(),
self.usage.clone(),
);
MockProvider {
responses: Mutex::new(Vec::new()),
default_response: default,
requests: Mutex::new(Vec::new()),
}
} else {
let content = self.content.unwrap_or_default();
let default = MockProvider::make_response(content, self.model, self.usage);
MockProvider {
responses: Mutex::new(Vec::new()),
default_response: default,
requests: Mutex::new(Vec::new()),
}
}
}
pub fn build_client(self) -> crate::client::LlmClient {
let provider = self.build();
crate::client::LlmClient::from_provider(Arc::new(provider))
}
}
impl Default for MockProviderBuilder {
fn default() -> Self {
Self::new()
}
}