agentic_warden/
wait_mode.rs

1use crate::config::{
2    LEGACY_WAIT_INTERVAL_ENV, MAX_WAIT_DURATION, WAIT_INTERVAL_DEFAULT, WAIT_INTERVAL_ENV,
3};
4use crate::core::models::ProcessTreeInfo;
5use crate::core::process_tree::ProcessTreeError;
6use crate::error::RegistryError;
7use crate::logging::warn;
8use crate::platform;
9use crate::registry_factory::{create_cli_registry, create_mcp_registry};
10use crate::storage::{CleanupReason, RegistryEntry};
11use crate::task_record::TaskRecord;
12use crate::task_record::TaskStatus;
13use chrono::{DateTime, Local, Utc};
14use std::collections::HashSet;
15use std::fmt::Write;
16use std::thread;
17use std::time::{Duration, Instant};
18use thiserror::Error;
19
20#[derive(Debug, Error)]
21pub enum WaitError {
22    #[error("registry error: {0}")]
23    Registry(#[from] RegistryError),
24    #[error("process tree error: {0}")]
25    ProcessTree(#[from] ProcessTreeError),
26}
27
28pub fn run() -> Result<(), WaitError> {
29    let cli_registry = create_cli_registry()?;
30    let mcp_registry = create_mcp_registry();
31    let interval = read_interval();
32    let start = Instant::now();
33    let mut processed_pids: HashSet<u32> = HashSet::new();
34    let mut report = TaskReport::new();
35
36    // Get current process root parent for task filtering (core functionality)
37    let current_root_parent = match ProcessTreeInfo::current() {
38        Ok(tree_info) => tree_info.get_ai_cli_root(),
39        Err(err) => {
40            warn(format!("Failed to get process tree info: {}", err));
41            None
42        }
43    };
44
45    let terminate_wrapper = |pid: u32| {
46        platform::terminate_process(pid);
47        Ok(())
48    };
49
50    loop {
51        let now = chrono::Utc::now();
52
53        // Process CLI registry tasks
54        let cli_cleanups =
55            cli_registry.sweep_stale_entries(now, platform::process_alive, &terminate_wrapper)?;
56        for event in cli_cleanups {
57            if event.reason == CleanupReason::Timeout {
58                continue;
59            }
60
61            // Filter tasks by root parent PID
62            if !should_process_task(&event.record, current_root_parent) {
63                continue;
64            }
65
66            let pid = event._pid;
67            if processed_pids.insert(pid) {
68                let completion = TaskCompletion::from_record(pid, event.record);
69                emit_realtime_update(&completion);
70                report.add_completion(completion);
71            }
72        }
73
74        for (pid, record) in cli_registry.get_completed_unread_tasks()? {
75            // Filter tasks by root parent PID
76            if !should_process_task(&record, current_root_parent) {
77                continue;
78            }
79
80            if processed_pids.insert(pid) {
81                let completion = TaskCompletion::from_record(pid, record);
82                emit_realtime_update(&completion);
83                report.add_completion(completion);
84            }
85        }
86
87        // Process MCP registry tasks (without root parent filtering for cross-process)
88        let mcp_cleanups =
89            mcp_registry.sweep_stale_entries(now, platform::process_alive, &terminate_wrapper)?;
90        for event in mcp_cleanups {
91            if event.reason == CleanupReason::Timeout {
92                continue;
93            }
94
95            let pid = event._pid;
96            if processed_pids.insert(pid) {
97                let completion = TaskCompletion::from_record(pid, event.record);
98                emit_realtime_update(&completion);
99                report.add_completion(completion);
100            }
101        }
102
103        for (pid, record) in mcp_registry.get_completed_unread_tasks()? {
104            if processed_pids.insert(pid) {
105                let completion = TaskCompletion::from_record(pid, record);
106                emit_realtime_update(&completion);
107                report.add_completion(completion);
108            }
109        }
110
111        // Check both registries for running tasks
112        let cli_entries = cli_registry.entries()?;
113        let mcp_entries = mcp_registry.entries()?;
114
115        let cli_has_running = cli_entries.iter().any(|entry| {
116            entry.record.status == TaskStatus::Running
117                && should_process_task(&entry.record, current_root_parent)
118        });
119
120        let mcp_has_running = mcp_entries
121            .iter()
122            .any(|entry| entry.record.status == TaskStatus::Running);
123
124        // Only exit when both registries have no running tasks
125        if !cli_has_running && !mcp_has_running {
126            print_report(&report, None, false, start.elapsed());
127            return Ok(());
128        }
129
130        if start.elapsed() >= MAX_WAIT_DURATION {
131            // Filter running entries to only show related processes
132            let mut filtered_entries: Vec<RegistryEntry> = cli_entries
133                .iter()
134                .filter(|entry| should_process_task(&entry.record, current_root_parent))
135                .cloned()
136                .collect();
137
138            // Add MCP entries (no filtering for cross-process)
139            filtered_entries.extend(mcp_entries.iter().cloned());
140
141            print_report(&report, Some(&filtered_entries), true, start.elapsed());
142            return Ok(());
143        }
144
145        thread::sleep(interval);
146    }
147}
148
149/// Check if a task should be processed based on root parent PID
150pub fn should_process_task(record: &TaskRecord, current_root_parent: Option<u32>) -> bool {
151    // If we don't have root parent info, process all tasks
152    let task_root_parent = match record.resolved_root_parent_pid() {
153        Some(pid) => pid,
154        None => return true, // Process tasks without root parent info
155    };
156
157    match current_root_parent {
158        Some(current_pid) => task_root_parent == current_pid,
159        None => true, // No filtering if we couldn't get current root parent
160    }
161}
162
163fn read_interval() -> Duration {
164    read_env_interval(WAIT_INTERVAL_ENV)
165        .or_else(|| read_env_interval(LEGACY_WAIT_INTERVAL_ENV))
166        .unwrap_or(WAIT_INTERVAL_DEFAULT)
167}
168
169fn read_env_interval(var: &str) -> Option<Duration> {
170    match std::env::var(var) {
171        Ok(raw) => match raw.parse::<u64>() {
172            Ok(seconds) if seconds > 0 => Some(Duration::from_secs(seconds)),
173            _ => {
174                warn(format!(
175                    "environment variable {var} invalid, using default 30s"
176                ));
177                None
178            }
179        },
180        Err(_) => None,
181    }
182}
183
184fn emit_realtime_update(task: &TaskCompletion) {
185    let exit_code = task
186        .exit_code
187        .map(|code| code.to_string())
188        .unwrap_or_else(|| "未提供".to_string());
189    let status_word = if task.is_success() {
190        "完成"
191    } else {
192        "失败"
193    };
194    let header = format!(
195        "{} 任务{} PID={} (exit_code: {}) @ {}",
196        task.status_icon(),
197        status_word,
198        task.pid,
199        exit_code,
200        task.completed_time_local()
201    );
202    let log_line = format!("日志文件: {}", task.log_path);
203    let summary_line = format!("{}: {}", task.summary_label(), task.summary_text());
204
205    if task.is_success() {
206        println!("{header}");
207        println!("{log_line}");
208        println!("{summary_line}");
209    } else {
210        eprintln!("{header}");
211        eprintln!("{log_line}");
212        eprintln!("{summary_line}");
213    }
214}
215
216fn print_report(
217    report: &TaskReport,
218    running_entries: Option<&[RegistryEntry]>,
219    timed_out: bool,
220    wait_elapsed: Duration,
221) {
222    let mut buffer = String::new();
223    report
224        .render(&mut buffer, running_entries, timed_out, wait_elapsed)
225        .expect("rendering wait report");
226    println!("{buffer}");
227}
228
229#[derive(Clone)]
230struct TaskCompletion {
231    pid: u32,
232    log_path: String,
233    started_at: DateTime<Utc>,
234    completed_at: DateTime<Utc>,
235    exit_code: Option<i32>,
236    result: Option<String>,
237    cleanup_reason: Option<String>,
238}
239
240impl TaskCompletion {
241    fn from_record(pid: u32, mut record: TaskRecord) -> Self {
242        let completed_at = record.completed_at.unwrap_or_else(Utc::now);
243        record.completed_at = Some(completed_at);
244        Self {
245            pid,
246            log_path: record.log_path,
247            started_at: record.started_at,
248            completed_at,
249            exit_code: record.exit_code,
250            result: record.result,
251            cleanup_reason: record.cleanup_reason,
252        }
253    }
254
255    fn is_success(&self) -> bool {
256        self.cleanup_reason.is_none() && self.exit_code.unwrap_or(0) == 0
257    }
258
259    fn status_icon(&self) -> &'static str {
260        if self.is_success() {
261            "✅"
262        } else {
263            "❌"
264        }
265    }
266
267    fn completed_time_local(&self) -> String {
268        self.completed_at
269            .with_timezone(&Local)
270            .format("%Y-%m-%d %H:%M:%S")
271            .to_string()
272    }
273
274    fn summary_label(&self) -> &'static str {
275        if self.is_success() {
276            "结果摘要"
277        } else {
278            "错误摘要"
279        }
280    }
281
282    fn summary_text(&self) -> String {
283        if let Some(result) = &self.result {
284            result.clone()
285        } else if let Some(reason) = &self.cleanup_reason {
286            format!("任务被清理: {reason}")
287        } else if self.is_success() {
288            "任务成功完成,但未提供摘要。".to_string()
289        } else {
290            "任务失败,未提供错误摘要。".to_string()
291        }
292    }
293}
294
295struct TaskReport {
296    completions: Vec<TaskCompletion>,
297    earliest_start: Option<DateTime<Utc>>,
298    latest_completion: Option<DateTime<Utc>>,
299}
300
301impl TaskReport {
302    fn new() -> Self {
303        Self {
304            completions: Vec::new(),
305            earliest_start: None,
306            latest_completion: None,
307        }
308    }
309
310    fn add_completion(&mut self, completion: TaskCompletion) {
311        if self
312            .earliest_start
313            .is_none_or(|current| completion.started_at < current)
314        {
315            self.earliest_start = Some(completion.started_at);
316        }
317        if self
318            .latest_completion
319            .is_none_or(|current| completion.completed_at > current)
320        {
321            self.latest_completion = Some(completion.completed_at);
322        }
323        self.completions.push(completion);
324    }
325
326    fn total_count(&self) -> usize {
327        self.completions.len()
328    }
329
330    fn successful_count(&self) -> usize {
331        self.completions.iter().filter(|c| c.is_success()).count()
332    }
333
334    fn failed_count(&self) -> usize {
335        self.total_count() - self.successful_count()
336    }
337
338    fn total_duration(&self) -> Option<chrono::Duration> {
339        match (self.earliest_start, self.latest_completion) {
340            (Some(start), Some(end)) => Some(end.signed_duration_since(start)),
341            _ => None,
342        }
343    }
344
345    fn render(
346        &self,
347        buffer: &mut String,
348        running_entries: Option<&[RegistryEntry]>,
349        timed_out: bool,
350        wait_elapsed: Duration,
351    ) -> Result<(), std::fmt::Error> {
352        writeln!(buffer, "## 📋 任务执行完成报告")?;
353        if timed_out {
354            writeln!(buffer, "\n⚠️ 等待已达到最大时长,仍检测到未完成的任务。")?;
355        }
356
357        writeln!(buffer, "\n### ✅ 已完成任务列表")?;
358        if self.completions.is_empty() {
359            writeln!(buffer, "- 暂无完成任务")?;
360        } else {
361            let mut items = self.completions.clone();
362            items.sort_by_key(|item| item.completed_at);
363            for (idx, completion) in items.iter().enumerate() {
364                writeln!(buffer, "{}. **PID**: {}", idx + 1, completion.pid)?;
365                writeln!(
366                    buffer,
367                    "   - **状态**: {}",
368                    completion.status_icon_with_exit_code()
369                )?;
370                writeln!(buffer, "   - **日志文件**: {}", completion.log_path)?;
371                writeln!(
372                    buffer,
373                    "   - **完成时间**: {}",
374                    completion.completed_time_local()
375                )?;
376                writeln!(
377                    buffer,
378                    "   - **{}**: {}",
379                    completion.summary_label(),
380                    completion.summary_text()
381                )?;
382            }
383        }
384
385        let total_duration = self
386            .total_duration()
387            .or_else(|| chrono::Duration::from_std(wait_elapsed).ok())
388            .unwrap_or_else(chrono::Duration::zero);
389        writeln!(buffer, "\n### 📊 执行统计")?;
390        writeln!(buffer, "- 总任务数: {}", self.total_count())?;
391        writeln!(buffer, "- 成功: {}个", self.successful_count())?;
392        writeln!(buffer, "- 失败: {}个", self.failed_count())?;
393        writeln!(
394            buffer,
395            "- 总耗时: {}",
396            format_human_duration(total_duration)
397        )?;
398
399        writeln!(buffer, "\n### 📂 完整日志文件路径")?;
400        let mut log_paths: Vec<String> = Vec::new();
401        if self.completions.is_empty() {
402            writeln!(buffer, "- 无可用日志")?;
403        } else {
404            let mut paths: Vec<&String> = self.completions.iter().map(|c| &c.log_path).collect();
405            paths.sort();
406            paths.dedup();
407            for path in &paths {
408                writeln!(buffer, "- {path}")?;
409            }
410            log_paths = paths.iter().map(|path| (*path).clone()).collect();
411        }
412
413        if let Some(entries) = running_entries {
414            let running: Vec<&RegistryEntry> = entries
415                .iter()
416                .filter(|entry| entry.record.status == TaskStatus::Running)
417                .collect();
418            if !running.is_empty() {
419                writeln!(buffer, "\n### ⏳ 仍在运行的任务")?;
420                for entry in running {
421                    let started = entry
422                        .record
423                        .started_at
424                        .with_timezone(&Local)
425                        .format("%Y-%m-%d %H:%M:%S");
426                    writeln!(
427                        buffer,
428                        "- PID {} (启动于 {started}) -> {}",
429                        entry.pid, entry.record.log_path
430                    )?;
431                }
432            }
433        }
434
435        writeln!(
436            buffer,
437            "\n现在请基于上述结果继续你的工作,必要时查看日志文件。"
438        )?;
439        writeln!(buffer, "\n### 🧠 Claude 日志阅读提示")?;
440        writeln!(
441            buffer,
442            "- Claude,请分批次读取体积较大的日志文件,避免一次性请求全部内容。"
443        )?;
444        writeln!(
445            buffer,
446            "- 请在读取日志时使用 `offset`/`limit` 参数来控制输出范围,逐段检查关键信息。"
447        )?;
448        if log_paths.is_empty() {
449            writeln!(
450                buffer,
451                "- 当前没有可供阅读的日志文件路径,可在任务完成后再尝试。"
452            )?;
453        } else {
454            writeln!(buffer, "- 建议按照以下路径逐个读取日志:")?;
455            for path in &log_paths {
456                writeln!(buffer, "  - {path}")?;
457            }
458        }
459        writeln!(
460            buffer,
461            "- 读取完一批内容后,请说明下一步需要的 `offset`/`limit` 或指出新的文件路径,以便继续协助你。"
462        )?;
463        Ok(())
464    }
465}
466
467impl TaskCompletion {
468    fn status_icon_with_exit_code(&self) -> String {
469        let exit_code = self
470            .exit_code
471            .map(|code| code.to_string())
472            .unwrap_or_else(|| "未提供".to_string());
473        if let Some(reason) = &self.cleanup_reason {
474            format!(
475                "{} {} (exit_code: {exit_code}, cleanup: {reason})",
476                self.status_icon(),
477                if self.is_success() {
478                    "完成"
479                } else {
480                    "失败"
481                }
482            )
483        } else {
484            format!(
485                "{} {} (exit_code: {exit_code})",
486                self.status_icon(),
487                if self.is_success() {
488                    "完成"
489                } else {
490                    "失败"
491                }
492            )
493        }
494    }
495}
496
497fn format_human_duration(duration: chrono::Duration) -> String {
498    let mut seconds = duration.num_seconds();
499    if seconds < 0 {
500        seconds = 0;
501    }
502    let hours = seconds / 3600;
503    let minutes = (seconds % 3600) / 60;
504    let remaining_seconds = seconds % 60;
505
506    let mut parts = Vec::new();
507    if hours > 0 {
508        parts.push(format!("{hours}小时"));
509    }
510    if minutes > 0 {
511        parts.push(format!("{minutes}分"));
512    }
513    if remaining_seconds > 0 || parts.is_empty() {
514        parts.push(format!("{remaining_seconds}秒"));
515    }
516
517    parts.join("")
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use crate::task_record::TaskRecord;
524    use chrono::Utc;
525    use serial_test::serial;
526    use std::env;
527
528    /// Helper to safely set environment variable for testing
529    ///
530    /// Note: Environment variables are process-global, so these tests must run serially
531    /// using the #[serial] attribute from serial_test crate.
532    fn set_test_env(key: &str, value: &str) {
533        env::set_var(key, value);
534    }
535
536    /// Helper to safely clear test environment variables
537    fn clear_test_env() {
538        env::remove_var(WAIT_INTERVAL_ENV);
539        env::remove_var(LEGACY_WAIT_INTERVAL_ENV);
540    }
541
542    /// Cleanup guard that ensures environment variables are cleared after test
543    struct EnvCleanup;
544
545    impl Drop for EnvCleanup {
546        fn drop(&mut self) {
547            clear_test_env();
548        }
549    }
550
551    #[test]
552    #[serial] // Ensure tests run one at a time to avoid env var conflicts
553    fn prefers_primary_interval_env() {
554        let _cleanup = EnvCleanup;
555        clear_test_env();
556
557        set_test_env(WAIT_INTERVAL_ENV, "45");
558        assert_eq!(read_interval(), Duration::from_secs(45));
559    }
560
561    #[test]
562    #[serial] // Ensure tests run one at a time to avoid env var conflicts
563    fn falls_back_to_legacy_env() {
564        let _cleanup = EnvCleanup;
565        clear_test_env();
566
567        set_test_env(LEGACY_WAIT_INTERVAL_ENV, "90");
568        assert_eq!(read_interval(), Duration::from_secs(90));
569    }
570
571    #[test]
572    #[serial] // Ensure tests run one at a time to avoid env var conflicts
573    fn returns_default_on_invalid_values() {
574        let _cleanup = EnvCleanup;
575        clear_test_env();
576
577        set_test_env(WAIT_INTERVAL_ENV, "not-a-number");
578        assert_eq!(read_interval(), WAIT_INTERVAL_DEFAULT);
579    }
580
581    #[test]
582    fn test_should_process_task_filtering() {
583        let base_time = Utc::now();
584
585        // Task with root parent PID 100
586        let task_with_root = TaskRecord::new(
587            base_time,
588            "1001".to_string(),
589            "/tmp/1001.log".to_string(),
590            Some(1000),
591        );
592        let task_with_root = task_with_root
593            .with_process_tree_info(ProcessTreeInfo::new(vec![1000, 100]))
594            .expect("process tree should attach");
595
596        // Task without root parent info (backward compatibility)
597        let task_without_root = TaskRecord::new(
598            base_time,
599            "1002".to_string(),
600            "/tmp/1002.log".to_string(),
601            Some(1000),
602        );
603
604        // Task with different root parent PID
605        let task_different_root = TaskRecord::new(
606            base_time,
607            "1003".to_string(),
608            "/tmp/1003.log".to_string(),
609            Some(1000),
610        );
611        let task_different_root = task_different_root
612            .with_process_tree_info(ProcessTreeInfo::new(vec![2000, 200]))
613            .expect("process tree should attach");
614
615        // Test filtering by root parent 100
616        assert!(should_process_task(&task_with_root, Some(100)));
617        assert!(should_process_task(&task_without_root, Some(100))); // Should include tasks without root info
618        assert!(!should_process_task(&task_different_root, Some(100)));
619
620        // Test with no filter (None means process all tasks)
621        assert!(should_process_task(&task_with_root, None));
622        assert!(should_process_task(&task_without_root, None));
623        assert!(should_process_task(&task_different_root, None));
624    }
625}