atomcode_core/agent/
subtask_driver.rs1use std::collections::HashSet;
13
14#[derive(Debug, Clone)]
16pub struct Subtask {
17 pub file: String, pub done: bool,
19}
20
21#[derive(Debug, Clone)]
23pub struct SubtaskDriver {
24 pub subtasks: Vec<Subtask>,
25 pub current_idx: usize,
26 pub active: bool,
27}
28
29impl SubtaskDriver {
30 pub fn new() -> Self {
31 Self {
32 subtasks: Vec::new(),
33 current_idx: 0,
34 active: false,
35 }
36 }
37
38 pub fn extract_from_plan(&mut self, plan_text: &str) {
43 let mut files = Vec::new();
44 let mut seen = HashSet::new();
45
46 let reference_files = extract_reference_files(plan_text);
52
53 for word in plan_text.split(|c: char| {
61 c.is_whitespace()
62 || c == ','
63 || c == '`'
64 || c == '"'
65 || c == '\''
66 || c == '('
67 || c == ')'
68 || c == '['
69 || c == ']'
70 || c == '\u{FF0C}' || c == '\u{3002}' || c == '\u{3001}' || c == '\u{FF1B}' || c == '\u{FF1A}' || c == '\u{FF08}' || c == '\u{FF09}' || c == '\u{300A}' || c == '\u{300B}' || c == '\u{300C}' || c == '\u{300D}' || c == '\u{FF1F}' || c == '\u{FF01}' || c == '\u{2014}' }) {
86 let trimmed = word
87 .trim()
88 .trim_matches(|c: char| c == '`' || c == '*' || c == ':');
89 if trimmed.is_empty() {
90 continue;
91 }
92
93 if is_source_file(trimmed) {
94 let file_name = trimmed.rsplit('/').next().unwrap_or(trimmed);
95 if !file_name.is_empty()
96 && seen.insert(file_name.to_string())
97 && !reference_files.contains(file_name)
98 {
99 files.push(file_name.to_string());
100 }
101 }
102 }
103
104 if files.is_empty() {
105 self.active = false;
106 return;
107 }
108
109 files.sort_by(|a, b| {
111 let a_backend = a.ends_with(".java")
112 || a.ends_with(".py")
113 || a.ends_with(".go")
114 || a.ends_with(".rs");
115 let b_backend = b.ends_with(".java")
116 || b.ends_with(".py")
117 || b.ends_with(".go")
118 || b.ends_with(".rs");
119 b_backend.cmp(&a_backend) });
121
122 self.subtasks = files
123 .into_iter()
124 .map(|f| Subtask {
125 file: f,
126 done: false,
127 })
128 .collect();
129 self.current_idx = 0;
130 self.active = true;
131 }
132
133 pub fn current_instruction(&self) -> Option<String> {
136 if !self.active {
137 return None;
138 }
139 let task = self.subtasks.get(self.current_idx)?;
140 if task.done {
141 return None;
142 }
143
144 let total = self.subtasks.len();
145 let remaining: Vec<&str> = self.subtasks[self.current_idx + 1..]
146 .iter()
147 .filter(|t| !t.done)
148 .map(|t| t.file.as_str())
149 .collect();
150
151 let next_hint = if remaining.is_empty() {
152 "This is the last file.".to_string()
153 } else {
154 format!("After this: {}", remaining.join(", "))
155 };
156
157 Some(format!(
158 "[Subtask {}/{}: Edit {} \u{2014} make ALL needed changes in ONE edit. {}]",
159 self.current_idx + 1,
160 total,
161 task.file,
162 next_hint,
163 ))
164 }
165
166 pub fn advance(&mut self) {
168 if let Some(task) = self.subtasks.get_mut(self.current_idx) {
169 task.done = true;
170 }
171 self.current_idx += 1;
172 if self.current_idx >= self.subtasks.len() {
173 self.active = false;
174 }
175 }
176
177 pub fn matches_current(&self, edited_file: &str) -> bool {
179 if let Some(task) = self.subtasks.get(self.current_idx) {
180 edited_file.contains(&task.file) || task.file.contains(edited_file)
181 } else {
182 false
183 }
184 }
185
186 pub fn all_done(&self) -> bool {
188 self.subtasks.iter().all(|t| t.done)
189 }
190}
191
192fn is_source_file(s: &str) -> bool {
194 s.ends_with(".java")
195 || s.ends_with(".vue")
196 || s.ends_with(".ts")
197 || s.ends_with(".tsx")
198 || s.ends_with(".py")
199 || s.ends_with(".rs")
200 || s.ends_with(".go")
201 || s.ends_with(".js")
202 || s.ends_with(".svelte")
203}
204
205fn extract_reference_files(plan_text: &str) -> HashSet<String> {
208 let mut refs = HashSet::new();
209
210 let ref_kw: &[&str] = &[
211 "\u{53C2}\u{8003}", "\u{53C2}\u{7167}", "\u{4EFF}\u{7167}", "\u{7C7B}\u{4F3C}", "reference",
216 "following",
217 "same as",
218 "style of",
219 "follow",
220 ];
221 let modify_kw: &[&str] = &[
222 "\u{4FEE}\u{6539}", "\u{7F16}\u{8F91}", "\u{66F4}\u{65B0}", "\u{6DFB}\u{52A0}", "\u{5B9E}\u{73B0}", "\u{6539}", "modify",
229 "edit",
230 "update",
231 "add",
232 "change",
233 "implement",
234 ];
235
236 for line in plan_text.lines() {
237 let lower = line.to_lowercase();
238 let has_ref = ref_kw.iter().any(|k| lower.contains(k));
239 if !has_ref {
240 continue;
241 }
242
243 let modify_pos = modify_kw.iter().filter_map(|k| lower.find(k)).min();
245
246 let ref_portion = match modify_pos {
248 Some(pos) => &line[..pos],
249 None => line,
250 };
251
252 for word in ref_portion.split(|c: char| {
254 c.is_whitespace()
255 || c == ','
256 || c == '`'
257 || c == '"'
258 || c == '\''
259 || c == '('
260 || c == ')'
261 || c == '\u{FF0C}'
262 }) {
263 let trimmed = word
264 .trim()
265 .trim_matches(|c: char| c == '`' || c == '*' || c == ':');
266 if is_source_file(trimmed) {
267 let file_name = trimmed.rsplit('/').next().unwrap_or(trimmed);
268 refs.insert(file_name.to_string());
269 }
270 }
271 }
272
273 refs
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
287 fn extract_handles_chinese_punctuation_separators() {
288 let plan = "\u{73B0}\u{5728}\u{9010}\u{4E00}\u{5904}\u{7406} 4 \u{4E2A}\u{6587}\u{4EF6}\u{3002}\u{5148}\u{5904}\u{7406} constants.rs \u{548C} types.rs\u{FF0C}\u{5B83}\u{4EEC}\u{5DF2}\u{7ECF}\u{6709}\u{4E00}\u{4E9B}\u{4E2D}\u{6587}\u{6CE8}\u{91CA}\u{4F46}\u{4E0D}\u{591F}\u{5B8C}\u{6574}\u{3002}platform.rs \u{548C} mod.rs \u{4E5F}\u{9700}\u{8981}\u{8865}\u{5168}\u{3002}";
289
290 let mut driver = SubtaskDriver::new();
291 driver.extract_from_plan(plan);
292
293 assert!(driver.active);
294 assert_eq!(driver.subtasks.len(), 4, "expected 4 .rs files extracted, got: {:?}", driver.subtasks);
295 let names: Vec<&str> = driver.subtasks.iter().map(|s| s.file.as_str()).collect();
296 assert!(names.contains(&"constants.rs"));
297 assert!(names.contains(&"types.rs"));
298 assert!(names.contains(&"platform.rs"));
299 assert!(names.contains(&"mod.rs"));
300 for s in &driver.subtasks {
302 assert!(
303 !s.file.contains('\u{FF0C}') && !s.file.contains('\u{3002}'),
304 "extracted name `{}` contains Chinese punctuation — splitter missed",
305 s.file
306 );
307 }
308 }
309
310 #[test]
311 fn extract_files_from_plan() {
312 let plan =
313 "\u{6211}\u{8BA1}\u{5212}\u{4FEE}\u{6539}\u{4EE5}\u{4E0B}\u{6587}\u{4EF6}\u{FF1A}
3141. TagRebuildTaskService.java \u{2014} \u{6DFB}\u{52A0} token \u{7EDF}\u{8BA1}
3152. AITagExtractionService.java \u{2014} \u{8FD4}\u{56DE} token \u{6D88}\u{8017}
3163. SettingsView.vue \u{2014} \u{524D}\u{7AEF}\u{663E}\u{793A}";
317
318 let mut driver = SubtaskDriver::new();
319 driver.extract_from_plan(plan);
320
321 assert!(driver.active);
322 assert_eq!(driver.subtasks.len(), 3);
323 assert!(driver.subtasks[0].file.ends_with(".java"));
325 assert!(driver.subtasks[1].file.ends_with(".java"));
326 assert!(driver.subtasks[2].file.ends_with(".vue"));
328 }
329
330 #[test]
331 fn reference_files_filtered_out() {
332 let plan = "\u{6211}\u{5C06}\u{53C2}\u{8003} ProductCenter.vue \u{7684}\u{98CE}\u{683C}\u{FF0C}\u{4FEE}\u{6539} TestCenter.vue \u{6DFB}\u{52A0}\u{72B6}\u{6001}\u{7B5B}\u{9009}\u{529F}\u{80FD}\u{3002}";
334
335 let mut driver = SubtaskDriver::new();
336 driver.extract_from_plan(plan);
337
338 assert_eq!(driver.subtasks.len(), 1);
340 assert_eq!(driver.subtasks[0].file, "TestCenter.vue");
341 }
342
343 #[test]
344 fn reference_file_english() {
345 let plan =
346 "I'll follow the style of IdeaCenter.vue and modify DevCenter.vue to add code reviews.";
347
348 let mut driver = SubtaskDriver::new();
349 driver.extract_from_plan(plan);
350
351 assert_eq!(driver.subtasks.len(), 1);
352 assert_eq!(driver.subtasks[0].file, "DevCenter.vue");
353 }
354
355 #[test]
356 fn multiple_modify_targets_no_reference() {
357 let plan = "\u{4FEE}\u{6539} Service.java \u{7684}\u{63A5}\u{53E3}\u{FF0C}\u{7136}\u{540E}\u{66F4}\u{65B0} Controller.java \u{7684}\u{8C03}\u{7528}";
358
359 let mut driver = SubtaskDriver::new();
360 driver.extract_from_plan(plan);
361
362 assert_eq!(driver.subtasks.len(), 2);
363 }
364
365 #[test]
366 fn instruction_format() {
367 let mut driver = SubtaskDriver::new();
368 driver.extract_from_plan("\u{4FEE}\u{6539} TagService.java \u{548C} SettingsView.vue");
369
370 let instr = driver.current_instruction().unwrap();
371 assert!(instr.contains("Subtask 1/2"));
372 assert!(instr.contains("TagService.java"));
373 assert!(instr.contains("ONE edit"));
374 }
375
376 #[test]
377 fn advance_through_subtasks() {
378 let mut driver = SubtaskDriver::new();
379 driver.extract_from_plan("\u{4FEE}\u{6539} A.java \u{548C} B.vue");
380
381 assert_eq!(driver.current_idx, 0);
382 driver.advance();
383 assert_eq!(driver.current_idx, 1);
384 driver.advance();
385 assert!(driver.all_done());
386 assert!(!driver.active);
387 }
388
389 #[test]
390 fn empty_plan_no_subtasks() {
391 let mut driver = SubtaskDriver::new();
392 driver.extract_from_plan("\u{6211}\u{89C9}\u{5F97}\u{9700}\u{8981}\u{4FEE}\u{6539}\u{4E00}\u{4E9B}\u{4EE3}\u{7801}");
393 assert!(!driver.active);
394 }
395}