mod types;
#[cfg(feature = "openai")]
pub mod openai;
pub use types::*;
use async_trait::async_trait;
use serde_json::Value;
use crate::Result;
use crate::error::TinyAgentsError;
use crate::harness::message::{AssistantMessage, ContentBlock, Message, MessageDelta};
use crate::harness::model::{
ChatModel, ModelProfile, ModelRequest, ModelResponse, ModelStream, ModelStreamItem,
};
use crate::harness::tool::ToolCall;
use crate::harness::usage::Usage;
fn estimate_input_tokens(request: &ModelRequest) -> u64 {
let total_chars: u64 = request.messages.iter().map(|m| m.text().len() as u64).sum();
total_chars.div_ceil(4)
}
fn estimate_output_tokens(text: &str) -> u64 {
let chars = text.len() as u64;
std::cmp::max(1, chars.div_ceil(4))
}
impl MockModel {
pub fn echo() -> Self {
Self {
behavior: MockBehavior::Echo,
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn constant(text: impl Into<String>) -> Self {
Self {
behavior: MockBehavior::Constant(text.into()),
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn with_responses(responses: Vec<ModelResponse>) -> Self {
assert!(
!responses.is_empty(),
"MockModel::with_responses: responses must not be empty"
);
Self {
behavior: MockBehavior::Scripted(responses),
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn with_tool_call(name: impl Into<String>, arguments: impl Into<Value>) -> Self {
Self {
behavior: MockBehavior::ToolCall {
name: name.into(),
arguments: arguments.into(),
},
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn call_count(&self) -> u64 {
self.inner
.lock()
.expect("MockModel inner state poisoned")
.call_count
}
}
fn mock_profile() -> &'static ModelProfile {
static PROFILE: std::sync::OnceLock<ModelProfile> = std::sync::OnceLock::new();
PROFILE.get_or_init(ModelProfile::permissive)
}
#[async_trait]
impl<State: Send + Sync> ChatModel<State> for MockModel {
fn profile(&self) -> Option<&ModelProfile> {
Some(mock_profile())
}
async fn invoke(&self, _state: &State, request: ModelRequest) -> Result<ModelResponse> {
let call_id = {
let mut inner = self
.inner
.lock()
.map_err(|e| TinyAgentsError::Model(format!("MockModel lock poisoned: {e}")))?;
inner.call_count += 1;
inner.call_count
};
let msg_id = format!("mock-msg-{call_id}");
let input_tokens = estimate_input_tokens(&request);
let response = match &self.behavior {
MockBehavior::Echo => {
let text = request
.messages
.iter()
.rev()
.find_map(|m| {
if let Message::User(_) = m {
Some(m.text())
} else {
None
}
})
.unwrap_or_default();
let output_tokens = estimate_output_tokens(&text);
ModelResponse::assistant(text)
.with_usage(Usage::new(input_tokens, output_tokens))
.with_finish_reason("stop")
}
MockBehavior::Constant(text) => {
let output_tokens = estimate_output_tokens(text);
ModelResponse::assistant(text.clone())
.with_usage(Usage::new(input_tokens, output_tokens))
.with_finish_reason("stop")
}
MockBehavior::Scripted(responses) => {
let index = {
let mut inner = self.inner.lock().map_err(|e| {
TinyAgentsError::Model(format!("MockModel lock poisoned: {e}"))
})?;
let idx = ((inner.call_count - 1) as usize) % responses.len();
inner.scripted_index = idx;
idx
};
responses[index].clone()
}
MockBehavior::ToolCall { name, arguments } => {
let tool_call = ToolCall {
id: format!("mock-tool-{call_id}"),
name: name.clone(),
arguments: arguments.clone(),
};
let usage = Usage::new(input_tokens, 5);
let message = AssistantMessage {
id: Some(msg_id.clone()),
content: Vec::new(),
tool_calls: vec![tool_call],
usage: Some(usage),
};
ModelResponse {
message,
usage: Some(usage),
finish_reason: Some("tool_calls".to_string()),
raw: None,
resolved_model: None,
}
}
};
let mut response = response;
if response.message.id.is_none() {
response.message.id = Some(msg_id);
}
Ok(response)
}
async fn stream(&self, state: &State, request: ModelRequest) -> Result<ModelStream> {
let response = self.invoke(state, request).await?;
let text = response.text();
let mut items = vec![ModelStreamItem::Started];
if text.is_empty() {
items.push(ModelStreamItem::MessageDelta(MessageDelta::default()));
} else {
let chars: Vec<char> = text.chars().collect();
let mid = chars.len() / 2;
let first: String = chars[..mid].iter().collect();
let second: String = chars[mid..].iter().collect();
items.push(ModelStreamItem::MessageDelta(MessageDelta {
text: first,
tool_call: None,
}));
items.push(ModelStreamItem::MessageDelta(MessageDelta {
text: second,
tool_call: None,
}));
}
items.push(ModelStreamItem::Completed(response));
Ok(Box::pin(futures::stream::iter(items)))
}
}
impl MockModel {
pub fn text_response(text: impl Into<String>) -> ModelResponse {
let s = text.into();
let output_tokens = estimate_output_tokens(&s);
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(s)],
tool_calls: Vec::new(),
usage: Some(Usage::new(10, output_tokens)),
},
usage: Some(Usage::new(10, output_tokens)),
finish_reason: Some("stop".to_string()),
raw: None,
resolved_model: None,
}
}
}
#[cfg(test)]
mod test;