1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use super::types::*;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TddLoopState {
22 pub task_id: String,
23 pub phase: TddPhase,
24 pub iteration: u32,
25 pub max_iterations: u32,
26
27 pub test_spec: Option<TestSpec>,
29 pub test_written: bool,
31 pub code_written: bool,
33
34 pub last_test_result: Option<TestResult>,
36 pub last_error: Option<String>,
38
39 pub started_at: DateTime<Utc>,
41 pub phase_durations: HashMap<String, u64>,
43}
44
45impl TddLoopState {
46 pub fn new(task_id: String) -> Self {
48 Self {
49 task_id,
50 phase: TddPhase::WriteTest,
51 iteration: 0,
52 max_iterations: 10,
53 test_spec: None,
54 test_written: false,
55 code_written: false,
56 last_test_result: None,
57 last_error: None,
58 started_at: Utc::now(),
59 phase_durations: HashMap::new(),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TddConfig {
71 pub max_iterations: u32,
73 pub test_timeout: u64,
75 pub auto_refactor: bool,
77 pub continue_on_red_failure: bool,
79}
80
81impl Default for TddConfig {
82 fn default() -> Self {
83 Self {
84 max_iterations: 10,
85 test_timeout: 60000,
86 auto_refactor: true,
87 continue_on_red_failure: true,
88 }
89 }
90}
91
92pub struct TddPrompts;
98
99impl TddPrompts {
100 pub fn write_test() -> &'static str {
102 r#"你现在处于 TDD 的「编写测试」阶段。
103
104请根据任务描述编写测试代码:
1051. 测试应该覆盖主要功能和边界情况
1062. 测试应该是失败的(因为还没有实现代码)
1073. 使用清晰的测试描述和断言
108
109输出格式:
110```
111// 测试代码
112```"#
113 }
114
115 pub fn run_test_red() -> &'static str {
117 r#"你现在处于 TDD 的「红灯」阶段。
118
119请运行测试并确认测试失败:
1201. 执行测试命令
1212. 确认测试失败(这是预期的)
1223. 记录失败信息
123
124如果测试意外通过,说明测试可能有问题。"#
125 }
126
127 pub fn write_code() -> &'static str {
129 r#"你现在处于 TDD 的「编写代码」阶段。
130
131请编写最小可行代码使测试通过:
1321. 只编写让测试通过的代码
1332. 不要过度设计
1343. 专注于当前测试
135
136输出格式:
137### 文件:path/to/file.rs
138```rust
139// 代码内容
140```"#
141 }
142
143 pub fn run_test_green() -> &'static str {
145 r#"你现在处于 TDD 的「绿灯」阶段。
146
147请运行测试并确认测试通过:
1481. 执行测试命令
1492. 确认所有测试通过
1503. 如果测试失败,返回「编写代码」阶段"#
151 }
152
153 pub fn refactor() -> &'static str {
155 r#"你现在处于 TDD 的「重构」阶段。
156
157请在保持测试通过的前提下优化代码:
1581. 消除重复代码
1592. 改善命名
1603. 简化逻辑
1614. 提高可读性
162
163重构后再次运行测试确认通过。"#
164 }
165
166 pub fn get_prompt(phase: TddPhase) -> &'static str {
168 match phase {
169 TddPhase::WriteTest => Self::write_test(),
170 TddPhase::RunTestRed => Self::run_test_red(),
171 TddPhase::WriteCode => Self::write_code(),
172 TddPhase::RunTestGreen => Self::run_test_green(),
173 TddPhase::Refactor => Self::refactor(),
174 TddPhase::Done => "TDD 循环已完成。",
175 }
176 }
177}
178
179pub struct TddExecutor {
185 config: TddConfig,
186 active_loops: HashMap<String, TddLoopState>,
188}
189
190impl Default for TddExecutor {
191 fn default() -> Self {
192 Self::new(TddConfig::default())
193 }
194}
195
196impl TddExecutor {
197 pub fn new(config: TddConfig) -> Self {
199 Self {
200 config,
201 active_loops: HashMap::new(),
202 }
203 }
204
205 pub fn start_loop(&mut self, task_id: String) -> &TddLoopState {
207 let mut state = TddLoopState::new(task_id.clone());
208 state.max_iterations = self.config.max_iterations;
209 self.active_loops.insert(task_id.clone(), state);
210 self.active_loops.get(&task_id).unwrap()
211 }
212
213 pub fn is_in_loop(&self, task_id: &str) -> bool {
215 self.active_loops.contains_key(task_id)
216 }
217
218 pub fn get_loop_state(&self, task_id: &str) -> Option<&TddLoopState> {
220 self.active_loops.get(task_id)
221 }
222
223 pub fn get_loop_state_mut(&mut self, task_id: &str) -> Option<&mut TddLoopState> {
225 self.active_loops.get_mut(task_id)
226 }
227
228 pub fn end_loop(&mut self, task_id: &str) -> Option<TddLoopState> {
230 self.active_loops.remove(task_id)
231 }
232
233 pub fn advance_phase(&mut self, task_id: &str) -> Result<TddPhase, String> {
235 let state = self
236 .active_loops
237 .get_mut(task_id)
238 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
239
240 let next_phase = match state.phase {
241 TddPhase::WriteTest => TddPhase::RunTestRed,
242 TddPhase::RunTestRed => TddPhase::WriteCode,
243 TddPhase::WriteCode => TddPhase::RunTestGreen,
244 TddPhase::RunTestGreen => {
245 if let Some(ref result) = state.last_test_result {
247 if result.passed {
248 TddPhase::Refactor
249 } else {
250 state.iteration += 1;
252 if state.iteration >= state.max_iterations {
253 return Err(format!(
254 "任务 {} 达到最大迭代次数 {}",
255 task_id, state.max_iterations
256 ));
257 }
258 TddPhase::WriteCode
259 }
260 } else {
261 TddPhase::WriteCode
262 }
263 }
264 TddPhase::Refactor => TddPhase::Done,
265 TddPhase::Done => TddPhase::Done,
266 };
267
268 state.phase = next_phase;
269 Ok(next_phase)
270 }
271
272 pub fn record_test_result(&mut self, task_id: &str, result: TestResult) -> Result<(), String> {
274 let state = self
275 .active_loops
276 .get_mut(task_id)
277 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
278
279 state.last_test_result = Some(result);
280 Ok(())
281 }
282
283 pub fn record_error(&mut self, task_id: &str, error: String) -> Result<(), String> {
285 let state = self
286 .active_loops
287 .get_mut(task_id)
288 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
289
290 state.last_error = Some(error);
291 Ok(())
292 }
293
294 pub fn set_test_spec(&mut self, task_id: &str, spec: TestSpec) -> Result<(), String> {
296 let state = self
297 .active_loops
298 .get_mut(task_id)
299 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
300
301 state.test_spec = Some(spec);
302 state.test_written = true;
303 Ok(())
304 }
305
306 pub fn mark_code_written(&mut self, task_id: &str) -> Result<(), String> {
308 let state = self
309 .active_loops
310 .get_mut(task_id)
311 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
312
313 state.code_written = true;
314 Ok(())
315 }
316
317 pub fn get_current_prompt(&self, task_id: &str) -> Option<&'static str> {
319 self.active_loops
320 .get(task_id)
321 .map(|state| TddPrompts::get_prompt(state.phase))
322 }
323
324 pub fn can_skip_write_test(&self, task_id: &str, has_acceptance_tests: bool) -> bool {
327 if let Some(state) = self.active_loops.get(task_id) {
328 state.phase == TddPhase::WriteTest && has_acceptance_tests
329 } else {
330 false
331 }
332 }
333
334 pub fn skip_write_test(&mut self, task_id: &str) -> Result<(), String> {
336 let state = self
337 .active_loops
338 .get_mut(task_id)
339 .ok_or_else(|| format!("任务 {} 不在 TDD 循环中", task_id))?;
340
341 if state.phase != TddPhase::WriteTest {
342 return Err("只能在 WriteTest 阶段跳过".to_string());
343 }
344
345 state.phase = TddPhase::RunTestRed;
346 state.test_written = true;
347 Ok(())
348 }
349
350 pub fn get_active_loops(&self) -> Vec<&TddLoopState> {
352 self.active_loops.values().collect()
353 }
354
355 pub fn get_config(&self) -> &TddConfig {
357 &self.config
358 }
359
360 pub fn update_config(&mut self, config: TddConfig) {
362 self.config = config;
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_tdd_executor_creation() {
372 let executor = TddExecutor::default();
373 assert_eq!(executor.config.max_iterations, 10);
374 assert!(executor.active_loops.is_empty());
375 }
376
377 #[test]
378 fn test_start_loop() {
379 let mut executor = TddExecutor::default();
380 let state = executor.start_loop("task-1".to_string());
381
382 assert_eq!(state.task_id, "task-1");
383 assert_eq!(state.phase, TddPhase::WriteTest);
384 assert_eq!(state.iteration, 0);
385 }
386
387 #[test]
388 fn test_advance_phase() {
389 let mut executor = TddExecutor::default();
390 executor.start_loop("task-1".to_string());
391
392 let phase = executor.advance_phase("task-1").unwrap();
394 assert_eq!(phase, TddPhase::RunTestRed);
395
396 let phase = executor.advance_phase("task-1").unwrap();
398 assert_eq!(phase, TddPhase::WriteCode);
399 }
400
401 #[test]
402 fn test_tdd_prompts() {
403 assert!(!TddPrompts::write_test().is_empty());
404 assert!(!TddPrompts::run_test_red().is_empty());
405 assert!(!TddPrompts::write_code().is_empty());
406 assert!(!TddPrompts::run_test_green().is_empty());
407 assert!(!TddPrompts::refactor().is_empty());
408 }
409}