1use async_trait::async_trait;
7use std::sync::Mutex;
8
9use super::{CompletionRequest, CompletionResponse, LlmDriver, ToolCall};
10use crate::agent::result::{AgentError, StopReason, TokenUsage};
11use crate::serve::backends::PrivacyTier;
12
13pub struct MockDriver {
15 responses: Mutex<Vec<CompletionResponse>>,
16 context_window: usize,
17 cost_per_token: f64,
19}
20
21impl MockDriver {
22 pub fn new(responses: Vec<CompletionResponse>) -> Self {
27 Self { responses: Mutex::new(responses), context_window: 4096, cost_per_token: 0.0 }
28 }
29
30 pub fn single_response(text: &str) -> Self {
32 Self::new(vec![CompletionResponse {
33 text: text.to_string(),
34 stop_reason: StopReason::EndTurn,
35 tool_calls: vec![],
36 usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
37 }])
38 }
39
40 pub fn tool_then_response(
42 tool_name: &str,
43 tool_input: serde_json::Value,
44 final_text: &str,
45 ) -> Self {
46 Self::new(vec![
47 CompletionResponse {
48 text: String::new(),
49 stop_reason: StopReason::ToolUse,
50 tool_calls: vec![ToolCall {
51 id: "mock-1".into(),
52 name: tool_name.to_string(),
53 input: tool_input,
54 }],
55 usage: TokenUsage { input_tokens: 10, output_tokens: 5 },
56 },
57 CompletionResponse {
58 text: final_text.to_string(),
59 stop_reason: StopReason::EndTurn,
60 tool_calls: vec![],
61 usage: TokenUsage { input_tokens: 20, output_tokens: 10 },
62 },
63 ])
64 }
65
66 #[must_use]
68 pub fn with_context_window(mut self, size: usize) -> Self {
69 self.context_window = size;
70 self
71 }
72
73 #[must_use]
75 pub fn with_cost_per_token(mut self, cost: f64) -> Self {
76 self.cost_per_token = cost;
77 self
78 }
79}
80
81#[async_trait]
82impl LlmDriver for MockDriver {
83 async fn complete(
84 &self,
85 _request: CompletionRequest,
86 ) -> Result<CompletionResponse, AgentError> {
87 let mut responses = self.responses.lock().map_err(|e| {
88 AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(format!(
89 "mock lock poisoned: {e}"
90 )))
91 })?;
92
93 if responses.is_empty() {
94 Ok(CompletionResponse {
95 text: "[mock exhausted]".into(),
96 stop_reason: StopReason::EndTurn,
97 tool_calls: vec![],
98 usage: TokenUsage::default(),
99 })
100 } else {
101 Ok(responses.remove(0))
102 }
103 }
104
105 fn context_window(&self) -> usize {
106 self.context_window
107 }
108
109 fn privacy_tier(&self) -> PrivacyTier {
110 PrivacyTier::Sovereign
111 }
112
113 #[allow(clippy::cast_precision_loss)] fn estimate_cost(&self, usage: &TokenUsage) -> f64 {
115 self.cost_per_token * usage.total() as f64
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[tokio::test]
124 async fn test_single_response() {
125 let driver = MockDriver::single_response("hello world");
126 let req = CompletionRequest {
127 model: "mock".into(),
128 messages: vec![],
129 tools: vec![],
130 max_tokens: 100,
131 temperature: 0.0,
132 system: None,
133 };
134
135 let resp = driver.complete(req).await.expect("complete failed");
136 assert_eq!(resp.text, "hello world");
137 assert_eq!(resp.stop_reason, StopReason::EndTurn);
138 assert!(resp.tool_calls.is_empty());
139 }
140
141 #[tokio::test]
142 async fn test_sequenced_responses() {
143 let driver = MockDriver::new(vec![
144 CompletionResponse {
145 text: "first".into(),
146 stop_reason: StopReason::EndTurn,
147 tool_calls: vec![],
148 usage: TokenUsage::default(),
149 },
150 CompletionResponse {
151 text: "second".into(),
152 stop_reason: StopReason::EndTurn,
153 tool_calls: vec![],
154 usage: TokenUsage::default(),
155 },
156 ]);
157
158 let req = CompletionRequest {
159 model: "mock".into(),
160 messages: vec![],
161 tools: vec![],
162 max_tokens: 100,
163 temperature: 0.0,
164 system: None,
165 };
166
167 let r1 = driver.complete(req.clone()).await.expect("first failed");
168 assert_eq!(r1.text, "first");
169
170 let r2 = driver.complete(req).await.expect("second failed");
171 assert_eq!(r2.text, "second");
172 }
173
174 #[tokio::test]
175 async fn test_exhausted_responses() {
176 let driver = MockDriver::new(vec![]);
177 let req = CompletionRequest {
178 model: "mock".into(),
179 messages: vec![],
180 tools: vec![],
181 max_tokens: 100,
182 temperature: 0.0,
183 system: None,
184 };
185
186 let resp = driver.complete(req).await.expect("complete failed");
187 assert_eq!(resp.text, "[mock exhausted]");
188 }
189
190 #[tokio::test]
191 async fn test_tool_call_response() {
192 let driver = MockDriver::tool_then_response(
193 "rag",
194 serde_json::json!({"query": "test"}),
195 "final answer",
196 );
197
198 let req = CompletionRequest {
199 model: "mock".into(),
200 messages: vec![],
201 tools: vec![],
202 max_tokens: 100,
203 temperature: 0.0,
204 system: None,
205 };
206
207 let r1 = driver.complete(req.clone()).await.expect("first failed");
208 assert_eq!(r1.stop_reason, StopReason::ToolUse);
209 assert_eq!(r1.tool_calls.len(), 1);
210 assert_eq!(r1.tool_calls[0].name, "rag");
211
212 let r2 = driver.complete(req).await.expect("second failed");
213 assert_eq!(r2.text, "final answer");
214 assert_eq!(r2.stop_reason, StopReason::EndTurn);
215 }
216
217 #[test]
218 fn test_context_window() {
219 let driver = MockDriver::single_response("hi");
220 assert_eq!(driver.context_window(), 4096);
221
222 let driver = driver.with_context_window(8192);
223 assert_eq!(driver.context_window(), 8192);
224 }
225
226 #[test]
227 fn test_privacy_tier_sovereign() {
228 let driver = MockDriver::single_response("hi");
229 assert_eq!(driver.privacy_tier(), PrivacyTier::Sovereign);
230 }
231}