use crate::{
EmbeddingModel, EmbeddingResult, FinishReason, LanguageModel, ModelMessage, ModelRequest,
ModelResponse, Part, ProviderRegistration, Result, ToolCall, ToolChoice, ToolDefinition,
ToolResult, Usage,
};
#[derive(Clone, Debug)]
pub struct MockLanguageModel {
model_id: String,
}
impl MockLanguageModel {
fn last_user_text(&self, messages: &[ModelMessage]) -> Option<String> {
messages.iter().rev().find_map(|message| {
if message.role == crate::Role::User {
let text = message.text();
if text.is_empty() {
None
} else {
Some(text)
}
} else {
None
}
})
}
fn last_tool_result<'a>(&self, messages: &'a [ModelMessage]) -> Option<&'a ToolResult> {
messages.iter().rev().find_map(|message| {
message.parts.iter().rev().find_map(|part| match part {
Part::ToolResult(result) => Some(result),
_ => None,
})
})
}
}
impl LanguageModel for MockLanguageModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn generate(&self, request: &ModelRequest) -> Result<ModelResponse> {
if let Some(result) = self.last_tool_result(&request.messages) {
return Ok(ModelResponse {
parts: vec![Part::Text(format!(
"I checked {} and found: {}",
result.name, result.output
))],
finish_reason: FinishReason::Stop,
usage: crate::estimate_usage(
&request.messages,
&format!("{} {}", result.name, result.output),
),
response_metadata: crate::metadata_with_provider("mock", &self.model_id),
});
}
let user_text = self.last_user_text(&request.messages).unwrap_or_default();
let user_text = user_text.trim();
if let Some(tool_name) = should_call_tool(user_text, &request.tools, &request.tool_choice) {
let input = user_text
.strip_prefix("remember ")
.or_else(|| user_text.strip_prefix("search "))
.unwrap_or(user_text)
.trim()
.to_string();
return Ok(ModelResponse {
parts: vec![Part::ToolCall(ToolCall::new(tool_name, input))],
finish_reason: FinishReason::ToolCalls,
usage: crate::estimate_usage(&request.messages, user_text),
response_metadata: crate::metadata_with_provider("mock", &self.model_id),
});
}
let text = if user_text.to_ascii_lowercase().contains("json") {
"{\"topic\":\"llm\",\"status\":\"ok\"}".to_string()
} else {
format!("Mock response from {}: {}", self.model_id, user_text)
};
Ok(ModelResponse {
parts: vec![Part::Text(text.clone())],
finish_reason: FinishReason::Stop,
usage: crate::estimate_usage(&request.messages, &text),
response_metadata: crate::metadata_with_provider("mock", &self.model_id),
})
}
}
fn should_call_tool(
user_text: &str,
tools: &[ToolDefinition],
tool_choice: &ToolChoice,
) -> Option<String> {
match tool_choice {
ToolChoice::None => None,
ToolChoice::Required(name) => tools
.iter()
.find(|tool| tool.name == *name)
.map(|tool| tool.name.clone()),
ToolChoice::Auto => {
if user_text.starts_with("remember ") {
tools
.iter()
.find(|tool| tool.name == "add_resource")
.map(|tool| tool.name.clone())
} else if user_text.starts_with("search ") {
tools
.iter()
.find(|tool| tool.name == "get_information")
.map(|tool| tool.name.clone())
} else {
None
}
}
}
}
#[derive(Clone, Debug)]
pub struct MockEmbeddingModel {
model_id: String,
}
impl EmbeddingModel for MockEmbeddingModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn embed(&self, value: &str) -> Result<EmbeddingResult> {
let mut buckets = vec![0.0; 8];
for ch in value.chars().flat_map(|c| c.to_lowercase()) {
if ch.is_ascii_alphabetic() {
let index = (ch as u8 - b'a') as usize % buckets.len();
buckets[index] += 1.0;
}
}
Ok(EmbeddingResult {
embedding: buckets,
usage: Usage {
input_tokens: crate::count_tokens(value),
output_tokens: 0,
},
})
}
}
fn mock_language_model(model_id: &str) -> Result<Box<dyn LanguageModel>> {
Ok(Box::new(MockLanguageModel {
model_id: model_id.to_string(),
}))
}
fn mock_embedding_model(model_id: &str) -> Result<Box<dyn EmbeddingModel>> {
Ok(Box::new(MockEmbeddingModel {
model_id: model_id.to_string(),
}) as Box<dyn EmbeddingModel>)
}
inventory::submit! {
ProviderRegistration {
id: "mock",
language_model: mock_language_model,
embedding_model: mock_embedding_model,
}
}