use async_trait::async_trait;
use std::sync::Mutex;
use super::{CompletionRequest, CompletionResponse, LlmDriver, ToolCall};
use crate::agent::result::{AgentError, StopReason, TokenUsage};
use crate::serve::backends::PrivacyTier;
pub struct MockDriver {
responses: Mutex<Vec<CompletionResponse>>,
context_window: usize,
cost_per_token: f64,
}
impl MockDriver {
pub fn new(responses: Vec<CompletionResponse>) -> Self {
Self { responses: Mutex::new(responses), context_window: 4096, cost_per_token: 0.0 }
}
pub fn single_response(text: &str) -> Self {
Self::new(vec![CompletionResponse {
text: text.to_string(),
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
}])
}
pub fn tool_then_response(
tool_name: &str,
tool_input: serde_json::Value,
final_text: &str,
) -> Self {
Self::new(vec![
CompletionResponse {
text: String::new(),
stop_reason: StopReason::ToolUse,
tool_calls: vec![ToolCall {
id: "mock-1".into(),
name: tool_name.to_string(),
input: tool_input,
}],
usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
},
CompletionResponse {
text: final_text.to_string(),
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage { input_tokens: 20, output_tokens: 10 },
},
])
}
#[must_use]
pub fn with_context_window(mut self, size: usize) -> Self {
self.context_window = size;
self
}
#[must_use]
pub fn with_cost_per_token(mut self, cost: f64) -> Self {
self.cost_per_token = cost;
self
}
}
#[async_trait]
impl LlmDriver for MockDriver {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, AgentError> {
let mut responses = self.responses.lock().map_err(|e| {
AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(format!(
"mock lock poisoned: {e}"
)))
})?;
if responses.is_empty() {
Ok(CompletionResponse {
text: "[mock exhausted]".into(),
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage::default(),
})
} else {
Ok(responses.remove(0))
}
}
fn context_window(&self) -> usize {
self.context_window
}
fn privacy_tier(&self) -> PrivacyTier {
PrivacyTier::Sovereign
}
#[allow(clippy::cast_precision_loss)] fn estimate_cost(&self, usage: &TokenUsage) -> f64 {
self.cost_per_token * usage.total() as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_single_response() {
let driver = MockDriver::single_response("hello world");
let req = CompletionRequest {
model: "mock".into(),
messages: vec![],
tools: vec![],
max_tokens: 100,
temperature: 0.0,
system: None,
};
let resp = driver.complete(req).await.expect("complete failed");
assert_eq!(resp.text, "hello world");
assert_eq!(resp.stop_reason, StopReason::EndTurn);
assert!(resp.tool_calls.is_empty());
}
#[tokio::test]
async fn test_sequenced_responses() {
let driver = MockDriver::new(vec![
CompletionResponse {
text: "first".into(),
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage::default(),
},
CompletionResponse {
text: "second".into(),
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage::default(),
},
]);
let req = CompletionRequest {
model: "mock".into(),
messages: vec![],
tools: vec![],
max_tokens: 100,
temperature: 0.0,
system: None,
};
let r1 = driver.complete(req.clone()).await.expect("first failed");
assert_eq!(r1.text, "first");
let r2 = driver.complete(req).await.expect("second failed");
assert_eq!(r2.text, "second");
}
#[tokio::test]
async fn test_exhausted_responses() {
let driver = MockDriver::new(vec![]);
let req = CompletionRequest {
model: "mock".into(),
messages: vec![],
tools: vec![],
max_tokens: 100,
temperature: 0.0,
system: None,
};
let resp = driver.complete(req).await.expect("complete failed");
assert_eq!(resp.text, "[mock exhausted]");
}
#[tokio::test]
async fn test_tool_call_response() {
let driver = MockDriver::tool_then_response(
"rag",
serde_json::json!({"query": "test"}),
"final answer",
);
let req = CompletionRequest {
model: "mock".into(),
messages: vec![],
tools: vec![],
max_tokens: 100,
temperature: 0.0,
system: None,
};
let r1 = driver.complete(req.clone()).await.expect("first failed");
assert_eq!(r1.stop_reason, StopReason::ToolUse);
assert_eq!(r1.tool_calls.len(), 1);
assert_eq!(r1.tool_calls[0].name, "rag");
let r2 = driver.complete(req).await.expect("second failed");
assert_eq!(r2.text, "final answer");
assert_eq!(r2.stop_reason, StopReason::EndTurn);
}
#[test]
fn test_context_window() {
let driver = MockDriver::single_response("hi");
assert_eq!(driver.context_window(), 4096);
let driver = driver.with_context_window(8192);
assert_eq!(driver.context_window(), 8192);
}
#[test]
fn test_privacy_tier_sovereign() {
let driver = MockDriver::single_response("hi");
assert_eq!(driver.privacy_tier(), PrivacyTier::Sovereign);
}
}