agentic_warden/
supervisor.rs

1use crate::cli_type::CliType;
2use crate::core::models::ProcessTreeInfo;
3use crate::core::process_tree::ProcessTreeError;
4use crate::error::RegistryError;
5use crate::logging::debug;
6use crate::logging::warn;
7#[cfg(windows)]
8use crate::platform::ChildResources;
9use crate::platform::{self};
10use crate::provider::{AiType, ProviderManager};
11use crate::signal;
12use crate::storage::TaskStorage;
13use crate::task_record::TaskRecord;
14use crate::unified_registry::Registry;
15use chrono::{DateTime, Utc};
16use std::env;
17use std::ffi::OsString;
18use std::io;
19use std::path::PathBuf;
20use std::process::{ExitStatus, Stdio};
21use std::sync::Arc;
22use thiserror::Error;
23use tokio::fs::OpenOptions;
24use tokio::io::{AsyncRead, AsyncWriteExt, BufWriter};
25use tokio::process::Command;
26use tokio::sync::Mutex;
27
28#[derive(Debug, Error)]
29pub enum ProcessError {
30    #[error("IO error: {0}")]
31    Io(#[from] io::Error),
32    #[error("Registry error: {0}")]
33    Registry(#[from] RegistryError),
34    #[error("Process tree error: {0}")]
35    ProcessTree(#[from] ProcessTreeError),
36    #[error("CLI executable not found: {0}")]
37    CliNotFound(String),
38    #[error("{0}")]
39    Other(String),
40}
41
42async fn get_cli_command(cli_type: &CliType) -> Result<String, ProcessError> {
43    // First try environment variable
44    if let Ok(custom_path) = env::var(cli_type.env_var_name()) {
45        if custom_path.is_empty() {
46            return Err(ProcessError::CliNotFound(format!(
47                "{} environment variable is empty",
48                cli_type.env_var_name()
49            )));
50        }
51        return Ok(custom_path);
52    }
53
54    // Fall back to default command name
55    let default_cmd = cli_type.command_name();
56
57    // On Windows, try to find the actual executable path
58    if cfg!(windows) {
59        let output = Command::new("where")
60            .arg(default_cmd)
61            .output()
62            .await
63            .map_err(|_| {
64                ProcessError::CliNotFound(format!(
65                    "Failed to check if '{}' exists in PATH",
66                    default_cmd
67                ))
68            })?;
69
70        if output.status.success() {
71            let stdout = String::from_utf8_lossy(&output.stdout);
72            // Prefer .cmd files on Windows, otherwise use first result
73            for line in stdout.lines() {
74                if line.ends_with(".cmd") || line.ends_with(".bat") || line.ends_with(".exe") {
75                    return Ok(line.to_string());
76                }
77            }
78            // Fallback to first line if no Windows executable found
79            if let Some(first_line) = stdout.lines().next() {
80                return Ok(first_line.to_string());
81            }
82        }
83
84        return Err(ProcessError::CliNotFound(format!(
85            "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
86            default_cmd,
87            cli_type.env_var_name()
88        )));
89    } else {
90        let output = Command::new("which")
91            .arg(default_cmd)
92            .output()
93            .await
94            .map_err(|_| {
95                ProcessError::CliNotFound(format!(
96                    "Failed to check if '{}' exists in PATH",
97                    default_cmd
98                ))
99            })?;
100
101        if !output.status.success() {
102            return Err(ProcessError::CliNotFound(format!(
103                "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
104                default_cmd,
105                cli_type.env_var_name()
106            )));
107        }
108    }
109
110    Ok(default_cmd.to_string())
111}
112
113/// Output handling strategy for CLI execution
114enum OutputStrategy {
115    /// Mirror output to stdout/stderr
116    Mirror,
117    /// Capture stdout to buffer
118    Capture(Arc<Mutex<Vec<u8>>>),
119}
120
121pub async fn execute_cli<S: TaskStorage>(
122    registry: &Registry<S>,
123    cli_type: &CliType,
124    args: &[OsString],
125    provider: Option<String>,
126) -> Result<i32, ProcessError> {
127    execute_cli_internal(registry, cli_type, args, provider, None, OutputStrategy::Mirror)
128        .await
129        .map(|(exit_code, _)| exit_code)
130}
131
132/// Execute CLI and capture stdout output (for code generation)
133pub async fn execute_cli_with_output<S: TaskStorage>(
134    registry: &Registry<S>,
135    cli_type: &CliType,
136    args: &[OsString],
137    provider: Option<String>,
138    timeout: std::time::Duration,
139) -> Result<String, ProcessError> {
140    let buffer = Arc::new(Mutex::new(Vec::new()));
141    let (_, output_opt) = execute_cli_internal(
142        registry,
143        cli_type,
144        args,
145        provider,
146        Some(timeout),
147        OutputStrategy::Capture(buffer.clone()),
148    )
149    .await?;
150
151    match output_opt {
152        Some(output) => Ok(output),
153        None => Err(ProcessError::Other(
154            "Output capture failed unexpectedly".to_string(),
155        )),
156    }
157}
158
159/// Internal CLI execution with configurable output handling
160async fn execute_cli_internal<S: TaskStorage>(
161    registry: &Registry<S>,
162    cli_type: &CliType,
163    args: &[OsString],
164    provider: Option<String>,
165    timeout: Option<std::time::Duration>,
166    output_strategy: OutputStrategy,
167) -> Result<(i32, Option<String>), ProcessError> {
168    let is_capture_mode = matches!(output_strategy, OutputStrategy::Capture(_));
169
170    platform::init_platform();
171
172    let terminate_wrapper = |pid: u32| {
173        platform::terminate_process(pid);
174        Ok(())
175    };
176    registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
177
178    // Load provider configuration
179    let provider_manager = ProviderManager::new()
180        .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
181
182    // Determine which provider to use
183    let (provider_name, provider_config) = if let Some(name) = provider {
184        let config = provider_manager
185            .get_provider(&name)
186            .map_err(|e| ProcessError::Other(e.to_string()))?;
187        (name, config)
188    } else {
189        let (name, config) = provider_manager
190            .get_default_provider()
191            .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
192        (name, config)
193    };
194
195    // Validate compatibility
196    let _ai_type = match cli_type {
197        CliType::Claude => AiType::Claude,
198        CliType::Codex => AiType::Codex,
199        CliType::Gemini => AiType::Gemini,
200    };
201
202    // Display provider info if not using official
203    if provider_name != *"official" {
204        eprintln!(
205            "Using provider: {} ({})",
206            provider_name,
207            provider_config.summary()
208        );
209    }
210
211    let cli_command = get_cli_command(cli_type).await?;
212
213    let mut command = Command::new(&cli_command);
214    command.args(args);
215    command.stdin(Stdio::null());
216    command.stdout(Stdio::piped());
217    command.stderr(Stdio::piped());
218
219    // Platform-specific command preparation
220    #[cfg(unix)]
221    {
222        unsafe {
223            command.pre_exec(|| {
224                let result = libc::setpgid(0, 0);
225                if result != 0 {
226                    return Err(io::Error::last_os_error());
227                }
228                #[cfg(target_os = "linux")]
229                {
230                    let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
231                    if result != 0 {
232                        return Err(io::Error::last_os_error());
233                    }
234                }
235                Ok(())
236            });
237        }
238    }
239
240    // Inject environment variables
241    for (key, value) in &provider_config.env {
242        command.env(key, value);
243    }
244
245    let mut child = command.spawn()?;
246    let child_pid = child
247        .id()
248        .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
249
250    let log_path = match generate_log_path(child_pid) {
251        Ok(path) => path,
252        Err(err) => {
253            platform::terminate_process(child_pid);
254            let _ = child.wait();
255            return Err(err.into());
256        }
257    };
258
259    let log_file = match OpenOptions::new()
260        .create(true)
261        .write(true)
262        .truncate(true)
263        .open(&log_path)
264        .await
265    {
266        Ok(file) => file,
267        Err(err) => {
268            platform::terminate_process(child_pid);
269            let _ = child.wait().await;
270            return Err(err.into());
271        }
272    };
273
274    debug(format!(
275        "Started {} process pid={} log={}{}",
276        cli_type.display_name(),
277        child_pid,
278        log_path.display(),
279        if is_capture_mode { " (capture mode)" } else { "" }
280    ));
281
282    #[cfg(windows)]
283    {
284        let _resources = ChildResources::with_job(None);
285    }
286
287    let signal_guard = signal::install(child_pid)?;
288
289    let log_writer = Arc::new(Mutex::new(BufWriter::new(log_file)));
290    let mut copy_handles = Vec::new();
291
292    // Handle stdout based on strategy
293    if let Some(stdout) = child.stdout.take() {
294        match &output_strategy {
295            OutputStrategy::Mirror => {
296                copy_handles.push(tokio::spawn(spawn_copy(
297                    stdout,
298                    log_writer.clone(),
299                    StreamMirror::Stdout,
300                )));
301            }
302            OutputStrategy::Capture(buffer) => {
303                let buffer_clone = buffer.clone();
304                let writer_clone = log_writer.clone();
305                copy_handles.push(tokio::spawn(async move {
306                    spawn_copy_with_capture(stdout, writer_clone, buffer_clone).await
307                }));
308            }
309        }
310    }
311
312    if let Some(stderr) = child.stderr.take() {
313        copy_handles.push(tokio::spawn(spawn_copy(
314            stderr,
315            log_writer.clone(),
316            StreamMirror::Stderr,
317        )));
318    }
319
320    let registration_guard = {
321        let mut record = TaskRecord::new(
322            Utc::now(),
323            child_pid.to_string(),
324            log_path.to_string_lossy().into_owned(),
325            Some(platform::current_pid()),
326        );
327
328        // Get process tree information
329        match ProcessTreeInfo::current() {
330            Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
331                Ok(updated) => {
332                    record = updated;
333                }
334                Err(err) => {
335                    warn(format!("Failed to attach process tree info: {}", err));
336                }
337            },
338            Err(err) => {
339                warn(format!("Failed to get process tree info: {}", err));
340            }
341        }
342
343        if let Err(err) = registry.register(child_pid, &record) {
344            platform::terminate_process(child_pid);
345            let _ = child.wait().await;
346            return Err(err.into());
347        }
348        Some(RegistrationGuard::new(registry, child_pid))
349    };
350
351    // Wait with optional timeout
352    let status = if let Some(timeout_duration) = timeout {
353        tokio::select! {
354            result = child.wait() => result?,
355            _ = tokio::time::sleep(timeout_duration) => {
356                platform::terminate_process(child_pid);
357                let _ = child.wait().await;
358                return Err(ProcessError::Other(format!(
359                    "CLI execution timed out after {:?}",
360                    timeout_duration
361                )));
362            }
363        }
364    } else {
365        child.wait().await?
366    };
367
368    drop(signal_guard);
369
370    for handle in copy_handles {
371        match handle.await {
372            Ok(result) => result?,
373            Err(_) => {
374                return Err(io::Error::other("Log writer task failed").into());
375            }
376        }
377    }
378
379    {
380        let mut writer = log_writer.lock().await;
381        writer.flush().await?;
382        writer.get_ref().sync_all().await?;
383    }
384
385    // Display log file path to user (not in capture mode)
386    if !is_capture_mode {
387        eprintln!("完整日志已保存到: {}", log_path.display());
388    }
389
390    if let Some(guard) = registration_guard {
391        let completed_at = Utc::now();
392        let exit_code = status.code();
393        let result = match (status.success(), exit_code) {
394            (true, _) => Some(if is_capture_mode {
395                "codegen_success".to_owned()
396            } else {
397                "success".to_owned()
398            }),
399            (false, Some(code)) => Some(format!(
400                "{}_failed_with_exit_code_{code}",
401                if is_capture_mode { "codegen" } else { "cli" }
402            )),
403            (false, None) => Some(format!(
404                "{}_failed_without_exit_code",
405                if is_capture_mode { "codegen" } else { "cli" }
406            )),
407        };
408        let _ = guard.mark_completed(result, exit_code, completed_at);
409    }
410
411    // Extract captured output if in capture mode
412    let captured_output = if let OutputStrategy::Capture(buffer) = output_strategy {
413        let output = buffer.lock().await.clone();
414        let output_str = String::from_utf8_lossy(&output).to_string();
415
416        if !status.success() {
417            return Err(ProcessError::Other(format!(
418                "{} CLI failed with exit code {}: {}",
419                cli_type.display_name(),
420                extract_exit_code(status),
421                output_str
422            )));
423        }
424
425        Some(output_str)
426    } else {
427        None
428    };
429
430    Ok((extract_exit_code(status), captured_output))
431}
432
433/// Generate a secure log file path in runtime directory
434///
435/// Security considerations:
436/// - Uses system temp directory (cross-platform)
437/// - Creates directory with restrictive permissions (0700 on Unix)
438/// - Ensures logs are only accessible by the current user
439/// - Logs are automatically cleaned up on system reboot
440fn generate_log_path(pid: u32) -> io::Result<PathBuf> {
441    // Use system temp directory as per SPEC design (cross-platform)
442    // Linux/macOS: /tmp/.aiw/logs/
443    // Windows: %TEMP%\.aiw\logs\
444    // Runtime data (logs, temp files) → temp_dir()/.aiw/
445    // Persistent config → ~/.aiw/
446    let log_dir = std::env::temp_dir().join(".aiw").join("logs");
447
448    // Create the logs directory if it doesn't exist
449    if !log_dir.exists() {
450        std::fs::create_dir_all(&log_dir)?;
451
452        // Set restrictive permissions on Unix systems (only user can read/write/execute)
453        #[cfg(unix)]
454        {
455            use std::os::unix::fs::PermissionsExt;
456            let mut perms = std::fs::metadata(&log_dir)?.permissions();
457            perms.set_mode(0o700); // rwx------
458            std::fs::set_permissions(&log_dir, perms)?;
459        }
460    }
461
462    Ok(log_dir.join(format!("{pid}.log")))
463}
464
465#[derive(Copy, Clone)]
466enum StreamMirror {
467    Stdout,
468    Stderr,
469}
470
471impl StreamMirror {
472    async fn write(self, data: &[u8]) -> io::Result<()> {
473        use tokio::io::AsyncWriteExt;
474        match self {
475            StreamMirror::Stdout => {
476                let mut handle = tokio::io::stdout();
477                handle.write_all(data).await?;
478                handle.flush().await
479            }
480            StreamMirror::Stderr => {
481                let mut handle = tokio::io::stderr();
482                handle.write_all(data).await?;
483                handle.flush().await
484            }
485        }
486    }
487}
488
489async fn spawn_copy<R>(
490    mut reader: R,
491    writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
492    mirror: StreamMirror,
493) -> io::Result<()>
494where
495    R: AsyncRead + Unpin + Send + 'static,
496{
497    use tokio::io::AsyncReadExt;
498
499    let mut buffer = [0u8; 8192];
500    loop {
501        let read = reader.read(&mut buffer).await?;
502        if read == 0 {
503            break;
504        }
505        let chunk = &buffer[..read];
506        {
507            let mut guard = writer.lock().await;
508            guard.write_all(chunk).await?;
509            guard.flush().await?;
510        }
511        mirror.write(chunk).await?;
512    }
513    Ok(())
514}
515
516/// Copy stream to log file and capture to buffer (for code generation)
517async fn spawn_copy_with_capture<R>(
518    mut reader: R,
519    writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
520    capture_buffer: Arc<Mutex<Vec<u8>>>,
521) -> io::Result<()>
522where
523    R: AsyncRead + Unpin + Send + 'static,
524{
525    use tokio::io::AsyncReadExt;
526
527    let mut buffer = [0u8; 8192];
528    loop {
529        let read = reader.read(&mut buffer).await?;
530        if read == 0 {
531            break;
532        }
533        let chunk = &buffer[..read];
534
535        // Write to log file
536        {
537            let mut guard = writer.lock().await;
538            guard.write_all(chunk).await?;
539            guard.flush().await?;
540        }
541
542        // Capture to buffer
543        {
544            let mut capture = capture_buffer.lock().await;
545            capture.extend_from_slice(chunk);
546        }
547    }
548    Ok(())
549}
550
551fn extract_exit_code(status: ExitStatus) -> i32 {
552    status.code().unwrap_or(1)
553}
554
555struct RegistrationGuard<'a, S: TaskStorage> {
556    registry: &'a Registry<S>,
557    pid: u32,
558    active: bool,
559}
560
561impl<'a, S: TaskStorage> RegistrationGuard<'a, S> {
562    fn new(registry: &'a Registry<S>, pid: u32) -> Self {
563        Self {
564            registry,
565            pid,
566            active: true,
567        }
568    }
569
570    fn mark_completed(
571        mut self,
572        result: Option<String>,
573        exit_code: Option<i32>,
574        completed_at: DateTime<Utc>,
575    ) -> Result<(), RegistryError> {
576        if self.active {
577            self.registry
578                .mark_completed(self.pid, result, exit_code, completed_at)?;
579            self.active = false;
580        }
581        Ok(())
582    }
583}
584
585impl<S: TaskStorage> Drop for RegistrationGuard<'_, S> {
586    fn drop(&mut self) {
587        // 注意:TaskStorage trait不提供remove方法
588        // 如果需要清理,应该通过mark_completed或sweep_stale_entries
589        // 这里我们什么都不做,让任务记录保留在注册表中
590    }
591}
592
593/// Start interactive CLI mode (directly launch AI CLI without task prompt)
594pub async fn start_interactive_cli<S: TaskStorage>(
595    registry: &Registry<S>,
596    cli_type: &CliType,
597    provider: Option<String>,
598) -> Result<i32, ProcessError> {
599    platform::init_platform();
600
601    let terminate_wrapper = |pid: u32| {
602        platform::terminate_process(pid);
603        Ok(())
604    };
605    registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
606
607    // Load provider configuration
608    let provider_manager = ProviderManager::new()
609        .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
610
611    // Determine which provider to use
612    let (provider_name, provider_config) = if let Some(name) = provider {
613        let config = provider_manager
614            .get_provider(&name)
615            .map_err(|e| ProcessError::Other(e.to_string()))?;
616        (name, config)
617    } else {
618        let (name, config) = provider_manager
619            .get_default_provider()
620            .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
621        (name, config)
622    };
623
624    // Validate compatibility
625    let _ai_type = match cli_type {
626        CliType::Claude => AiType::Claude,
627        CliType::Codex => AiType::Codex,
628        CliType::Gemini => AiType::Gemini,
629    };
630
631    // Display provider info if not using official
632    if provider_name != *"official" {
633        eprintln!(
634            "Using provider: {} ({})",
635            provider_name,
636            provider_config.summary()
637        );
638    }
639
640    let cli_command = get_cli_command(cli_type).await?;
641
642    // Interactive mode: launch CLI with stdin/stdout/stderr inherited
643    let mut command = Command::new(&cli_command);
644
645    // Add interactive args (e.g., "exec" for Codex, "-p" for Claude)
646    let interactive_args = cli_type.build_interactive_args();
647    command.args(&interactive_args);
648
649    command.stdin(Stdio::inherit());
650    command.stdout(Stdio::inherit());
651    command.stderr(Stdio::inherit());
652
653    // Platform-specific command preparation (Unix: set process group and death signal)
654    #[cfg(unix)]
655    {
656        unsafe {
657            command.pre_exec(|| {
658                // Set process group ID
659                let result = libc::setpgid(0, 0);
660                if result != 0 {
661                    return Err(io::Error::last_os_error());
662                }
663                // Set parent death signal on Linux
664                #[cfg(target_os = "linux")]
665                {
666                    let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
667                    if result != 0 {
668                        return Err(io::Error::last_os_error());
669                    }
670                }
671                Ok(())
672            });
673        }
674    }
675
676    // Inject environment variables
677    for (key, value) in &provider_config.env {
678        command.env(key, value);
679    }
680
681    let mut child = command.spawn()?;
682    let child_pid = child
683        .id()
684        .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
685
686    // Register the interactive CLI process
687    let log_path = generate_log_path(child_pid)?;
688    let record = TaskRecord::new(
689        Utc::now(),
690        child_pid.to_string(),
691        log_path.to_string_lossy().into_owned(),
692        Some(platform::current_pid()),
693    );
694
695    // Get process tree information
696    let record = match ProcessTreeInfo::current() {
697        Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
698            Ok(updated) => updated,
699            Err(err) => {
700                warn(format!("Failed to attach process tree info: {}", err));
701                record
702            }
703        },
704        Err(err) => {
705            warn(format!("Failed to get process tree info: {}", err));
706            record
707        }
708    };
709
710    if let Err(err) = registry.register(child_pid, &record) {
711        platform::terminate_process(child_pid);
712        let _ = child.wait().await;
713        return Err(err.into());
714    }
715
716    let registration_guard = RegistrationGuard::new(registry, child_pid);
717    let signal_guard = signal::install(child_pid)?;
718
719    let status = child.wait().await?;
720    drop(signal_guard);
721
722    // Mark as completed
723    let completed_at = Utc::now();
724    let exit_code = status.code();
725    let result = match (status.success(), exit_code) {
726        (true, _) => Some("interactive_session_completed".to_owned()),
727        (false, Some(code)) => Some(format!("interactive_session_failed_with_exit_code_{code}")),
728        (false, None) => Some("interactive_session_failed_without_exit_code".to_owned()),
729    };
730    let _ = registration_guard.mark_completed(result, exit_code, completed_at);
731
732    Ok(extract_exit_code(status))
733}
734
735/// Execute multiple CLI processes (for codex|claude|gemini syntax)
736pub async fn execute_multiple_clis<S: TaskStorage>(
737    registry: &Registry<S>,
738    cli_selector: &crate::cli_type::CliSelector,
739    task_prompt: &str,
740    provider: Option<String>,
741) -> Result<Vec<i32>, ProcessError> {
742    let mut exit_codes = Vec::new();
743
744    for cli_type in &cli_selector.types {
745        let cli_args = cli_type.build_full_access_args(task_prompt);
746        let os_args: Vec<OsString> = cli_args.into_iter().map(|s| s.into()).collect();
747
748        let exit_code = execute_cli(registry, cli_type, &os_args, provider.clone()).await?;
749        exit_codes.push(exit_code);
750    }
751
752    Ok(exit_codes)
753}