1use crate::agent::backend::LlmBackend;
20use crate::agent::{LLMResponse, Message, TokenCallback, TokenUsage, ToolCallRequest};
21use crate::tools::ToolDefinition;
22use crate::Result;
23use async_trait::async_trait;
24#[allow(unused_imports)]
25use serde_json::json;
26use serde_json::Value;
27use std::sync::atomic::{AtomicUsize, Ordering};
28use std::sync::Arc;
29
30#[derive(Clone, Debug)]
32pub enum MockResponse {
33 Text(String),
35 TextWithUsage { text: String, usage: TokenUsage },
37 ToolCall {
39 id: String,
40 name: String,
41 args: Value,
42 },
43 ToolSequence(Vec<ToolCallRequest>),
45}
46
47impl MockResponse {
48 pub fn text(s: impl Into<String>) -> Self {
49 Self::Text(s.into())
50 }
51
52 pub fn tool_call(name: impl Into<String>, args: Value) -> Self {
53 Self::ToolCall {
54 id: uuid::Uuid::new_v4().to_string(),
55 name: name.into(),
56 args,
57 }
58 }
59
60 pub fn tool_sequence(calls: Vec<(&str, Value)>) -> Self {
61 Self::ToolSequence(
62 calls
63 .into_iter()
64 .map(|(name, args)| ToolCallRequest {
65 id: uuid::Uuid::new_v4().to_string(),
66 name: name.to_string(),
67 arguments: args,
68 })
69 .collect(),
70 )
71 }
72}
73
74#[derive(Debug, Clone, PartialEq)]
81pub enum MockScenario {
82 TextOnly,
84 ReadFileRoundtrip,
86 BashRoundtrip,
88 MultiToolTurn,
90 EditRoundtrip,
92}
93
94impl MockScenario {
95 pub fn detect(messages: &[Message]) -> Option<Self> {
97 for msg in messages {
98 if let Some(pos) = msg.content.find("PARITY_SCENARIO:") {
99 let rest = msg.content[pos + 16..].trim();
100 let name = rest.split_whitespace().next().unwrap_or("");
101 return match name {
102 "text_only" => Some(Self::TextOnly),
103 "read_file_roundtrip" => Some(Self::ReadFileRoundtrip),
104 "bash_roundtrip" => Some(Self::BashRoundtrip),
105 "multi_tool_turn" => Some(Self::MultiToolTurn),
106 "edit_roundtrip" => Some(Self::EditRoundtrip),
107 _ => None,
108 };
109 }
110 }
111 None
112 }
113
114 pub fn responses(&self) -> Vec<MockResponse> {
116 match self {
117 Self::TextOnly => vec![MockResponse::text("Scenario complete: text only")],
118 Self::ReadFileRoundtrip => vec![
119 MockResponse::tool_call("read_file", serde_json::json!({"path": "src/lib.rs"})),
120 MockResponse::text("I read the file successfully."),
121 ],
122 Self::BashRoundtrip => vec![
123 MockResponse::tool_call("bash", serde_json::json!({"command": "echo hello"})),
124 MockResponse::text("Command executed successfully."),
125 ],
126 Self::MultiToolTurn => vec![
127 MockResponse::tool_sequence(vec![
128 ("read_file", serde_json::json!({"path": "Cargo.toml"})),
129 ("grep_search", serde_json::json!({"pattern": "version"})),
130 ]),
131 MockResponse::text("Found version info in both files."),
132 ],
133 Self::EditRoundtrip => vec![
134 MockResponse::tool_call("read_file", serde_json::json!({"path": "test.rs"})),
135 MockResponse::tool_call(
136 "edit_file",
137 serde_json::json!({
138 "path": "test.rs",
139 "old_string": "old",
140 "new_string": "new"
141 }),
142 ),
143 MockResponse::text("Edit complete."),
144 ],
145 }
146 }
147}
148
149pub fn mock_from_scenario(scenario: MockScenario) -> MockBackend {
151 MockBackend::new(scenario.responses())
152}
153
154pub struct MockBackend {
159 responses: Arc<Vec<MockResponse>>,
160 index: Arc<AtomicUsize>,
161}
162
163impl MockBackend {
164 pub fn new(responses: Vec<MockResponse>) -> Self {
165 Self {
166 responses: Arc::new(responses),
167 index: Arc::new(AtomicUsize::new(0)),
168 }
169 }
170
171 pub fn with_text(text: impl Into<String>) -> Self {
173 Self::new(vec![MockResponse::text(text)])
174 }
175
176 pub fn with_tool_call(id: &str, name: &str, args: Value, content: &str) -> Self {
178 Self::new(vec![
179 MockResponse::ToolCall {
180 id: id.to_string(),
181 name: name.to_string(),
182 args,
183 },
184 MockResponse::Text(content.to_string()),
185 ])
186 }
187
188 pub fn with_repeated_tool_call(name: &str) -> Self {
190 let mut responses = Vec::new();
191 for i in 0..32 {
192 responses.push(MockResponse::ToolCall {
193 id: format!("call_{i}"),
194 name: name.to_string(),
195 args: json!({}),
196 });
197 }
198 Self::new(responses)
199 }
200
201 pub fn with_multiple_tool_calls(calls: Vec<(&str, &str, Value)>) -> Self {
203 let tool_calls: Vec<ToolCallRequest> = calls
204 .into_iter()
205 .map(|(id, name, args)| ToolCallRequest {
206 id: id.to_string(),
207 name: name.to_string(),
208 arguments: args,
209 })
210 .collect();
211 Self::new(vec![
212 MockResponse::ToolSequence(tool_calls),
213 MockResponse::Text("Done".to_string()),
214 ])
215 }
216
217 pub fn with_text_and_usage(text: &str, prompt_tokens: u64, completion_tokens: u64) -> Self {
219 let reasoning_tokens = completion_tokens / 3;
220 let action_tokens = completion_tokens - reasoning_tokens;
221 let usage = TokenUsage {
222 prompt_tokens,
223 completion_tokens,
224 total_tokens: prompt_tokens + completion_tokens,
225 reasoning_tokens,
226 action_tokens,
227 };
228 Self::new(vec![MockResponse::TextWithUsage {
229 text: text.to_string(),
230 usage,
231 }])
232 }
233}
234
235#[async_trait]
236impl LlmBackend for MockBackend {
237 async fn generate(
238 &self,
239 _messages: &[Message],
240 _tools: &[ToolDefinition],
241 _on_token: Option<&TokenCallback>,
242 ) -> Result<LLMResponse> {
243 let idx = self.index.fetch_add(1, Ordering::SeqCst);
244
245 let response = self.responses.get(idx).cloned().unwrap_or_else(|| {
246 MockResponse::Text(String::new())
248 });
249
250 Ok(match response {
251 MockResponse::Text(content) => LLMResponse {
252 content,
253 reasoning: None,
254 tool_calls: vec![],
255 finish_reason: "stop".to_string(),
256 usage: None,
257 },
258 MockResponse::TextWithUsage { text, usage } => LLMResponse {
259 content: text,
260 reasoning: None,
261 tool_calls: vec![],
262 finish_reason: "stop".to_string(),
263 usage: Some(usage),
264 },
265 MockResponse::ToolCall { id, name, args } => LLMResponse {
266 content: String::new(),
267 reasoning: None,
268 tool_calls: vec![ToolCallRequest {
269 id,
270 name,
271 arguments: args,
272 }],
273 finish_reason: "tool_calls".to_string(),
274 usage: None,
275 },
276 MockResponse::ToolSequence(calls) => LLMResponse {
277 content: String::new(),
278 reasoning: None,
279 tool_calls: calls,
280 finish_reason: "tool_calls".to_string(),
281 usage: None,
282 },
283 })
284 }
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_scenario_detect_text_only() {
297 let messages = vec![Message {
298 role: crate::agent::Role::User,
299 content: "PARITY_SCENARIO: text_only\nDo something".into(),
300 tool_calls: vec![],
301 tool_result: None,
302 }];
303 assert_eq!(
304 MockScenario::detect(&messages),
305 Some(MockScenario::TextOnly)
306 );
307 }
308
309 #[test]
310 fn test_scenario_detect_read_file() {
311 let messages = vec![Message {
312 role: crate::agent::Role::User,
313 content: "Please PARITY_SCENARIO: read_file_roundtrip".into(),
314 tool_calls: vec![],
315 tool_result: None,
316 }];
317 assert_eq!(
318 MockScenario::detect(&messages),
319 Some(MockScenario::ReadFileRoundtrip)
320 );
321 }
322
323 #[test]
324 fn test_scenario_detect_none() {
325 let messages = vec![Message {
326 role: crate::agent::Role::User,
327 content: "Just a normal message".into(),
328 tool_calls: vec![],
329 tool_result: None,
330 }];
331 assert_eq!(MockScenario::detect(&messages), None);
332 }
333
334 #[test]
335 fn test_scenario_detect_unknown() {
336 let messages = vec![Message {
337 role: crate::agent::Role::User,
338 content: "PARITY_SCENARIO: nonexistent_scenario".into(),
339 tool_calls: vec![],
340 tool_result: None,
341 }];
342 assert_eq!(MockScenario::detect(&messages), None);
343 }
344
345 #[test]
346 fn test_scenario_responses_text_only() {
347 let responses = MockScenario::TextOnly.responses();
348 assert_eq!(responses.len(), 1);
349 assert!(matches!(&responses[0], MockResponse::Text(_)));
350 }
351
352 #[test]
353 fn test_scenario_responses_read_file() {
354 let responses = MockScenario::ReadFileRoundtrip.responses();
355 assert_eq!(responses.len(), 2);
356 assert!(
357 matches!(&responses[0], MockResponse::ToolCall { name, .. } if name == "read_file")
358 );
359 assert!(matches!(&responses[1], MockResponse::Text(_)));
360 }
361
362 #[test]
363 fn test_scenario_responses_multi_tool() {
364 let responses = MockScenario::MultiToolTurn.responses();
365 assert_eq!(responses.len(), 2);
366 assert!(matches!(&responses[0], MockResponse::ToolSequence(calls) if calls.len() == 2));
367 }
368
369 #[test]
370 fn test_scenario_responses_edit_roundtrip() {
371 let responses = MockScenario::EditRoundtrip.responses();
372 assert_eq!(responses.len(), 3);
373 assert!(
374 matches!(&responses[0], MockResponse::ToolCall { name, .. } if name == "read_file")
375 );
376 assert!(
377 matches!(&responses[1], MockResponse::ToolCall { name, .. } if name == "edit_file")
378 );
379 assert!(matches!(&responses[2], MockResponse::Text(_)));
380 }
381
382 #[test]
383 fn test_mock_from_scenario() {
384 let backend = mock_from_scenario(MockScenario::TextOnly);
385 assert_eq!(backend.responses.len(), 1);
386 }
387
388 #[test]
389 fn test_tool_sequence_constructor() {
390 let resp = MockResponse::tool_sequence(vec![
391 ("read_file", serde_json::json!({"path": "a.rs"})),
392 ("bash", serde_json::json!({"command": "ls"})),
393 ]);
394 if let MockResponse::ToolSequence(calls) = resp {
395 assert_eq!(calls.len(), 2);
396 assert_eq!(calls[0].name, "read_file");
397 assert_eq!(calls[1].name, "bash");
398 } else {
399 panic!("Expected ToolSequence");
400 }
401 }
402
403 #[tokio::test]
404 async fn test_mock_backend_tool_sequence() {
405 let backend = MockBackend::new(vec![
406 MockResponse::tool_sequence(vec![
407 ("read_file", serde_json::json!({"path": "a.rs"})),
408 ("grep_search", serde_json::json!({"pattern": "fn"})),
409 ]),
410 MockResponse::text("Done"),
411 ]);
412
413 let resp = backend.generate(&[], &[], None).await.unwrap();
414 assert_eq!(resp.tool_calls.len(), 2);
415 assert_eq!(resp.finish_reason, "tool_calls");
416
417 let resp2 = backend.generate(&[], &[], None).await.unwrap();
418 assert_eq!(resp2.content, "Done");
419 assert!(resp2.tool_calls.is_empty());
420 }
421
422 #[tokio::test]
423 async fn test_mock_backend_exhausted() {
424 let backend = MockBackend::new(vec![MockResponse::text("first")]);
425 let r1 = backend.generate(&[], &[], None).await.unwrap();
426 assert_eq!(r1.content, "first");
427
428 let r2 = backend.generate(&[], &[], None).await.unwrap();
429 assert_eq!(r2.content, ""); }
431}