agentic_warden/
supervisor.rs

1use crate::cli_type::CliType;
2use crate::core::models::ProcessTreeInfo;
3use crate::core::process_tree::ProcessTreeError;
4use crate::error::RegistryError;
5use crate::logging::debug;
6use crate::logging::warn;
7#[cfg(windows)]
8use crate::platform::ChildResources;
9use crate::platform::{self};
10use crate::provider::{AiType, ProviderManager};
11use crate::signal;
12use crate::storage::TaskStorage;
13use crate::task_record::TaskRecord;
14use crate::unified_registry::Registry;
15use chrono::{DateTime, Utc};
16use std::env;
17use std::ffi::OsString;
18use std::io;
19use std::path::PathBuf;
20use std::process::{ExitStatus, Stdio};
21use std::sync::Arc;
22use thiserror::Error;
23use tokio::fs::OpenOptions;
24use tokio::io::{AsyncRead, AsyncWriteExt, BufWriter};
25use tokio::process::Command;
26use tokio::sync::Mutex;
27
28#[derive(Debug, Error)]
29pub enum ProcessError {
30    #[error("IO error: {0}")]
31    Io(#[from] io::Error),
32    #[error("Registry error: {0}")]
33    Registry(#[from] RegistryError),
34    #[error("Process tree error: {0}")]
35    ProcessTree(#[from] ProcessTreeError),
36    #[error("CLI executable not found: {0}")]
37    CliNotFound(String),
38    #[error("{0}")]
39    Other(String),
40}
41
42async fn get_cli_command(cli_type: &CliType) -> Result<String, ProcessError> {
43    // First try environment variable
44    if let Ok(custom_path) = env::var(cli_type.env_var_name()) {
45        if custom_path.is_empty() {
46            return Err(ProcessError::CliNotFound(format!(
47                "{} environment variable is empty",
48                cli_type.env_var_name()
49            )));
50        }
51        return Ok(custom_path);
52    }
53
54    // Fall back to default command name
55    let default_cmd = cli_type.command_name();
56
57    // On Windows, try to find the actual executable path
58    if cfg!(windows) {
59        let output = Command::new("where")
60            .arg(default_cmd)
61            .output()
62            .await
63            .map_err(|_| {
64                ProcessError::CliNotFound(format!(
65                    "Failed to check if '{}' exists in PATH",
66                    default_cmd
67                ))
68            })?;
69
70        if output.status.success() {
71            let stdout = String::from_utf8_lossy(&output.stdout);
72            // Prefer .cmd files on Windows, otherwise use first result
73            for line in stdout.lines() {
74                if line.ends_with(".cmd") || line.ends_with(".bat") || line.ends_with(".exe") {
75                    return Ok(line.to_string());
76                }
77            }
78            // Fallback to first line if no Windows executable found
79            if let Some(first_line) = stdout.lines().next() {
80                return Ok(first_line.to_string());
81            }
82        }
83
84        return Err(ProcessError::CliNotFound(format!(
85            "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
86            default_cmd,
87            cli_type.env_var_name()
88        )));
89    } else {
90        let output = Command::new("which")
91            .arg(default_cmd)
92            .output()
93            .await
94            .map_err(|_| {
95                ProcessError::CliNotFound(format!(
96                    "Failed to check if '{}' exists in PATH",
97                    default_cmd
98                ))
99            })?;
100
101        if !output.status.success() {
102            return Err(ProcessError::CliNotFound(format!(
103                "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
104                default_cmd,
105                cli_type.env_var_name()
106            )));
107        }
108    }
109
110    Ok(default_cmd.to_string())
111}
112
113/// Output handling strategy for CLI execution
114enum OutputStrategy {
115    /// Mirror output to stdout/stderr
116    Mirror,
117    /// Capture stdout to buffer
118    Capture(Arc<Mutex<Vec<u8>>>),
119}
120
121pub async fn execute_cli<S: TaskStorage>(
122    registry: &Registry<S>,
123    cli_type: &CliType,
124    args: &[OsString],
125    provider: Option<String>,
126) -> Result<i32, ProcessError> {
127    execute_cli_internal(
128        registry,
129        cli_type,
130        args,
131        provider,
132        None,
133        OutputStrategy::Mirror,
134    )
135    .await
136    .map(|(exit_code, _)| exit_code)
137}
138
139/// Execute CLI and capture stdout output (for code generation)
140pub async fn execute_cli_with_output<S: TaskStorage>(
141    registry: &Registry<S>,
142    cli_type: &CliType,
143    args: &[OsString],
144    provider: Option<String>,
145    timeout: std::time::Duration,
146) -> Result<String, ProcessError> {
147    let buffer = Arc::new(Mutex::new(Vec::new()));
148    let (_, output_opt) = execute_cli_internal(
149        registry,
150        cli_type,
151        args,
152        provider,
153        Some(timeout),
154        OutputStrategy::Capture(buffer.clone()),
155    )
156    .await?;
157
158    match output_opt {
159        Some(output) => Ok(output),
160        None => Err(ProcessError::Other(
161            "Output capture failed unexpectedly".to_string(),
162        )),
163    }
164}
165
166/// Internal CLI execution with configurable output handling
167async fn execute_cli_internal<S: TaskStorage>(
168    registry: &Registry<S>,
169    cli_type: &CliType,
170    args: &[OsString],
171    provider: Option<String>,
172    timeout: Option<std::time::Duration>,
173    output_strategy: OutputStrategy,
174) -> Result<(i32, Option<String>), ProcessError> {
175    let is_capture_mode = matches!(output_strategy, OutputStrategy::Capture(_));
176
177    platform::init_platform();
178
179    let terminate_wrapper = |pid: u32| {
180        platform::terminate_process(pid);
181        Ok(())
182    };
183    registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
184
185    // Load provider configuration
186    let provider_manager = ProviderManager::new()
187        .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
188
189    // Determine which provider to use
190    let (provider_name, provider_config) = if let Some(name) = provider {
191        let config = provider_manager
192            .get_provider(&name)
193            .map_err(|e| ProcessError::Other(e.to_string()))?;
194        (name, config)
195    } else {
196        let (name, config) = provider_manager
197            .get_default_provider()
198            .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
199        (name, config)
200    };
201
202    // Validate compatibility
203    let _ai_type = match cli_type {
204        CliType::Claude => AiType::Claude,
205        CliType::Codex => AiType::Codex,
206        CliType::Gemini => AiType::Gemini,
207    };
208
209    // Display provider info if not using official
210    if provider_name != *"official" {
211        eprintln!(
212            "Using provider: {} ({})",
213            provider_name,
214            provider_config.summary()
215        );
216    }
217
218    let cli_command = get_cli_command(cli_type).await?;
219
220    let mut command = Command::new(&cli_command);
221    command.args(args);
222    command.stdin(Stdio::null());
223    command.stdout(Stdio::piped());
224    command.stderr(Stdio::piped());
225
226    // Platform-specific command preparation
227    #[cfg(unix)]
228    {
229        unsafe {
230            command.pre_exec(|| {
231                let result = libc::setpgid(0, 0);
232                if result != 0 {
233                    return Err(io::Error::last_os_error());
234                }
235                #[cfg(target_os = "linux")]
236                {
237                    let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
238                    if result != 0 {
239                        return Err(io::Error::last_os_error());
240                    }
241                }
242                Ok(())
243            });
244        }
245    }
246
247    // Inject environment variables
248    for (key, value) in &provider_config.env {
249        command.env(key, value);
250    }
251
252    let mut child = command.spawn()?;
253    let child_pid = child
254        .id()
255        .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
256
257    let log_path = match generate_log_path(child_pid) {
258        Ok(path) => path,
259        Err(err) => {
260            platform::terminate_process(child_pid);
261            let _ = child.wait();
262            return Err(err.into());
263        }
264    };
265
266    let log_file = match OpenOptions::new()
267        .create(true)
268        .write(true)
269        .truncate(true)
270        .open(&log_path)
271        .await
272    {
273        Ok(file) => file,
274        Err(err) => {
275            platform::terminate_process(child_pid);
276            let _ = child.wait().await;
277            return Err(err.into());
278        }
279    };
280
281    debug(format!(
282        "Started {} process pid={} log={}{}",
283        cli_type.display_name(),
284        child_pid,
285        log_path.display(),
286        if is_capture_mode {
287            " (capture mode)"
288        } else {
289            ""
290        }
291    ));
292
293    #[cfg(windows)]
294    {
295        let _resources = ChildResources::with_job(None);
296    }
297
298    let signal_guard = signal::install(child_pid)?;
299
300    let log_writer = Arc::new(Mutex::new(BufWriter::new(log_file)));
301    let mut copy_handles = Vec::new();
302
303    // Handle stdout based on strategy
304    if let Some(stdout) = child.stdout.take() {
305        match &output_strategy {
306            OutputStrategy::Mirror => {
307                copy_handles.push(tokio::spawn(spawn_copy(
308                    stdout,
309                    log_writer.clone(),
310                    StreamMirror::Stdout,
311                )));
312            }
313            OutputStrategy::Capture(buffer) => {
314                let buffer_clone = buffer.clone();
315                let writer_clone = log_writer.clone();
316                copy_handles.push(tokio::spawn(async move {
317                    spawn_copy_with_capture(stdout, writer_clone, buffer_clone).await
318                }));
319            }
320        }
321    }
322
323    if let Some(stderr) = child.stderr.take() {
324        copy_handles.push(tokio::spawn(spawn_copy(
325            stderr,
326            log_writer.clone(),
327            StreamMirror::Stderr,
328        )));
329    }
330
331    let registration_guard = {
332        let mut record = TaskRecord::new(
333            Utc::now(),
334            child_pid.to_string(),
335            log_path.to_string_lossy().into_owned(),
336            Some(platform::current_pid()),
337        );
338
339        // Get process tree information
340        match ProcessTreeInfo::current() {
341            Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
342                Ok(updated) => {
343                    record = updated;
344                }
345                Err(err) => {
346                    warn(format!("Failed to attach process tree info: {}", err));
347                }
348            },
349            Err(err) => {
350                warn(format!("Failed to get process tree info: {}", err));
351            }
352        }
353
354        if let Err(err) = registry.register(child_pid, &record) {
355            platform::terminate_process(child_pid);
356            let _ = child.wait().await;
357            return Err(err.into());
358        }
359        Some(RegistrationGuard::new(registry, child_pid))
360    };
361
362    // Wait with optional timeout
363    let status = if let Some(timeout_duration) = timeout {
364        tokio::select! {
365            result = child.wait() => result?,
366            _ = tokio::time::sleep(timeout_duration) => {
367                platform::terminate_process(child_pid);
368                let _ = child.wait().await;
369                return Err(ProcessError::Other(format!(
370                    "CLI execution timed out after {:?}",
371                    timeout_duration
372                )));
373            }
374        }
375    } else {
376        child.wait().await?
377    };
378
379    drop(signal_guard);
380
381    for handle in copy_handles {
382        match handle.await {
383            Ok(result) => result?,
384            Err(_) => {
385                return Err(io::Error::other("Log writer task failed").into());
386            }
387        }
388    }
389
390    {
391        let mut writer = log_writer.lock().await;
392        writer.flush().await?;
393        writer.get_ref().sync_all().await?;
394    }
395
396    // Display log file path to user (not in capture mode)
397    if !is_capture_mode {
398        eprintln!("完整日志已保存到: {}", log_path.display());
399    }
400
401    if let Some(guard) = registration_guard {
402        let completed_at = Utc::now();
403        let exit_code = status.code();
404        let result = match (status.success(), exit_code) {
405            (true, _) => Some(if is_capture_mode {
406                "codegen_success".to_owned()
407            } else {
408                "success".to_owned()
409            }),
410            (false, Some(code)) => Some(format!(
411                "{}_failed_with_exit_code_{code}",
412                if is_capture_mode { "codegen" } else { "cli" }
413            )),
414            (false, None) => Some(format!(
415                "{}_failed_without_exit_code",
416                if is_capture_mode { "codegen" } else { "cli" }
417            )),
418        };
419        let _ = guard.mark_completed(result, exit_code, completed_at);
420    }
421
422    // Extract captured output if in capture mode
423    let captured_output = if let OutputStrategy::Capture(buffer) = output_strategy {
424        let output = buffer.lock().await.clone();
425        let output_str = String::from_utf8_lossy(&output).to_string();
426
427        if !status.success() {
428            return Err(ProcessError::Other(format!(
429                "{} CLI failed with exit code {}: {}",
430                cli_type.display_name(),
431                extract_exit_code(status),
432                output_str
433            )));
434        }
435
436        Some(output_str)
437    } else {
438        None
439    };
440
441    Ok((extract_exit_code(status), captured_output))
442}
443
444/// Generate a secure log file path in runtime directory
445///
446/// Security considerations:
447/// - Uses system temp directory (cross-platform)
448/// - Creates directory with restrictive permissions (0700 on Unix)
449/// - Ensures logs are only accessible by the current user
450/// - Logs are automatically cleaned up on system reboot
451fn generate_log_path(pid: u32) -> io::Result<PathBuf> {
452    // Use system temp directory as per SPEC design (cross-platform)
453    // Linux/macOS: /tmp/.aiw/logs/
454    // Windows: %TEMP%\.aiw\logs\
455    // Runtime data (logs, temp files) → temp_dir()/.aiw/
456    // Persistent config → ~/.aiw/
457    let log_dir = std::env::temp_dir().join(".aiw").join("logs");
458
459    // Create the logs directory if it doesn't exist
460    if !log_dir.exists() {
461        std::fs::create_dir_all(&log_dir)?;
462
463        // Set restrictive permissions on Unix systems (only user can read/write/execute)
464        #[cfg(unix)]
465        {
466            use std::os::unix::fs::PermissionsExt;
467            let mut perms = std::fs::metadata(&log_dir)?.permissions();
468            perms.set_mode(0o700); // rwx------
469            std::fs::set_permissions(&log_dir, perms)?;
470        }
471    }
472
473    Ok(log_dir.join(format!("{pid}.log")))
474}
475
476#[derive(Copy, Clone)]
477enum StreamMirror {
478    Stdout,
479    Stderr,
480}
481
482impl StreamMirror {
483    async fn write(self, data: &[u8]) -> io::Result<()> {
484        use tokio::io::AsyncWriteExt;
485        match self {
486            StreamMirror::Stdout => {
487                let mut handle = tokio::io::stdout();
488                handle.write_all(data).await?;
489                handle.flush().await
490            }
491            StreamMirror::Stderr => {
492                let mut handle = tokio::io::stderr();
493                handle.write_all(data).await?;
494                handle.flush().await
495            }
496        }
497    }
498}
499
500async fn spawn_copy<R>(
501    mut reader: R,
502    writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
503    mirror: StreamMirror,
504) -> io::Result<()>
505where
506    R: AsyncRead + Unpin + Send + 'static,
507{
508    use tokio::io::AsyncReadExt;
509
510    let mut buffer = [0u8; 8192];
511    loop {
512        let read = reader.read(&mut buffer).await?;
513        if read == 0 {
514            break;
515        }
516        let chunk = &buffer[..read];
517        {
518            let mut guard = writer.lock().await;
519            guard.write_all(chunk).await?;
520            guard.flush().await?;
521        }
522        mirror.write(chunk).await?;
523    }
524    Ok(())
525}
526
527/// Copy stream to log file and capture to buffer (for code generation)
528async fn spawn_copy_with_capture<R>(
529    mut reader: R,
530    writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
531    capture_buffer: Arc<Mutex<Vec<u8>>>,
532) -> io::Result<()>
533where
534    R: AsyncRead + Unpin + Send + 'static,
535{
536    use tokio::io::AsyncReadExt;
537
538    let mut buffer = [0u8; 8192];
539    loop {
540        let read = reader.read(&mut buffer).await?;
541        if read == 0 {
542            break;
543        }
544        let chunk = &buffer[..read];
545
546        // Write to log file
547        {
548            let mut guard = writer.lock().await;
549            guard.write_all(chunk).await?;
550            guard.flush().await?;
551        }
552
553        // Capture to buffer
554        {
555            let mut capture = capture_buffer.lock().await;
556            capture.extend_from_slice(chunk);
557        }
558    }
559    Ok(())
560}
561
562fn extract_exit_code(status: ExitStatus) -> i32 {
563    status.code().unwrap_or(1)
564}
565
566struct RegistrationGuard<'a, S: TaskStorage> {
567    registry: &'a Registry<S>,
568    pid: u32,
569    active: bool,
570}
571
572impl<'a, S: TaskStorage> RegistrationGuard<'a, S> {
573    fn new(registry: &'a Registry<S>, pid: u32) -> Self {
574        Self {
575            registry,
576            pid,
577            active: true,
578        }
579    }
580
581    fn mark_completed(
582        mut self,
583        result: Option<String>,
584        exit_code: Option<i32>,
585        completed_at: DateTime<Utc>,
586    ) -> Result<(), RegistryError> {
587        if self.active {
588            self.registry
589                .mark_completed(self.pid, result, exit_code, completed_at)?;
590            self.active = false;
591        }
592        Ok(())
593    }
594}
595
596impl<S: TaskStorage> Drop for RegistrationGuard<'_, S> {
597    fn drop(&mut self) {
598        // 注意:TaskStorage trait不提供remove方法
599        // 如果需要清理,应该通过mark_completed或sweep_stale_entries
600        // 这里我们什么都不做,让任务记录保留在注册表中
601    }
602}
603
604/// Start interactive CLI mode (directly launch AI CLI without task prompt)
605pub async fn start_interactive_cli<S: TaskStorage>(
606    registry: &Registry<S>,
607    cli_type: &CliType,
608    provider: Option<String>,
609) -> Result<i32, ProcessError> {
610    platform::init_platform();
611
612    let terminate_wrapper = |pid: u32| {
613        platform::terminate_process(pid);
614        Ok(())
615    };
616    registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
617
618    // Load provider configuration
619    let provider_manager = ProviderManager::new()
620        .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
621
622    // Determine which provider to use
623    let (provider_name, provider_config) = if let Some(name) = provider {
624        let config = provider_manager
625            .get_provider(&name)
626            .map_err(|e| ProcessError::Other(e.to_string()))?;
627        (name, config)
628    } else {
629        let (name, config) = provider_manager
630            .get_default_provider()
631            .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
632        (name, config)
633    };
634
635    // Validate compatibility
636    let _ai_type = match cli_type {
637        CliType::Claude => AiType::Claude,
638        CliType::Codex => AiType::Codex,
639        CliType::Gemini => AiType::Gemini,
640    };
641
642    // Display provider info if not using official
643    if provider_name != *"official" {
644        eprintln!(
645            "Using provider: {} ({})",
646            provider_name,
647            provider_config.summary()
648        );
649    }
650
651    let cli_command = get_cli_command(cli_type).await?;
652
653    // Interactive mode: launch CLI with stdin/stdout/stderr inherited
654    let mut command = Command::new(&cli_command);
655
656    // Add interactive args (e.g., "exec" for Codex, "-p" for Claude)
657    let interactive_args = cli_type.build_interactive_args();
658    command.args(&interactive_args);
659
660    command.stdin(Stdio::inherit());
661    command.stdout(Stdio::inherit());
662    command.stderr(Stdio::inherit());
663
664    // Platform-specific command preparation (Unix: set process group and death signal)
665    #[cfg(unix)]
666    {
667        unsafe {
668            command.pre_exec(|| {
669                // Set process group ID
670                let result = libc::setpgid(0, 0);
671                if result != 0 {
672                    return Err(io::Error::last_os_error());
673                }
674                // Set parent death signal on Linux
675                #[cfg(target_os = "linux")]
676                {
677                    let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
678                    if result != 0 {
679                        return Err(io::Error::last_os_error());
680                    }
681                }
682                Ok(())
683            });
684        }
685    }
686
687    // Inject environment variables
688    for (key, value) in &provider_config.env {
689        command.env(key, value);
690    }
691
692    let mut child = command.spawn()?;
693    let child_pid = child
694        .id()
695        .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
696
697    // Register the interactive CLI process
698    let log_path = generate_log_path(child_pid)?;
699    let record = TaskRecord::new(
700        Utc::now(),
701        child_pid.to_string(),
702        log_path.to_string_lossy().into_owned(),
703        Some(platform::current_pid()),
704    );
705
706    // Get process tree information
707    let record = match ProcessTreeInfo::current() {
708        Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
709            Ok(updated) => updated,
710            Err(err) => {
711                warn(format!("Failed to attach process tree info: {}", err));
712                record
713            }
714        },
715        Err(err) => {
716            warn(format!("Failed to get process tree info: {}", err));
717            record
718        }
719    };
720
721    if let Err(err) = registry.register(child_pid, &record) {
722        platform::terminate_process(child_pid);
723        let _ = child.wait().await;
724        return Err(err.into());
725    }
726
727    let registration_guard = RegistrationGuard::new(registry, child_pid);
728    let signal_guard = signal::install(child_pid)?;
729
730    let status = child.wait().await?;
731    drop(signal_guard);
732
733    // Mark as completed
734    let completed_at = Utc::now();
735    let exit_code = status.code();
736    let result = match (status.success(), exit_code) {
737        (true, _) => Some("interactive_session_completed".to_owned()),
738        (false, Some(code)) => Some(format!("interactive_session_failed_with_exit_code_{code}")),
739        (false, None) => Some("interactive_session_failed_without_exit_code".to_owned()),
740    };
741    let _ = registration_guard.mark_completed(result, exit_code, completed_at);
742
743    Ok(extract_exit_code(status))
744}
745
746/// Execute multiple CLI processes (for codex|claude|gemini syntax)
747pub async fn execute_multiple_clis<S: TaskStorage>(
748    registry: &Registry<S>,
749    cli_selector: &crate::cli_type::CliSelector,
750    task_prompt: &str,
751    provider: Option<String>,
752) -> Result<Vec<i32>, ProcessError> {
753    let mut exit_codes = Vec::new();
754
755    for cli_type in &cli_selector.types {
756        let cli_args = cli_type.build_full_access_args(task_prompt);
757        let os_args: Vec<OsString> = cli_args.into_iter().map(|s| s.into()).collect();
758
759        let exit_code = execute_cli(registry, cli_type, &os_args, provider.clone()).await?;
760        exit_codes.push(exit_code);
761    }
762
763    Ok(exit_codes)
764}