1use chrono::Utc;
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use uuid::Uuid;
13
14use super::boundary_checker::{create_boundary_checker, BoundaryChecker};
15use super::types::{AcceptanceTest, ArtifactType, Blueprint, TaskNode, TddPhase, TestResult};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum TestFramework {
24 #[default]
26 Cargo,
27 Vitest,
29 Jest,
31 Mocha,
33 Pytest,
35}
36
37impl TestFramework {
38 pub fn get_test_command(&self, test_file: &str) -> String {
40 match self {
41 Self::Cargo => format!("cargo test --lib -- {}", test_file),
42 Self::Vitest => format!("npx vitest run {}", test_file),
43 Self::Jest => format!("npx jest {}", test_file),
44 Self::Mocha => format!("npx mocha {}", test_file),
45 Self::Pytest => format!("pytest {}", test_file),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct WorkerExecutorConfig {
53 pub model: String,
55 pub max_tokens: u32,
57 pub temperature: f32,
59 pub project_root: PathBuf,
61 pub test_framework: TestFramework,
63 pub test_timeout: u64,
65 pub debug: bool,
67}
68
69impl Default for WorkerExecutorConfig {
70 fn default() -> Self {
71 Self {
72 model: "claude-3-haiku".to_string(),
73 max_tokens: 8000,
74 temperature: 0.3,
75 project_root: std::env::current_dir().unwrap_or_default(),
76 test_framework: TestFramework::default(),
77 test_timeout: 60000,
78 debug: false,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
89pub struct CodeSnippet {
90 pub file_path: String,
91 pub content: String,
92}
93
94#[derive(Debug, Clone)]
96pub struct ExecutionContext {
97 pub task: TaskNode,
99 pub project_context: Option<String>,
101 pub code_snippets: Vec<CodeSnippet>,
103 pub last_error: Option<String>,
105 pub test_code: Option<String>,
107 pub acceptance_tests: Vec<AcceptanceTest>,
109}
110
111impl ExecutionContext {
112 pub fn new(task: TaskNode) -> Self {
114 Self {
115 task,
116 project_context: None,
117 code_snippets: Vec::new(),
118 last_error: None,
119 test_code: None,
120 acceptance_tests: Vec::new(),
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
131pub struct CodeArtifactOutput {
132 pub file_path: String,
133 pub content: String,
134}
135
136#[derive(Debug, Clone)]
138pub struct PhaseResult {
139 pub success: bool,
141 pub data: HashMap<String, serde_json::Value>,
143 pub error: Option<String>,
145 pub artifacts: Vec<CodeArtifactOutput>,
147 pub test_result: Option<TestResult>,
149}
150
151impl PhaseResult {
152 pub fn success() -> Self {
154 Self {
155 success: true,
156 data: HashMap::new(),
157 error: None,
158 artifacts: Vec::new(),
159 test_result: None,
160 }
161 }
162
163 pub fn failure(error: impl Into<String>) -> Self {
165 Self {
166 success: false,
167 data: HashMap::new(),
168 error: Some(error.into()),
169 artifacts: Vec::new(),
170 test_result: None,
171 }
172 }
173
174 pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
176 self.data.insert(key.into(), value);
177 self
178 }
179
180 pub fn with_artifact(mut self, file_path: String, content: String) -> Self {
182 self.artifacts
183 .push(CodeArtifactOutput { file_path, content });
184 self
185 }
186
187 pub fn with_test_result(mut self, result: TestResult) -> Self {
189 self.test_result = Some(result);
190 self
191 }
192}
193
194pub struct WorkerExecutor {
202 config: WorkerExecutorConfig,
203 boundary_checker: Option<BoundaryChecker>,
204 current_task_module_id: Option<String>,
205}
206
207impl WorkerExecutor {
208 pub fn new(config: WorkerExecutorConfig) -> Self {
210 Self {
211 config,
212 boundary_checker: None,
213 current_task_module_id: None,
214 }
215 }
216
217 pub fn set_blueprint(&mut self, blueprint: &Blueprint) {
219 self.boundary_checker = Some(create_boundary_checker(blueprint.clone(), None));
220 }
221
222 pub fn set_current_task_module(&mut self, module_id: Option<String>) {
224 self.current_task_module_id = module_id;
225 }
226
227 pub async fn execute_phase(&self, phase: TddPhase, context: &ExecutionContext) -> PhaseResult {
233 self.log(&format!("[Worker] 执行阶段: {:?}", phase));
234
235 match phase {
236 TddPhase::WriteTest => self.execute_write_test(context).await,
237 TddPhase::RunTestRed => self.execute_run_test_red(context).await,
238 TddPhase::WriteCode => self.execute_write_code(context).await,
239 TddPhase::RunTestGreen => self.execute_run_test_green(context).await,
240 TddPhase::Refactor => self.execute_refactor(context).await,
241 TddPhase::Done => PhaseResult::success()
242 .with_data("message".to_string(), serde_json::json!("TDD 循环完成")),
243 }
244 }
245
246 async fn execute_write_test(&self, context: &ExecutionContext) -> PhaseResult {
251 let task = &context.task;
252
253 if !task.acceptance_tests.is_empty() {
255 self.log("[Worker] 任务已有验收测试,跳过测试编写阶段");
256 return PhaseResult::success()
257 .with_data(
258 "message".to_string(),
259 serde_json::json!("任务已有验收测试,无需编写额外测试"),
260 )
261 .with_data(
262 "acceptance_test_count".to_string(),
263 serde_json::json!(task.acceptance_tests.len()),
264 );
265 }
266
267 let test_code = self.generate_test(task).await;
269
270 let test_file_path = self.determine_test_file_path(task);
272
273 if let Err(e) = self.save_file(&test_file_path, &test_code).await {
275 return PhaseResult::failure(format!("保存测试文件失败: {}", e));
276 }
277
278 let test_command = self.config.test_framework.get_test_command(&test_file_path);
279
280 PhaseResult::success()
281 .with_data("test_code".to_string(), serde_json::json!(test_code))
282 .with_data(
283 "test_file_path".to_string(),
284 serde_json::json!(test_file_path),
285 )
286 .with_data("test_command".to_string(), serde_json::json!(test_command))
287 .with_artifact(test_file_path, test_code)
288 }
289
290 async fn generate_test(&self, task: &TaskNode) -> String {
292 let _prompt = self.build_test_prompt(task);
293
294 format!(
297 r#"// 自动生成的测试代码
298// 任务: {}
299// 描述: {}
300
301#[cfg(test)]
302mod tests {{
303 use super::*;
304
305 #[test]
306 fn test_placeholder() {{
307 // TODO: 实现测试
308 assert!(true);
309 }}
310}}
311"#,
312 task.name, task.description
313 )
314 }
315
316 async fn execute_run_test_red(&self, context: &ExecutionContext) -> PhaseResult {
321 let task = &context.task;
322
323 if !context.acceptance_tests.is_empty() {
325 let mut results = Vec::new();
326
327 for test in &context.acceptance_tests {
328 let result = self.run_test(&test.test_file_path).await;
329 results.push(result);
330 }
331
332 let all_failed = results.iter().all(|r| !r.passed);
334
335 return PhaseResult::success()
336 .with_data("expected_to_fail".to_string(), serde_json::json!(true))
337 .with_data("actually_failed".to_string(), serde_json::json!(all_failed))
338 .with_test_result(results.into_iter().next().unwrap_or_else(|| TestResult {
339 id: Uuid::new_v4().to_string(),
340 timestamp: Utc::now(),
341 passed: false,
342 duration: 0,
343 output: String::new(),
344 error_message: None,
345 coverage: None,
346 details: None,
347 }));
348 }
349
350 if let Some(ref test_spec) = task.test_spec {
352 if let Some(ref test_file_path) = test_spec.test_file_path {
353 let result = self.run_test(test_file_path).await;
354
355 return PhaseResult::success()
356 .with_data("expected_to_fail".to_string(), serde_json::json!(true))
357 .with_data(
358 "actually_failed".to_string(),
359 serde_json::json!(!result.passed),
360 )
361 .with_test_result(result);
362 }
363 }
364
365 PhaseResult::failure("没有找到可运行的测试")
366 }
367
368 async fn execute_write_code(&self, context: &ExecutionContext) -> PhaseResult {
373 let task = &context.task;
374 let test_code = context.test_code.as_deref().unwrap_or("");
375 let last_error = context.last_error.as_deref();
376
377 let code_artifacts = self.generate_code(task, test_code, last_error).await;
379
380 let mut result = PhaseResult::success().with_data(
382 "file_count".to_string(),
383 serde_json::json!(code_artifacts.len()),
384 );
385
386 for artifact in code_artifacts {
387 if let Err(e) = self.save_file(&artifact.file_path, &artifact.content).await {
388 return PhaseResult::failure(format!("保存代码文件失败: {}", e));
389 }
390 result = result.with_artifact(artifact.file_path, artifact.content);
391 }
392
393 result
394 }
395
396 async fn generate_code(
398 &self,
399 task: &TaskNode,
400 test_code: &str,
401 last_error: Option<&str>,
402 ) -> Vec<CodeArtifactOutput> {
403 let _prompt = self.build_code_prompt(task, test_code, last_error);
404
405 vec![CodeArtifactOutput {
408 file_path: format!("src/{}.rs", task.id),
409 content: format!(
410 r#"//! 自动生成的实现代码
411//! 任务: {}
412//! 描述: {}
413
414pub fn placeholder() {{
415 // TODO: 实现功能
416}}
417"#,
418 task.name, task.description
419 ),
420 }]
421 }
422
423 async fn execute_run_test_green(&self, context: &ExecutionContext) -> PhaseResult {
428 let task = &context.task;
429
430 if !context.acceptance_tests.is_empty() {
432 let mut results = Vec::new();
433 let mut total_duration = 0u64;
434 let mut all_output = String::new();
435
436 for test in &context.acceptance_tests {
437 let result = self.run_test(&test.test_file_path).await;
438 total_duration += result.duration;
439 all_output.push_str(&result.output);
440 all_output.push_str("\n\n");
441 results.push(result);
442 }
443
444 let all_passed = results.iter().all(|r| r.passed);
445 let error_message = if all_passed {
446 None
447 } else {
448 Some(
449 results
450 .iter()
451 .filter(|r| !r.passed)
452 .filter_map(|r| r.error_message.clone())
453 .collect::<Vec<_>>()
454 .join("\n"),
455 )
456 };
457
458 return PhaseResult::success()
459 .with_data("expected_to_pass".to_string(), serde_json::json!(true))
460 .with_data("actually_passed".to_string(), serde_json::json!(all_passed))
461 .with_test_result(TestResult {
462 id: Uuid::new_v4().to_string(),
463 timestamp: Utc::now(),
464 passed: all_passed,
465 duration: total_duration,
466 output: all_output,
467 error_message,
468 coverage: None,
469 details: None,
470 });
471 }
472
473 if let Some(ref test_spec) = task.test_spec {
475 if let Some(ref test_file_path) = test_spec.test_file_path {
476 let result = self.run_test(test_file_path).await;
477
478 return PhaseResult::success()
479 .with_data("expected_to_pass".to_string(), serde_json::json!(true))
480 .with_data(
481 "actually_passed".to_string(),
482 serde_json::json!(result.passed),
483 )
484 .with_test_result(result);
485 }
486 }
487
488 PhaseResult::failure("没有找到可运行的测试")
489 }
490
491 async fn execute_refactor(&self, context: &ExecutionContext) -> PhaseResult {
496 let task = &context.task;
497
498 let current_code = self.read_task_code(task);
500
501 if current_code.is_empty() {
502 return PhaseResult::success().with_data(
503 "message".to_string(),
504 serde_json::json!("没有需要重构的代码"),
505 );
506 }
507
508 let refactored_artifacts = self.refactor_code(task, ¤t_code).await;
510
511 let mut result = PhaseResult::success().with_data(
513 "file_count".to_string(),
514 serde_json::json!(refactored_artifacts.len()),
515 );
516
517 for artifact in refactored_artifacts {
518 if let Err(e) = self.save_file(&artifact.file_path, &artifact.content).await {
519 return PhaseResult::failure(format!("保存重构代码失败: {}", e));
520 }
521 result = result.with_artifact(artifact.file_path, artifact.content);
522 }
523
524 result
525 }
526
527 async fn refactor_code(
529 &self,
530 task: &TaskNode,
531 current_code: &[CodeArtifactOutput],
532 ) -> Vec<CodeArtifactOutput> {
533 let _prompt = self.build_refactor_prompt(task, current_code);
534
535 current_code.to_vec()
538 }
539
540 async fn run_test(&self, test_file_path: &str) -> TestResult {
546 let start_time = std::time::Instant::now();
547 let command = self.config.test_framework.get_test_command(test_file_path);
548
549 let duration = start_time.elapsed().as_millis() as u64;
552
553 TestResult {
554 id: Uuid::new_v4().to_string(),
555 timestamp: Utc::now(),
556 passed: true, duration,
558 output: format!("运行测试: {}\n测试通过", command),
559 error_message: None,
560 coverage: None,
561 details: None,
562 }
563 }
564
565 fn build_test_prompt(&self, task: &TaskNode) -> String {
571 format!(
572 r#"# 任务:编写测试用例
573
574## 任务描述
575{}
576
577{}
578
579## 要求
5801. 使用 {:?} 测试框架
5812. 测试应该覆盖主要功能和边界情况
5823. 测试应该失败(因为还没有实现代码)
5834. 使用清晰的测试描述和断言
584
585## 输出格式
586请输出完整的测试代码,使用代码块包裹。
587只输出测试代码,不要包含其他说明文字。"#,
588 task.name, task.description, self.config.test_framework
589 )
590 }
591
592 fn build_code_prompt(
594 &self,
595 task: &TaskNode,
596 test_code: &str,
597 last_error: Option<&str>,
598 ) -> String {
599 let mut prompt = format!(
600 r#"# 任务:编写实现代码
601
602## 任务描述
603{}
604
605{}
606
607## 测试代码
608```
609{}
610```
611"#,
612 task.name, task.description, test_code
613 );
614
615 if let Some(error) = last_error {
616 prompt.push_str(&format!(
617 r#"
618## 上次测试错误
619```
620{}
621```
622
623请修复上述错误。
624"#,
625 error
626 ));
627 }
628
629 prompt.push_str(
630 r#"
631## 要求
6321. 编写最小可行代码使测试通过
6332. 不要过度设计
6343. 专注于当前测试
6354. 遵循项目代码风格
636
637## 输出格式
638请为每个文件输出代码,使用如下格式:
639
640### 文件:src/example.rs
641```rust
642// 代码内容
643```
644
645只输出代码文件,不要包含其他说明文字。"#,
646 );
647
648 prompt
649 }
650
651 fn build_refactor_prompt(
653 &self,
654 task: &TaskNode,
655 current_code: &[CodeArtifactOutput],
656 ) -> String {
657 let mut prompt = format!(
658 r#"# 任务:重构代码
659
660## 任务描述
661{}
662
663## 当前代码
664"#,
665 task.name
666 );
667
668 for file in current_code {
669 prompt.push_str(&format!(
670 r#"
671### 文件:{}
672```rust
673{}
674```
675"#,
676 file.file_path, file.content
677 ));
678 }
679
680 prompt.push_str(
681 r#"
682## 重构建议
6831. 消除重复代码
6842. 改善命名
6853. 简化逻辑
6864. 提高可读性
6875. 确保测试仍然通过
688
689## 输出格式
690请为每个需要修改的文件输出重构后的代码。
691如果某个文件不需要重构,不用输出。
692只输出代码文件,不要包含其他说明文字。"#,
693 );
694
695 prompt
696 }
697
698 fn determine_test_file_path(&self, task: &TaskNode) -> String {
704 if let Some(ref test_spec) = task.test_spec {
706 if let Some(ref path) = test_spec.test_file_path {
707 return path.clone();
708 }
709 }
710
711 match self.config.test_framework {
713 TestFramework::Cargo => format!("tests/{}_test.rs", task.id),
714 TestFramework::Vitest | TestFramework::Jest => {
715 format!("__tests__/{}.test.ts", task.id)
716 }
717 TestFramework::Mocha => format!("test/{}.test.js", task.id),
718 TestFramework::Pytest => format!("tests/test_{}.py", task.id),
719 }
720 }
721
722 fn read_task_code(&self, task: &TaskNode) -> Vec<CodeArtifactOutput> {
724 task.code_artifacts
725 .iter()
726 .filter_map(|artifact| {
727 if artifact.artifact_type == ArtifactType::File {
728 Some(CodeArtifactOutput {
729 file_path: artifact.file_path.clone().unwrap_or_default(),
730 content: artifact.content.clone().unwrap_or_default(),
731 })
732 } else {
733 None
734 }
735 })
736 .collect()
737 }
738
739 async fn save_file(&self, file_path: &str, content: &str) -> Result<(), String> {
741 let full_path = if Path::new(file_path).is_absolute() {
742 PathBuf::from(file_path)
743 } else {
744 self.config.project_root.join(file_path)
745 };
746
747 if let Some(ref checker) = self.boundary_checker {
749 let result = checker.check_task_boundary(
750 self.current_task_module_id.as_deref(),
751 full_path.to_str().unwrap_or(""),
752 );
753 if !result.allowed {
754 return Err(format!(
755 "[边界检查失败] {}",
756 result.reason.unwrap_or_default()
757 ));
758 }
759 }
760
761 if let Some(parent) = full_path.parent() {
763 std::fs::create_dir_all(parent).map_err(|e| format!("创建目录失败: {}", e))?;
764 }
765
766 std::fs::write(&full_path, content).map_err(|e| format!("写入文件失败: {}", e))?;
768
769 self.log(&format!("[Worker] 保存文件: {}", file_path));
770 Ok(())
771 }
772
773 fn log(&self, message: &str) {
775 if self.config.debug {
776 println!("{}", message);
777 }
778 }
779
780 pub fn set_model(&mut self, model: impl Into<String>) {
786 self.config.model = model.into();
787 }
788
789 pub fn set_project_root(&mut self, project_root: PathBuf) {
791 self.config.project_root = project_root;
792 }
793
794 pub fn set_test_framework(&mut self, framework: TestFramework) {
796 self.config.test_framework = framework;
797 }
798
799 pub fn config(&self) -> &WorkerExecutorConfig {
801 &self.config
802 }
803}
804
805impl Default for WorkerExecutor {
806 fn default() -> Self {
807 Self::new(WorkerExecutorConfig::default())
808 }
809}
810
811pub fn create_worker_executor(config: WorkerExecutorConfig) -> WorkerExecutor {
817 WorkerExecutor::new(config)
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_worker_executor_config_default() {
826 let config = WorkerExecutorConfig::default();
827 assert_eq!(config.model, "claude-3-haiku");
828 assert_eq!(config.max_tokens, 8000);
829 assert_eq!(config.test_framework, TestFramework::Cargo);
830 }
831
832 #[test]
833 fn test_test_framework_command() {
834 assert!(TestFramework::Cargo
835 .get_test_command("test_file")
836 .contains("cargo test"));
837 assert!(TestFramework::Vitest
838 .get_test_command("test_file")
839 .contains("vitest"));
840 }
841
842 #[test]
843 fn test_phase_result_builder() {
844 let result = PhaseResult::success()
845 .with_data("key".to_string(), serde_json::json!("value"))
846 .with_artifact("file.rs".to_string(), "content".to_string());
847
848 assert!(result.success);
849 assert_eq!(result.artifacts.len(), 1);
850 assert!(result.data.contains_key("key"));
851 }
852}