mod types;
#[cfg(feature = "openai")]
pub mod openai;
pub use types::*;
use async_trait::async_trait;
use serde_json::Value;
use crate::Result;
use crate::error::RustAgentsError;
use crate::harness::message::{AssistantMessage, ContentBlock, Message};
use crate::harness::model::{ChatModel, ModelDelta, ModelRequest, ModelResponse};
use crate::harness::tool::ToolCall;
use crate::harness::usage::Usage;
fn estimate_input_tokens(request: &ModelRequest) -> u64 {
let total_chars: u64 = request.messages.iter().map(|m| m.text().len() as u64).sum();
total_chars.div_ceil(4)
}
fn estimate_output_tokens(text: &str) -> u64 {
let chars = text.len() as u64;
std::cmp::max(1, chars.div_ceil(4))
}
impl MockModel {
pub fn echo() -> Self {
Self {
behavior: MockBehavior::Echo,
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn constant(text: impl Into<String>) -> Self {
Self {
behavior: MockBehavior::Constant(text.into()),
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn with_responses(responses: Vec<ModelResponse>) -> Self {
assert!(
!responses.is_empty(),
"MockModel::with_responses: responses must not be empty"
);
Self {
behavior: MockBehavior::Scripted(responses),
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn with_tool_call(name: impl Into<String>, arguments: impl Into<Value>) -> Self {
Self {
behavior: MockBehavior::ToolCall {
name: name.into(),
arguments: arguments.into(),
},
inner: std::sync::Mutex::new(MockInner::default()),
}
}
pub fn call_count(&self) -> u64 {
self.inner
.lock()
.expect("MockModel inner state poisoned")
.call_count
}
}
#[async_trait]
impl<State: Send + Sync> ChatModel<State> for MockModel {
async fn invoke(&self, _state: &State, request: ModelRequest) -> Result<ModelResponse> {
let call_id = {
let mut inner = self
.inner
.lock()
.map_err(|e| RustAgentsError::Model(format!("MockModel lock poisoned: {e}")))?;
inner.call_count += 1;
inner.call_count
};
let msg_id = format!("mock-msg-{call_id}");
let input_tokens = estimate_input_tokens(&request);
let response = match &self.behavior {
MockBehavior::Echo => {
let text = request
.messages
.iter()
.rev()
.find_map(|m| {
if let Message::User(_) = m {
Some(m.text())
} else {
None
}
})
.unwrap_or_default();
let output_tokens = estimate_output_tokens(&text);
ModelResponse::assistant(text)
.with_usage(Usage::new(input_tokens, output_tokens))
.with_finish_reason("stop")
}
MockBehavior::Constant(text) => {
let output_tokens = estimate_output_tokens(text);
ModelResponse::assistant(text.clone())
.with_usage(Usage::new(input_tokens, output_tokens))
.with_finish_reason("stop")
}
MockBehavior::Scripted(responses) => {
let index = {
let mut inner = self.inner.lock().map_err(|e| {
RustAgentsError::Model(format!("MockModel lock poisoned: {e}"))
})?;
let idx = ((inner.call_count - 1) as usize) % responses.len();
inner.scripted_index = idx;
idx
};
responses[index].clone()
}
MockBehavior::ToolCall { name, arguments } => {
let tool_call = ToolCall {
id: format!("mock-tool-{call_id}"),
name: name.clone(),
arguments: arguments.clone(),
};
let usage = Usage::new(input_tokens, 5);
let message = AssistantMessage {
id: Some(msg_id.clone()),
content: Vec::new(),
tool_calls: vec![tool_call],
usage: Some(usage),
};
ModelResponse {
message,
usage: Some(usage),
finish_reason: Some("tool_calls".to_string()),
raw: None,
resolved_model: None,
}
}
};
let mut response = response;
if response.message.id.is_none() {
response.message.id = Some(msg_id);
}
Ok(response)
}
async fn stream(&self, state: &State, request: ModelRequest) -> Result<Vec<ModelDelta>> {
let response = self.invoke(state, request).await?;
let call_id = response
.message
.id
.clone()
.unwrap_or_else(|| "mock-stream".to_string());
let text = response.text();
if text.is_empty() {
return Ok(vec![ModelDelta {
call_id,
content: String::new(),
tool_call: None,
}]);
}
let chars: Vec<char> = text.chars().collect();
let mid = chars.len() / 2;
let first: String = chars[..mid].iter().collect();
let second: String = chars[mid..].iter().collect();
Ok(vec![
ModelDelta {
call_id: call_id.clone(),
content: first,
tool_call: None,
},
ModelDelta {
call_id,
content: second,
tool_call: None,
},
])
}
}
impl MockModel {
pub fn text_response(text: impl Into<String>) -> ModelResponse {
let s = text.into();
let output_tokens = estimate_output_tokens(&s);
ModelResponse {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(s)],
tool_calls: Vec::new(),
usage: Some(Usage::new(10, output_tokens)),
},
usage: Some(Usage::new(10, output_tokens)),
finish_reason: Some("stop".to_string()),
raw: None,
resolved_model: None,
}
}
}
#[cfg(test)]
mod test;