agcodex_core/subagents/built_in/
test_writer.rs1use crate::code_tools::ast_agent_tools::ASTAgentTools;
11use crate::code_tools::ast_agent_tools::AgentToolOp;
12use crate::code_tools::ast_agent_tools::AgentToolResult;
13use crate::modes::OperatingMode;
14use crate::subagents::AgentResult;
15use crate::subagents::AgentStatus;
16use crate::subagents::Finding;
17use crate::subagents::Severity;
18use crate::subagents::Subagent;
19use crate::subagents::SubagentContext;
20use crate::subagents::SubagentError;
21use crate::subagents::SubagentResult;
22use std::collections::HashMap;
23use std::future::Future;
24use std::path::Path;
25use std::path::PathBuf;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::atomic::AtomicBool;
29use std::sync::atomic::Ordering;
30use std::time::Duration;
31use std::time::SystemTime;
32
33#[derive(Debug)]
35pub struct TestWriterAgent {
36 name: String,
37 description: String,
38 _mode_override: Option<OperatingMode>,
39 _tool_permissions: Vec<String>,
40 _prompt_template: String,
41 test_strategy: TestStrategy,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum TestStrategy {
47 Basic, Comprehensive, PropertyBased, }
51
52impl Default for TestWriterAgent {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl TestWriterAgent {
59 pub fn new() -> Self {
61 Self {
62 name: "test-writer".to_string(),
63 description: "Generates comprehensive test suites with high coverage".to_string(),
64 _mode_override: Some(OperatingMode::Build),
65 _tool_permissions: vec![
66 "search".to_string(),
67 "edit".to_string(),
68 "tree".to_string(),
69 "think".to_string(),
70 ],
71 _prompt_template: r#"
72You are an expert test engineer focused on:
73- Achieving high code coverage (>90%)
74- Testing edge cases and error conditions
75- Creating maintainable test suites
76- Using appropriate testing patterns
77- Generating realistic test data
78
79Write tests that are:
80- Isolated and independent
81- Fast and deterministic
82- Clear and well-documented
83- Comprehensive yet focused
84"#
85 .to_string(),
86 test_strategy: TestStrategy::Comprehensive,
87 }
88 }
89
90 pub const fn with_strategy(mut self, strategy: TestStrategy) -> Self {
92 self.test_strategy = strategy;
93 self
94 }
95
96 async fn analyze_coverage(&self, ast_tools: &mut ASTAgentTools, file: &Path) -> Vec<Finding> {
98 let mut findings = Vec::new();
99
100 if let Ok(AgentToolResult::Functions(functions)) =
102 ast_tools.execute(AgentToolOp::ExtractFunctions {
103 file: file.to_path_buf(),
104 language: self.detect_language(file),
105 })
106 {
107 for func in functions {
108 let test_pattern = format!("test.*{}", func.name);
110 if let Ok(AgentToolResult::SearchResults(results)) =
111 ast_tools.execute(AgentToolOp::Search {
112 query: test_pattern,
113 scope: crate::code_tools::search::SearchScope::Directory(
114 file.parent().unwrap_or(Path::new(".")).to_path_buf(),
115 ),
116 })
117 && results.is_empty()
118 {
119 findings.push(Finding {
120 category: "test-coverage".to_string(),
121 severity: Severity::Medium,
122 title: format!("Missing Tests: {}", func.name),
123 description: format!(
124 "Function '{}' has no test coverage. This could lead to undetected bugs.",
125 func.name
126 ),
127 location: Some(crate::code_tools::ast_agent_tools::Location {
128 file: file.to_path_buf(),
129 line: func.start_line,
130 column: 0,
131 byte_offset: 0,
132 }),
133 suggestion: Some("Create unit tests covering normal, edge, and error cases".to_string()),
134 metadata: HashMap::from([
135 ("function_name".to_string(), serde_json::json!(func.name)),
136 ("needs_tests".to_string(), serde_json::json!(true)),
137 ]),
138 });
139 }
140 }
141 }
142
143 findings
144 }
145
146 async fn generate_test_cases(
148 &self,
149 _ast_tools: &mut ASTAgentTools,
150 function_name: &str,
151 file: &Path,
152 ) -> String {
153 let mut test_code = String::new();
154 let lang = self.detect_language(file);
155
156 match lang.as_str() {
158 "rust" => {
159 test_code.push_str(&format!(
160 r#"
161#[cfg(test)]
162mod test_{} {{
163 use super::*;
164
165 #[test]
166 fn test_{}_normal_case() {{
167 // Arrange
168 let input = /* TODO: Add test input */;
169
170 // Act
171 let result = {}(input);
172
173 // Assert
174 assert_eq!(result, /* expected value */);
175 }}
176
177 #[test]
178 fn test_{}_edge_case() {{
179 // Test with boundary values
180 let edge_input = /* TODO: Add edge case input */;
181 let result = {}(edge_input);
182 assert!(/* validation */);
183 }}
184
185 #[test]
186 #[should_panic(expected = "error message")]
187 fn test_{}_error_case() {{
188 // Test error handling
189 let invalid_input = /* TODO: Add invalid input */;
190 {}(invalid_input); // Should panic
191 }}
192}}"#,
193 function_name,
194 function_name,
195 function_name,
196 function_name,
197 function_name,
198 function_name,
199 function_name
200 ));
201 }
202 "python" => {
203 test_code.push_str(&format!(
204 r#"
205import unittest
206from unittest.mock import Mock, patch
207
208class Test{}(unittest.TestCase):
209
210 def test_{}_normal_case(self):
211 # Arrange
212 input_data = # TODO: Add test input
213
214 # Act
215 result = {}(input_data)
216
217 # Assert
218 self.assertEqual(result, # expected value)
219
220 def test_{}_edge_case(self):
221 # Test with boundary values
222 edge_input = # TODO: Add edge case input
223 result = {}(edge_input)
224 self.assertTrue(# validation)
225
226 def test_{}_error_case(self):
227 # Test error handling
228 invalid_input = # TODO: Add invalid input
229 with self.assertRaises(Exception):
230 {}(invalid_input)
231
232if __name__ == '__main__':
233 unittest.main()"#,
234 to_pascal_case(function_name),
235 function_name,
236 function_name,
237 function_name,
238 function_name,
239 function_name,
240 function_name
241 ));
242 }
243 "javascript" | "typescript" => {
244 test_code.push_str(&format!(
245 r#"
246describe('{}', () => {{
247
248 test('should handle normal case', () => {{
249 // Arrange
250 const input = /* TODO: Add test input */;
251
252 // Act
253 const result = {}(input);
254
255 // Assert
256 expect(result).toBe(/* expected value */);
257 }});
258
259 test('should handle edge case', () => {{
260 // Test with boundary values
261 const edgeInput = /* TODO: Add edge case input */;
262 const result = {}(edgeInput);
263 expect(result).toBeTruthy();
264 }});
265
266 test('should throw error for invalid input', () => {{
267 // Test error handling
268 const invalidInput = /* TODO: Add invalid input */;
269 expect(() => {{
270 {}(invalidInput);
271 }}).toThrow('error message');
272 }});
273}});"#,
274 function_name, function_name, function_name, function_name
275 ));
276 }
277 _ => {
278 test_code.push_str(&format!(
279 "// TODO: Generate tests for function '{}'\n",
280 function_name
281 ));
282 }
283 }
284
285 test_code
286 }
287
288 fn detect_language(&self, file: &Path) -> String {
290 file.extension()
291 .and_then(|ext| ext.to_str())
292 .map(|ext| match ext {
293 "rs" => "rust",
294 "py" => "python",
295 "js" => "javascript",
296 "ts" => "typescript",
297 "go" => "go",
298 "java" => "java",
299 _ => "text",
300 })
301 .unwrap_or("text")
302 .to_string()
303 }
304}
305
306impl Subagent for TestWriterAgent {
307 fn name(&self) -> &str {
308 &self.name
309 }
310
311 fn description(&self) -> &str {
312 &self.description
313 }
314
315 fn execute<'a>(
316 &'a self,
317 context: &'a SubagentContext,
318 ast_tools: &'a mut ASTAgentTools,
319 cancel_flag: Arc<AtomicBool>,
320 ) -> Pin<Box<dyn Future<Output = SubagentResult<AgentResult>> + Send + 'a>> {
321 Box::pin(async move {
322 let start_time = SystemTime::now();
323 let mut all_findings = Vec::new();
324 let mut analyzed_files = Vec::new();
325 let mut modified_files = Vec::new();
326 let mut tests_generated = 0;
327
328 let files = self.get_test_targets(context)?;
330
331 for file in &files {
332 if cancel_flag.load(Ordering::Acquire) {
333 return Err(SubagentError::ExecutionFailed(
334 "Test generation cancelled".to_string(),
335 ));
336 }
337
338 analyzed_files.push(file.clone());
339
340 let coverage_findings = self.analyze_coverage(ast_tools, file).await;
342
343 for finding in &coverage_findings {
345 if let Some(function_name) = finding.metadata.get("function_name")
346 && let Some(name) = function_name.as_str()
347 {
348 let test_code = self.generate_test_cases(ast_tools, name, file).await;
349
350 if context.mode == OperatingMode::Build && !test_code.is_empty() {
352 let test_file = self.get_test_file_path(file);
353 modified_files.push(test_file);
355 tests_generated += 1;
356 }
357 }
358 }
359
360 all_findings.extend(coverage_findings);
361 }
362
363 let summary = format!(
364 "Test generation completed: {} files analyzed, {} missing tests found, {} test files generated",
365 analyzed_files.len(),
366 all_findings.len(),
367 tests_generated
368 );
369
370 let missing_tests = all_findings.len();
372
373 let execution_time = SystemTime::now()
374 .duration_since(start_time)
375 .unwrap_or_else(|_| Duration::from_secs(0));
376
377 Ok(AgentResult {
378 agent_name: self.name.clone(),
379 status: AgentStatus::Completed,
380 findings: all_findings,
381 analyzed_files,
382 modified_files,
383 execution_time,
384 summary,
385 metrics: HashMap::from([
386 (
387 "missing_tests".to_string(),
388 serde_json::json!(missing_tests),
389 ),
390 (
391 "tests_generated".to_string(),
392 serde_json::json!(tests_generated),
393 ),
394 (
395 "test_strategy".to_string(),
396 serde_json::json!(format!("{:?}", self.test_strategy)),
397 ),
398 ]),
399 })
400 })
401 }
402
403 fn capabilities(&self) -> Vec<String> {
404 vec![
405 "test-generation".to_string(),
406 "coverage-analysis".to_string(),
407 "edge-case-generation".to_string(),
408 "mock-generation".to_string(),
409 "test-data-generation".to_string(),
410 ]
411 }
412
413 fn supports_file_type(&self, file_path: &Path) -> bool {
414 let supported = ["rs", "py", "js", "ts", "go", "java"];
415 file_path
416 .extension()
417 .and_then(|ext| ext.to_str())
418 .map(|ext| supported.contains(&ext))
419 .unwrap_or(false)
420 }
421
422 fn execution_time_estimate(&self) -> Duration {
423 match self.test_strategy {
424 TestStrategy::Basic => Duration::from_secs(45),
425 TestStrategy::Comprehensive => Duration::from_secs(90),
426 TestStrategy::PropertyBased => Duration::from_secs(120),
427 }
428 }
429}
430
431impl TestWriterAgent {
432 fn get_test_targets(&self, context: &SubagentContext) -> Result<Vec<PathBuf>, SubagentError> {
433 if let Some(files) = context.parameters.get("files") {
434 Ok(files.split(',').map(|s| PathBuf::from(s.trim())).collect())
435 } else {
436 Ok(vec![context.working_directory.clone()])
437 }
438 }
439
440 fn get_test_file_path(&self, source_file: &Path) -> PathBuf {
441 let stem = source_file.file_stem().unwrap_or_default();
442 let ext = source_file.extension().unwrap_or_default();
443 let parent = source_file.parent().unwrap_or(Path::new("."));
444
445 parent.join(format!(
446 "{}_test.{}",
447 stem.to_string_lossy(),
448 ext.to_string_lossy()
449 ))
450 }
451}
452
453fn to_pascal_case(s: &str) -> String {
455 s.split('_')
456 .map(|word| {
457 let mut chars = word.chars();
458 match chars.next() {
459 None => String::new(),
460 Some(first) => first.to_uppercase().chain(chars).collect(),
461 }
462 })
463 .collect()
464}