1use crate::cli_type::CliType;
2use crate::core::models::ProcessTreeInfo;
3use crate::core::process_tree::ProcessTreeError;
4use crate::error::RegistryError;
5use crate::logging::debug;
6use crate::logging::warn;
7#[cfg(windows)]
8use crate::platform::ChildResources;
9use crate::platform::{self};
10use crate::provider::{AiType, ProviderManager};
11use crate::signal;
12use crate::storage::TaskStorage;
13use crate::task_record::TaskRecord;
14use crate::unified_registry::Registry;
15use chrono::{DateTime, Utc};
16use std::env;
17use std::ffi::OsString;
18use std::io;
19use std::path::PathBuf;
20use std::process::{ExitStatus, Stdio};
21use std::sync::Arc;
22use thiserror::Error;
23use tokio::fs::OpenOptions;
24use tokio::io::{AsyncRead, AsyncWriteExt, BufWriter};
25use tokio::process::Command;
26use tokio::sync::Mutex;
27
28#[derive(Debug, Error)]
29pub enum ProcessError {
30 #[error("IO error: {0}")]
31 Io(#[from] io::Error),
32 #[error("Registry error: {0}")]
33 Registry(#[from] RegistryError),
34 #[error("Process tree error: {0}")]
35 ProcessTree(#[from] ProcessTreeError),
36 #[error("CLI executable not found: {0}")]
37 CliNotFound(String),
38 #[error("{0}")]
39 Other(String),
40}
41
42async fn get_cli_command(cli_type: &CliType) -> Result<String, ProcessError> {
43 if let Ok(custom_path) = env::var(cli_type.env_var_name()) {
45 if custom_path.is_empty() {
46 return Err(ProcessError::CliNotFound(format!(
47 "{} environment variable is empty",
48 cli_type.env_var_name()
49 )));
50 }
51 return Ok(custom_path);
52 }
53
54 let default_cmd = cli_type.command_name();
56
57 if cfg!(windows) {
59 let output = Command::new("where")
60 .arg(default_cmd)
61 .output()
62 .await
63 .map_err(|_| {
64 ProcessError::CliNotFound(format!(
65 "Failed to check if '{}' exists in PATH",
66 default_cmd
67 ))
68 })?;
69
70 if output.status.success() {
71 let stdout = String::from_utf8_lossy(&output.stdout);
72 for line in stdout.lines() {
74 if line.ends_with(".cmd") || line.ends_with(".bat") || line.ends_with(".exe") {
75 return Ok(line.to_string());
76 }
77 }
78 if let Some(first_line) = stdout.lines().next() {
80 return Ok(first_line.to_string());
81 }
82 }
83
84 return Err(ProcessError::CliNotFound(format!(
85 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
86 default_cmd,
87 cli_type.env_var_name()
88 )));
89 } else {
90 let output = Command::new("which")
91 .arg(default_cmd)
92 .output()
93 .await
94 .map_err(|_| {
95 ProcessError::CliNotFound(format!(
96 "Failed to check if '{}' exists in PATH",
97 default_cmd
98 ))
99 })?;
100
101 if !output.status.success() {
102 return Err(ProcessError::CliNotFound(format!(
103 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
104 default_cmd,
105 cli_type.env_var_name()
106 )));
107 }
108 }
109
110 Ok(default_cmd.to_string())
111}
112
113enum OutputStrategy {
115 Mirror,
117 Capture(Arc<Mutex<Vec<u8>>>),
119}
120
121pub async fn execute_cli<S: TaskStorage>(
122 registry: &Registry<S>,
123 cli_type: &CliType,
124 args: &[OsString],
125 provider: Option<String>,
126) -> Result<i32, ProcessError> {
127 execute_cli_internal(
128 registry,
129 cli_type,
130 args,
131 provider,
132 None,
133 OutputStrategy::Mirror,
134 )
135 .await
136 .map(|(exit_code, _)| exit_code)
137}
138
139pub async fn execute_cli_with_output<S: TaskStorage>(
141 registry: &Registry<S>,
142 cli_type: &CliType,
143 args: &[OsString],
144 provider: Option<String>,
145 timeout: std::time::Duration,
146) -> Result<String, ProcessError> {
147 let buffer = Arc::new(Mutex::new(Vec::new()));
148 let (_, output_opt) = execute_cli_internal(
149 registry,
150 cli_type,
151 args,
152 provider,
153 Some(timeout),
154 OutputStrategy::Capture(buffer.clone()),
155 )
156 .await?;
157
158 match output_opt {
159 Some(output) => Ok(output),
160 None => Err(ProcessError::Other(
161 "Output capture failed unexpectedly".to_string(),
162 )),
163 }
164}
165
166async fn execute_cli_internal<S: TaskStorage>(
168 registry: &Registry<S>,
169 cli_type: &CliType,
170 args: &[OsString],
171 provider: Option<String>,
172 timeout: Option<std::time::Duration>,
173 output_strategy: OutputStrategy,
174) -> Result<(i32, Option<String>), ProcessError> {
175 let is_capture_mode = matches!(output_strategy, OutputStrategy::Capture(_));
176
177 platform::init_platform();
178
179 let terminate_wrapper = |pid: u32| {
180 platform::terminate_process(pid);
181 Ok(())
182 };
183 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
184
185 let provider_manager = ProviderManager::new()
187 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
188
189 let (provider_name, provider_config) = if let Some(name) = provider {
191 let config = provider_manager
192 .get_provider(&name)
193 .map_err(|e| ProcessError::Other(e.to_string()))?;
194 (name, config)
195 } else {
196 let (name, config) = provider_manager
197 .get_default_provider()
198 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
199 (name, config)
200 };
201
202 let _ai_type = match cli_type {
204 CliType::Claude => AiType::Claude,
205 CliType::Codex => AiType::Codex,
206 CliType::Gemini => AiType::Gemini,
207 };
208
209 if provider_name != *"official" {
211 eprintln!(
212 "Using provider: {} ({})",
213 provider_name,
214 provider_config.summary()
215 );
216 }
217
218 let cli_command = get_cli_command(cli_type).await?;
219
220 let mut command = Command::new(&cli_command);
221 command.args(args);
222 command.stdin(Stdio::null());
223 command.stdout(Stdio::piped());
224 command.stderr(Stdio::piped());
225
226 #[cfg(unix)]
228 {
229 unsafe {
230 command.pre_exec(|| {
231 let result = libc::setpgid(0, 0);
232 if result != 0 {
233 return Err(io::Error::last_os_error());
234 }
235 #[cfg(target_os = "linux")]
236 {
237 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
238 if result != 0 {
239 return Err(io::Error::last_os_error());
240 }
241 }
242 Ok(())
243 });
244 }
245 }
246
247 for (key, value) in &provider_config.env {
249 command.env(key, value);
250 }
251
252 let mut child = command.spawn()?;
253 let child_pid = child
254 .id()
255 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
256
257 let log_path = match generate_log_path(child_pid) {
258 Ok(path) => path,
259 Err(err) => {
260 platform::terminate_process(child_pid);
261 let _ = child.wait();
262 return Err(err.into());
263 }
264 };
265
266 let log_file = match OpenOptions::new()
267 .create(true)
268 .write(true)
269 .truncate(true)
270 .open(&log_path)
271 .await
272 {
273 Ok(file) => file,
274 Err(err) => {
275 platform::terminate_process(child_pid);
276 let _ = child.wait().await;
277 return Err(err.into());
278 }
279 };
280
281 debug(format!(
282 "Started {} process pid={} log={}{}",
283 cli_type.display_name(),
284 child_pid,
285 log_path.display(),
286 if is_capture_mode {
287 " (capture mode)"
288 } else {
289 ""
290 }
291 ));
292
293 #[cfg(windows)]
294 {
295 let _resources = ChildResources::with_job(None);
296 }
297
298 let signal_guard = signal::install(child_pid)?;
299
300 let log_writer = Arc::new(Mutex::new(BufWriter::new(log_file)));
301 let mut copy_handles = Vec::new();
302
303 if let Some(stdout) = child.stdout.take() {
305 match &output_strategy {
306 OutputStrategy::Mirror => {
307 copy_handles.push(tokio::spawn(spawn_copy(
308 stdout,
309 log_writer.clone(),
310 StreamMirror::Stdout,
311 )));
312 }
313 OutputStrategy::Capture(buffer) => {
314 let buffer_clone = buffer.clone();
315 let writer_clone = log_writer.clone();
316 copy_handles.push(tokio::spawn(async move {
317 spawn_copy_with_capture(stdout, writer_clone, buffer_clone).await
318 }));
319 }
320 }
321 }
322
323 if let Some(stderr) = child.stderr.take() {
324 copy_handles.push(tokio::spawn(spawn_copy(
325 stderr,
326 log_writer.clone(),
327 StreamMirror::Stderr,
328 )));
329 }
330
331 let registration_guard = {
332 let mut record = TaskRecord::new(
333 Utc::now(),
334 child_pid.to_string(),
335 log_path.to_string_lossy().into_owned(),
336 Some(platform::current_pid()),
337 );
338
339 match ProcessTreeInfo::current() {
341 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
342 Ok(updated) => {
343 record = updated;
344 }
345 Err(err) => {
346 warn(format!("Failed to attach process tree info: {}", err));
347 }
348 },
349 Err(err) => {
350 warn(format!("Failed to get process tree info: {}", err));
351 }
352 }
353
354 if let Err(err) = registry.register(child_pid, &record) {
355 platform::terminate_process(child_pid);
356 let _ = child.wait().await;
357 return Err(err.into());
358 }
359 Some(RegistrationGuard::new(registry, child_pid))
360 };
361
362 let status = if let Some(timeout_duration) = timeout {
364 tokio::select! {
365 result = child.wait() => result?,
366 _ = tokio::time::sleep(timeout_duration) => {
367 platform::terminate_process(child_pid);
368 let _ = child.wait().await;
369 return Err(ProcessError::Other(format!(
370 "CLI execution timed out after {:?}",
371 timeout_duration
372 )));
373 }
374 }
375 } else {
376 child.wait().await?
377 };
378
379 drop(signal_guard);
380
381 for handle in copy_handles {
382 match handle.await {
383 Ok(result) => result?,
384 Err(_) => {
385 return Err(io::Error::other("Log writer task failed").into());
386 }
387 }
388 }
389
390 {
391 let mut writer = log_writer.lock().await;
392 writer.flush().await?;
393 writer.get_ref().sync_all().await?;
394 }
395
396 if !is_capture_mode {
398 eprintln!("完整日志已保存到: {}", log_path.display());
399 }
400
401 if let Some(guard) = registration_guard {
402 let completed_at = Utc::now();
403 let exit_code = status.code();
404 let result = match (status.success(), exit_code) {
405 (true, _) => Some(if is_capture_mode {
406 "codegen_success".to_owned()
407 } else {
408 "success".to_owned()
409 }),
410 (false, Some(code)) => Some(format!(
411 "{}_failed_with_exit_code_{code}",
412 if is_capture_mode { "codegen" } else { "cli" }
413 )),
414 (false, None) => Some(format!(
415 "{}_failed_without_exit_code",
416 if is_capture_mode { "codegen" } else { "cli" }
417 )),
418 };
419 let _ = guard.mark_completed(result, exit_code, completed_at);
420 }
421
422 let captured_output = if let OutputStrategy::Capture(buffer) = output_strategy {
424 let output = buffer.lock().await.clone();
425 let output_str = String::from_utf8_lossy(&output).to_string();
426
427 if !status.success() {
428 return Err(ProcessError::Other(format!(
429 "{} CLI failed with exit code {}: {}",
430 cli_type.display_name(),
431 extract_exit_code(status),
432 output_str
433 )));
434 }
435
436 Some(output_str)
437 } else {
438 None
439 };
440
441 Ok((extract_exit_code(status), captured_output))
442}
443
444fn generate_log_path(pid: u32) -> io::Result<PathBuf> {
452 let log_dir = std::env::temp_dir().join(".aiw").join("logs");
458
459 if !log_dir.exists() {
461 std::fs::create_dir_all(&log_dir)?;
462
463 #[cfg(unix)]
465 {
466 use std::os::unix::fs::PermissionsExt;
467 let mut perms = std::fs::metadata(&log_dir)?.permissions();
468 perms.set_mode(0o700); std::fs::set_permissions(&log_dir, perms)?;
470 }
471 }
472
473 Ok(log_dir.join(format!("{pid}.log")))
474}
475
476#[derive(Copy, Clone)]
477enum StreamMirror {
478 Stdout,
479 Stderr,
480}
481
482impl StreamMirror {
483 async fn write(self, data: &[u8]) -> io::Result<()> {
484 use tokio::io::AsyncWriteExt;
485 match self {
486 StreamMirror::Stdout => {
487 let mut handle = tokio::io::stdout();
488 handle.write_all(data).await?;
489 handle.flush().await
490 }
491 StreamMirror::Stderr => {
492 let mut handle = tokio::io::stderr();
493 handle.write_all(data).await?;
494 handle.flush().await
495 }
496 }
497 }
498}
499
500async fn spawn_copy<R>(
501 mut reader: R,
502 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
503 mirror: StreamMirror,
504) -> io::Result<()>
505where
506 R: AsyncRead + Unpin + Send + 'static,
507{
508 use tokio::io::AsyncReadExt;
509
510 let mut buffer = [0u8; 8192];
511 loop {
512 let read = reader.read(&mut buffer).await?;
513 if read == 0 {
514 break;
515 }
516 let chunk = &buffer[..read];
517 {
518 let mut guard = writer.lock().await;
519 guard.write_all(chunk).await?;
520 guard.flush().await?;
521 }
522 mirror.write(chunk).await?;
523 }
524 Ok(())
525}
526
527async fn spawn_copy_with_capture<R>(
529 mut reader: R,
530 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
531 capture_buffer: Arc<Mutex<Vec<u8>>>,
532) -> io::Result<()>
533where
534 R: AsyncRead + Unpin + Send + 'static,
535{
536 use tokio::io::AsyncReadExt;
537
538 let mut buffer = [0u8; 8192];
539 loop {
540 let read = reader.read(&mut buffer).await?;
541 if read == 0 {
542 break;
543 }
544 let chunk = &buffer[..read];
545
546 {
548 let mut guard = writer.lock().await;
549 guard.write_all(chunk).await?;
550 guard.flush().await?;
551 }
552
553 {
555 let mut capture = capture_buffer.lock().await;
556 capture.extend_from_slice(chunk);
557 }
558 }
559 Ok(())
560}
561
562fn extract_exit_code(status: ExitStatus) -> i32 {
563 status.code().unwrap_or(1)
564}
565
566struct RegistrationGuard<'a, S: TaskStorage> {
567 registry: &'a Registry<S>,
568 pid: u32,
569 active: bool,
570}
571
572impl<'a, S: TaskStorage> RegistrationGuard<'a, S> {
573 fn new(registry: &'a Registry<S>, pid: u32) -> Self {
574 Self {
575 registry,
576 pid,
577 active: true,
578 }
579 }
580
581 fn mark_completed(
582 mut self,
583 result: Option<String>,
584 exit_code: Option<i32>,
585 completed_at: DateTime<Utc>,
586 ) -> Result<(), RegistryError> {
587 if self.active {
588 self.registry
589 .mark_completed(self.pid, result, exit_code, completed_at)?;
590 self.active = false;
591 }
592 Ok(())
593 }
594}
595
596impl<S: TaskStorage> Drop for RegistrationGuard<'_, S> {
597 fn drop(&mut self) {
598 }
602}
603
604pub async fn start_interactive_cli<S: TaskStorage>(
606 registry: &Registry<S>,
607 cli_type: &CliType,
608 provider: Option<String>,
609) -> Result<i32, ProcessError> {
610 platform::init_platform();
611
612 let terminate_wrapper = |pid: u32| {
613 platform::terminate_process(pid);
614 Ok(())
615 };
616 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
617
618 let provider_manager = ProviderManager::new()
620 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
621
622 let (provider_name, provider_config) = if let Some(name) = provider {
624 let config = provider_manager
625 .get_provider(&name)
626 .map_err(|e| ProcessError::Other(e.to_string()))?;
627 (name, config)
628 } else {
629 let (name, config) = provider_manager
630 .get_default_provider()
631 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
632 (name, config)
633 };
634
635 let _ai_type = match cli_type {
637 CliType::Claude => AiType::Claude,
638 CliType::Codex => AiType::Codex,
639 CliType::Gemini => AiType::Gemini,
640 };
641
642 if provider_name != *"official" {
644 eprintln!(
645 "Using provider: {} ({})",
646 provider_name,
647 provider_config.summary()
648 );
649 }
650
651 let cli_command = get_cli_command(cli_type).await?;
652
653 let mut command = Command::new(&cli_command);
655
656 let interactive_args = cli_type.build_interactive_args();
658 command.args(&interactive_args);
659
660 command.stdin(Stdio::inherit());
661 command.stdout(Stdio::inherit());
662 command.stderr(Stdio::inherit());
663
664 #[cfg(unix)]
666 {
667 unsafe {
668 command.pre_exec(|| {
669 let result = libc::setpgid(0, 0);
671 if result != 0 {
672 return Err(io::Error::last_os_error());
673 }
674 #[cfg(target_os = "linux")]
676 {
677 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
678 if result != 0 {
679 return Err(io::Error::last_os_error());
680 }
681 }
682 Ok(())
683 });
684 }
685 }
686
687 for (key, value) in &provider_config.env {
689 command.env(key, value);
690 }
691
692 let mut child = command.spawn()?;
693 let child_pid = child
694 .id()
695 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
696
697 let log_path = generate_log_path(child_pid)?;
699 let record = TaskRecord::new(
700 Utc::now(),
701 child_pid.to_string(),
702 log_path.to_string_lossy().into_owned(),
703 Some(platform::current_pid()),
704 );
705
706 let record = match ProcessTreeInfo::current() {
708 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
709 Ok(updated) => updated,
710 Err(err) => {
711 warn(format!("Failed to attach process tree info: {}", err));
712 record
713 }
714 },
715 Err(err) => {
716 warn(format!("Failed to get process tree info: {}", err));
717 record
718 }
719 };
720
721 if let Err(err) = registry.register(child_pid, &record) {
722 platform::terminate_process(child_pid);
723 let _ = child.wait().await;
724 return Err(err.into());
725 }
726
727 let registration_guard = RegistrationGuard::new(registry, child_pid);
728 let signal_guard = signal::install(child_pid)?;
729
730 let status = child.wait().await?;
731 drop(signal_guard);
732
733 let completed_at = Utc::now();
735 let exit_code = status.code();
736 let result = match (status.success(), exit_code) {
737 (true, _) => Some("interactive_session_completed".to_owned()),
738 (false, Some(code)) => Some(format!("interactive_session_failed_with_exit_code_{code}")),
739 (false, None) => Some("interactive_session_failed_without_exit_code".to_owned()),
740 };
741 let _ = registration_guard.mark_completed(result, exit_code, completed_at);
742
743 Ok(extract_exit_code(status))
744}
745
746pub async fn execute_multiple_clis<S: TaskStorage>(
748 registry: &Registry<S>,
749 cli_selector: &crate::cli_type::CliSelector,
750 task_prompt: &str,
751 provider: Option<String>,
752) -> Result<Vec<i32>, ProcessError> {
753 let mut exit_codes = Vec::new();
754
755 for cli_type in &cli_selector.types {
756 let cli_args = cli_type.build_full_access_args(task_prompt);
757 let os_args: Vec<OsString> = cli_args.into_iter().map(|s| s.into()).collect();
758
759 let exit_code = execute_cli(registry, cli_type, &os_args, provider.clone()).await?;
760 exit_codes.push(exit_code);
761 }
762
763 Ok(exit_codes)
764}