Skip to main content

batuta/agent/driver/
mock.rs

1//! Mock LLM driver for deterministic testing.
2//!
3//! Returns pre-configured responses in sequence. Essential for
4//! testing the agent loop without actual model inference.
5
6use async_trait::async_trait;
7use std::sync::Mutex;
8
9use super::{CompletionRequest, CompletionResponse, LlmDriver, ToolCall};
10use crate::agent::result::{AgentError, StopReason, TokenUsage};
11use crate::serve::backends::PrivacyTier;
12
13/// Mock driver that returns pre-configured responses.
14pub struct MockDriver {
15    responses: Mutex<Vec<CompletionResponse>>,
16    context_window: usize,
17    /// Cost per token (input + output) for testing cost budgets.
18    cost_per_token: f64,
19}
20
21impl MockDriver {
22    /// Create a mock driver with a sequence of responses.
23    ///
24    /// Responses are returned in order. If exhausted, returns
25    /// a default "end of mock responses" response.
26    pub fn new(responses: Vec<CompletionResponse>) -> Self {
27        Self { responses: Mutex::new(responses), context_window: 4096, cost_per_token: 0.0 }
28    }
29
30    /// Create a mock that returns a single text response.
31    pub fn single_response(text: &str) -> Self {
32        Self::new(vec![CompletionResponse {
33            text: text.to_string(),
34            stop_reason: StopReason::EndTurn,
35            tool_calls: vec![],
36            usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
37        }])
38    }
39
40    /// Create a mock that first requests a tool call, then responds.
41    pub fn tool_then_response(
42        tool_name: &str,
43        tool_input: serde_json::Value,
44        final_text: &str,
45    ) -> Self {
46        Self::new(vec![
47            CompletionResponse {
48                text: String::new(),
49                stop_reason: StopReason::ToolUse,
50                tool_calls: vec![ToolCall {
51                    id: "mock-1".into(),
52                    name: tool_name.to_string(),
53                    input: tool_input,
54                }],
55                usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
56            },
57            CompletionResponse {
58                text: final_text.to_string(),
59                stop_reason: StopReason::EndTurn,
60                tool_calls: vec![],
61                usage: TokenUsage { input_tokens: 20, output_tokens: 10 },
62            },
63        ])
64    }
65
66    /// Set context window size.
67    #[must_use]
68    pub fn with_context_window(mut self, size: usize) -> Self {
69        self.context_window = size;
70        self
71    }
72
73    /// Set cost per token for testing cost budget enforcement.
74    #[must_use]
75    pub fn with_cost_per_token(mut self, cost: f64) -> Self {
76        self.cost_per_token = cost;
77        self
78    }
79}
80
81#[async_trait]
82impl LlmDriver for MockDriver {
83    async fn complete(
84        &self,
85        _request: CompletionRequest,
86    ) -> Result<CompletionResponse, AgentError> {
87        let mut responses = self.responses.lock().map_err(|e| {
88            AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(format!(
89                "mock lock poisoned: {e}"
90            )))
91        })?;
92
93        if responses.is_empty() {
94            Ok(CompletionResponse {
95                text: "[mock exhausted]".into(),
96                stop_reason: StopReason::EndTurn,
97                tool_calls: vec![],
98                usage: TokenUsage::default(),
99            })
100        } else {
101            Ok(responses.remove(0))
102        }
103    }
104
105    fn context_window(&self) -> usize {
106        self.context_window
107    }
108
109    fn privacy_tier(&self) -> PrivacyTier {
110        PrivacyTier::Sovereign
111    }
112
113    #[allow(clippy::cast_precision_loss)] // token counts fit in f64 mantissa
114    fn estimate_cost(&self, usage: &TokenUsage) -> f64 {
115        self.cost_per_token * usage.total() as f64
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[tokio::test]
124    async fn test_single_response() {
125        let driver = MockDriver::single_response("hello world");
126        let req = CompletionRequest {
127            model: "mock".into(),
128            messages: vec![],
129            tools: vec![],
130            max_tokens: 100,
131            temperature: 0.0,
132            system: None,
133        };
134
135        let resp = driver.complete(req).await.expect("complete failed");
136        assert_eq!(resp.text, "hello world");
137        assert_eq!(resp.stop_reason, StopReason::EndTurn);
138        assert!(resp.tool_calls.is_empty());
139    }
140
141    #[tokio::test]
142    async fn test_sequenced_responses() {
143        let driver = MockDriver::new(vec![
144            CompletionResponse {
145                text: "first".into(),
146                stop_reason: StopReason::EndTurn,
147                tool_calls: vec![],
148                usage: TokenUsage::default(),
149            },
150            CompletionResponse {
151                text: "second".into(),
152                stop_reason: StopReason::EndTurn,
153                tool_calls: vec![],
154                usage: TokenUsage::default(),
155            },
156        ]);
157
158        let req = CompletionRequest {
159            model: "mock".into(),
160            messages: vec![],
161            tools: vec![],
162            max_tokens: 100,
163            temperature: 0.0,
164            system: None,
165        };
166
167        let r1 = driver.complete(req.clone()).await.expect("first failed");
168        assert_eq!(r1.text, "first");
169
170        let r2 = driver.complete(req).await.expect("second failed");
171        assert_eq!(r2.text, "second");
172    }
173
174    #[tokio::test]
175    async fn test_exhausted_responses() {
176        let driver = MockDriver::new(vec![]);
177        let req = CompletionRequest {
178            model: "mock".into(),
179            messages: vec![],
180            tools: vec![],
181            max_tokens: 100,
182            temperature: 0.0,
183            system: None,
184        };
185
186        let resp = driver.complete(req).await.expect("complete failed");
187        assert_eq!(resp.text, "[mock exhausted]");
188    }
189
190    #[tokio::test]
191    async fn test_tool_call_response() {
192        let driver = MockDriver::tool_then_response(
193            "rag",
194            serde_json::json!({"query": "test"}),
195            "final answer",
196        );
197
198        let req = CompletionRequest {
199            model: "mock".into(),
200            messages: vec![],
201            tools: vec![],
202            max_tokens: 100,
203            temperature: 0.0,
204            system: None,
205        };
206
207        let r1 = driver.complete(req.clone()).await.expect("first failed");
208        assert_eq!(r1.stop_reason, StopReason::ToolUse);
209        assert_eq!(r1.tool_calls.len(), 1);
210        assert_eq!(r1.tool_calls[0].name, "rag");
211
212        let r2 = driver.complete(req).await.expect("second failed");
213        assert_eq!(r2.text, "final answer");
214        assert_eq!(r2.stop_reason, StopReason::EndTurn);
215    }
216
217    #[test]
218    fn test_context_window() {
219        let driver = MockDriver::single_response("hi");
220        assert_eq!(driver.context_window(), 4096);
221
222        let driver = driver.with_context_window(8192);
223        assert_eq!(driver.context_window(), 8192);
224    }
225
226    #[test]
227    fn test_privacy_tier_sovereign() {
228        let driver = MockDriver::single_response("hi");
229        assert_eq!(driver.privacy_tier(), PrivacyTier::Sovereign);
230    }
231}