1use crate::cli_type::CliType;
2use crate::core::models::ProcessTreeInfo;
3use crate::core::process_tree::ProcessTreeError;
4use crate::error::RegistryError;
5use crate::logging::debug;
6use crate::logging::warn;
7#[cfg(windows)]
8use crate::platform::ChildResources;
9use crate::platform::{self};
10use crate::provider::{AiType, ProviderManager};
11use crate::signal;
12use crate::storage::TaskStorage;
13use crate::task_record::TaskRecord;
14use crate::unified_registry::Registry;
15use chrono::{DateTime, Utc};
16use std::env;
17use std::ffi::OsString;
18use std::io;
19use std::path::PathBuf;
20use std::process::{ExitStatus, Stdio};
21use std::sync::Arc;
22use thiserror::Error;
23use tokio::fs::OpenOptions;
24use tokio::io::{AsyncRead, AsyncWriteExt, BufWriter};
25use tokio::process::Command;
26use tokio::sync::Mutex;
27
28#[derive(Debug, Error)]
29pub enum ProcessError {
30 #[error("IO error: {0}")]
31 Io(#[from] io::Error),
32 #[error("Registry error: {0}")]
33 Registry(#[from] RegistryError),
34 #[error("Process tree error: {0}")]
35 ProcessTree(#[from] ProcessTreeError),
36 #[error("CLI executable not found: {0}")]
37 CliNotFound(String),
38 #[error("{0}")]
39 Other(String),
40}
41
42async fn get_cli_command(cli_type: &CliType) -> Result<String, ProcessError> {
43 if let Ok(custom_path) = env::var(cli_type.env_var_name()) {
45 if custom_path.is_empty() {
46 return Err(ProcessError::CliNotFound(format!(
47 "{} environment variable is empty",
48 cli_type.env_var_name()
49 )));
50 }
51 return Ok(custom_path);
52 }
53
54 let default_cmd = cli_type.command_name();
56
57 if cfg!(windows) {
59 let output = Command::new("where")
60 .arg(default_cmd)
61 .output()
62 .await
63 .map_err(|_| {
64 ProcessError::CliNotFound(format!(
65 "Failed to check if '{}' exists in PATH",
66 default_cmd
67 ))
68 })?;
69
70 if output.status.success() {
71 let stdout = String::from_utf8_lossy(&output.stdout);
72 for line in stdout.lines() {
74 if line.ends_with(".cmd") || line.ends_with(".bat") || line.ends_with(".exe") {
75 return Ok(line.to_string());
76 }
77 }
78 if let Some(first_line) = stdout.lines().next() {
80 return Ok(first_line.to_string());
81 }
82 }
83
84 return Err(ProcessError::CliNotFound(format!(
85 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
86 default_cmd,
87 cli_type.env_var_name()
88 )));
89 } else {
90 let output = Command::new("which")
91 .arg(default_cmd)
92 .output()
93 .await
94 .map_err(|_| {
95 ProcessError::CliNotFound(format!(
96 "Failed to check if '{}' exists in PATH",
97 default_cmd
98 ))
99 })?;
100
101 if !output.status.success() {
102 return Err(ProcessError::CliNotFound(format!(
103 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
104 default_cmd,
105 cli_type.env_var_name()
106 )));
107 }
108 }
109
110 Ok(default_cmd.to_string())
111}
112
113enum OutputStrategy {
115 Mirror,
117 Capture(Arc<Mutex<Vec<u8>>>),
119}
120
121pub async fn execute_cli<S: TaskStorage>(
122 registry: &Registry<S>,
123 cli_type: &CliType,
124 args: &[OsString],
125 provider: Option<String>,
126) -> Result<i32, ProcessError> {
127 execute_cli_internal(registry, cli_type, args, provider, None, OutputStrategy::Mirror)
128 .await
129 .map(|(exit_code, _)| exit_code)
130}
131
132pub async fn execute_cli_with_output<S: TaskStorage>(
134 registry: &Registry<S>,
135 cli_type: &CliType,
136 args: &[OsString],
137 provider: Option<String>,
138 timeout: std::time::Duration,
139) -> Result<String, ProcessError> {
140 let buffer = Arc::new(Mutex::new(Vec::new()));
141 let (_, output_opt) = execute_cli_internal(
142 registry,
143 cli_type,
144 args,
145 provider,
146 Some(timeout),
147 OutputStrategy::Capture(buffer.clone()),
148 )
149 .await?;
150
151 match output_opt {
152 Some(output) => Ok(output),
153 None => Err(ProcessError::Other(
154 "Output capture failed unexpectedly".to_string(),
155 )),
156 }
157}
158
159async fn execute_cli_internal<S: TaskStorage>(
161 registry: &Registry<S>,
162 cli_type: &CliType,
163 args: &[OsString],
164 provider: Option<String>,
165 timeout: Option<std::time::Duration>,
166 output_strategy: OutputStrategy,
167) -> Result<(i32, Option<String>), ProcessError> {
168 let is_capture_mode = matches!(output_strategy, OutputStrategy::Capture(_));
169
170 platform::init_platform();
171
172 let terminate_wrapper = |pid: u32| {
173 platform::terminate_process(pid);
174 Ok(())
175 };
176 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
177
178 let provider_manager = ProviderManager::new()
180 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
181
182 let (provider_name, provider_config) = if let Some(name) = provider {
184 let config = provider_manager
185 .get_provider(&name)
186 .map_err(|e| ProcessError::Other(e.to_string()))?;
187 (name, config)
188 } else {
189 let (name, config) = provider_manager
190 .get_default_provider()
191 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
192 (name, config)
193 };
194
195 let _ai_type = match cli_type {
197 CliType::Claude => AiType::Claude,
198 CliType::Codex => AiType::Codex,
199 CliType::Gemini => AiType::Gemini,
200 };
201
202 if provider_name != *"official" {
204 eprintln!(
205 "Using provider: {} ({})",
206 provider_name,
207 provider_config.summary()
208 );
209 }
210
211 let cli_command = get_cli_command(cli_type).await?;
212
213 let mut command = Command::new(&cli_command);
214 command.args(args);
215 command.stdin(Stdio::null());
216 command.stdout(Stdio::piped());
217 command.stderr(Stdio::piped());
218
219 #[cfg(unix)]
221 {
222 unsafe {
223 command.pre_exec(|| {
224 let result = libc::setpgid(0, 0);
225 if result != 0 {
226 return Err(io::Error::last_os_error());
227 }
228 #[cfg(target_os = "linux")]
229 {
230 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
231 if result != 0 {
232 return Err(io::Error::last_os_error());
233 }
234 }
235 Ok(())
236 });
237 }
238 }
239
240 for (key, value) in &provider_config.env {
242 command.env(key, value);
243 }
244
245 let mut child = command.spawn()?;
246 let child_pid = child
247 .id()
248 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
249
250 let log_path = match generate_log_path(child_pid) {
251 Ok(path) => path,
252 Err(err) => {
253 platform::terminate_process(child_pid);
254 let _ = child.wait();
255 return Err(err.into());
256 }
257 };
258
259 let log_file = match OpenOptions::new()
260 .create(true)
261 .write(true)
262 .truncate(true)
263 .open(&log_path)
264 .await
265 {
266 Ok(file) => file,
267 Err(err) => {
268 platform::terminate_process(child_pid);
269 let _ = child.wait().await;
270 return Err(err.into());
271 }
272 };
273
274 debug(format!(
275 "Started {} process pid={} log={}{}",
276 cli_type.display_name(),
277 child_pid,
278 log_path.display(),
279 if is_capture_mode { " (capture mode)" } else { "" }
280 ));
281
282 #[cfg(windows)]
283 {
284 let _resources = ChildResources::with_job(None);
285 }
286
287 let signal_guard = signal::install(child_pid)?;
288
289 let log_writer = Arc::new(Mutex::new(BufWriter::new(log_file)));
290 let mut copy_handles = Vec::new();
291
292 if let Some(stdout) = child.stdout.take() {
294 match &output_strategy {
295 OutputStrategy::Mirror => {
296 copy_handles.push(tokio::spawn(spawn_copy(
297 stdout,
298 log_writer.clone(),
299 StreamMirror::Stdout,
300 )));
301 }
302 OutputStrategy::Capture(buffer) => {
303 let buffer_clone = buffer.clone();
304 let writer_clone = log_writer.clone();
305 copy_handles.push(tokio::spawn(async move {
306 spawn_copy_with_capture(stdout, writer_clone, buffer_clone).await
307 }));
308 }
309 }
310 }
311
312 if let Some(stderr) = child.stderr.take() {
313 copy_handles.push(tokio::spawn(spawn_copy(
314 stderr,
315 log_writer.clone(),
316 StreamMirror::Stderr,
317 )));
318 }
319
320 let registration_guard = {
321 let mut record = TaskRecord::new(
322 Utc::now(),
323 child_pid.to_string(),
324 log_path.to_string_lossy().into_owned(),
325 Some(platform::current_pid()),
326 );
327
328 match ProcessTreeInfo::current() {
330 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
331 Ok(updated) => {
332 record = updated;
333 }
334 Err(err) => {
335 warn(format!("Failed to attach process tree info: {}", err));
336 }
337 },
338 Err(err) => {
339 warn(format!("Failed to get process tree info: {}", err));
340 }
341 }
342
343 if let Err(err) = registry.register(child_pid, &record) {
344 platform::terminate_process(child_pid);
345 let _ = child.wait().await;
346 return Err(err.into());
347 }
348 Some(RegistrationGuard::new(registry, child_pid))
349 };
350
351 let status = if let Some(timeout_duration) = timeout {
353 tokio::select! {
354 result = child.wait() => result?,
355 _ = tokio::time::sleep(timeout_duration) => {
356 platform::terminate_process(child_pid);
357 let _ = child.wait().await;
358 return Err(ProcessError::Other(format!(
359 "CLI execution timed out after {:?}",
360 timeout_duration
361 )));
362 }
363 }
364 } else {
365 child.wait().await?
366 };
367
368 drop(signal_guard);
369
370 for handle in copy_handles {
371 match handle.await {
372 Ok(result) => result?,
373 Err(_) => {
374 return Err(io::Error::other("Log writer task failed").into());
375 }
376 }
377 }
378
379 {
380 let mut writer = log_writer.lock().await;
381 writer.flush().await?;
382 writer.get_ref().sync_all().await?;
383 }
384
385 if !is_capture_mode {
387 eprintln!("完整日志已保存到: {}", log_path.display());
388 }
389
390 if let Some(guard) = registration_guard {
391 let completed_at = Utc::now();
392 let exit_code = status.code();
393 let result = match (status.success(), exit_code) {
394 (true, _) => Some(if is_capture_mode {
395 "codegen_success".to_owned()
396 } else {
397 "success".to_owned()
398 }),
399 (false, Some(code)) => Some(format!(
400 "{}_failed_with_exit_code_{code}",
401 if is_capture_mode { "codegen" } else { "cli" }
402 )),
403 (false, None) => Some(format!(
404 "{}_failed_without_exit_code",
405 if is_capture_mode { "codegen" } else { "cli" }
406 )),
407 };
408 let _ = guard.mark_completed(result, exit_code, completed_at);
409 }
410
411 let captured_output = if let OutputStrategy::Capture(buffer) = output_strategy {
413 let output = buffer.lock().await.clone();
414 let output_str = String::from_utf8_lossy(&output).to_string();
415
416 if !status.success() {
417 return Err(ProcessError::Other(format!(
418 "{} CLI failed with exit code {}: {}",
419 cli_type.display_name(),
420 extract_exit_code(status),
421 output_str
422 )));
423 }
424
425 Some(output_str)
426 } else {
427 None
428 };
429
430 Ok((extract_exit_code(status), captured_output))
431}
432
433fn generate_log_path(pid: u32) -> io::Result<PathBuf> {
441 let log_dir = std::env::temp_dir().join(".aiw").join("logs");
447
448 if !log_dir.exists() {
450 std::fs::create_dir_all(&log_dir)?;
451
452 #[cfg(unix)]
454 {
455 use std::os::unix::fs::PermissionsExt;
456 let mut perms = std::fs::metadata(&log_dir)?.permissions();
457 perms.set_mode(0o700); std::fs::set_permissions(&log_dir, perms)?;
459 }
460 }
461
462 Ok(log_dir.join(format!("{pid}.log")))
463}
464
465#[derive(Copy, Clone)]
466enum StreamMirror {
467 Stdout,
468 Stderr,
469}
470
471impl StreamMirror {
472 async fn write(self, data: &[u8]) -> io::Result<()> {
473 use tokio::io::AsyncWriteExt;
474 match self {
475 StreamMirror::Stdout => {
476 let mut handle = tokio::io::stdout();
477 handle.write_all(data).await?;
478 handle.flush().await
479 }
480 StreamMirror::Stderr => {
481 let mut handle = tokio::io::stderr();
482 handle.write_all(data).await?;
483 handle.flush().await
484 }
485 }
486 }
487}
488
489async fn spawn_copy<R>(
490 mut reader: R,
491 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
492 mirror: StreamMirror,
493) -> io::Result<()>
494where
495 R: AsyncRead + Unpin + Send + 'static,
496{
497 use tokio::io::AsyncReadExt;
498
499 let mut buffer = [0u8; 8192];
500 loop {
501 let read = reader.read(&mut buffer).await?;
502 if read == 0 {
503 break;
504 }
505 let chunk = &buffer[..read];
506 {
507 let mut guard = writer.lock().await;
508 guard.write_all(chunk).await?;
509 guard.flush().await?;
510 }
511 mirror.write(chunk).await?;
512 }
513 Ok(())
514}
515
516async fn spawn_copy_with_capture<R>(
518 mut reader: R,
519 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
520 capture_buffer: Arc<Mutex<Vec<u8>>>,
521) -> io::Result<()>
522where
523 R: AsyncRead + Unpin + Send + 'static,
524{
525 use tokio::io::AsyncReadExt;
526
527 let mut buffer = [0u8; 8192];
528 loop {
529 let read = reader.read(&mut buffer).await?;
530 if read == 0 {
531 break;
532 }
533 let chunk = &buffer[..read];
534
535 {
537 let mut guard = writer.lock().await;
538 guard.write_all(chunk).await?;
539 guard.flush().await?;
540 }
541
542 {
544 let mut capture = capture_buffer.lock().await;
545 capture.extend_from_slice(chunk);
546 }
547 }
548 Ok(())
549}
550
551fn extract_exit_code(status: ExitStatus) -> i32 {
552 status.code().unwrap_or(1)
553}
554
555struct RegistrationGuard<'a, S: TaskStorage> {
556 registry: &'a Registry<S>,
557 pid: u32,
558 active: bool,
559}
560
561impl<'a, S: TaskStorage> RegistrationGuard<'a, S> {
562 fn new(registry: &'a Registry<S>, pid: u32) -> Self {
563 Self {
564 registry,
565 pid,
566 active: true,
567 }
568 }
569
570 fn mark_completed(
571 mut self,
572 result: Option<String>,
573 exit_code: Option<i32>,
574 completed_at: DateTime<Utc>,
575 ) -> Result<(), RegistryError> {
576 if self.active {
577 self.registry
578 .mark_completed(self.pid, result, exit_code, completed_at)?;
579 self.active = false;
580 }
581 Ok(())
582 }
583}
584
585impl<S: TaskStorage> Drop for RegistrationGuard<'_, S> {
586 fn drop(&mut self) {
587 }
591}
592
593pub async fn start_interactive_cli<S: TaskStorage>(
595 registry: &Registry<S>,
596 cli_type: &CliType,
597 provider: Option<String>,
598) -> Result<i32, ProcessError> {
599 platform::init_platform();
600
601 let terminate_wrapper = |pid: u32| {
602 platform::terminate_process(pid);
603 Ok(())
604 };
605 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
606
607 let provider_manager = ProviderManager::new()
609 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
610
611 let (provider_name, provider_config) = if let Some(name) = provider {
613 let config = provider_manager
614 .get_provider(&name)
615 .map_err(|e| ProcessError::Other(e.to_string()))?;
616 (name, config)
617 } else {
618 let (name, config) = provider_manager
619 .get_default_provider()
620 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
621 (name, config)
622 };
623
624 let _ai_type = match cli_type {
626 CliType::Claude => AiType::Claude,
627 CliType::Codex => AiType::Codex,
628 CliType::Gemini => AiType::Gemini,
629 };
630
631 if provider_name != *"official" {
633 eprintln!(
634 "Using provider: {} ({})",
635 provider_name,
636 provider_config.summary()
637 );
638 }
639
640 let cli_command = get_cli_command(cli_type).await?;
641
642 let mut command = Command::new(&cli_command);
644
645 let interactive_args = cli_type.build_interactive_args();
647 command.args(&interactive_args);
648
649 command.stdin(Stdio::inherit());
650 command.stdout(Stdio::inherit());
651 command.stderr(Stdio::inherit());
652
653 #[cfg(unix)]
655 {
656 unsafe {
657 command.pre_exec(|| {
658 let result = libc::setpgid(0, 0);
660 if result != 0 {
661 return Err(io::Error::last_os_error());
662 }
663 #[cfg(target_os = "linux")]
665 {
666 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
667 if result != 0 {
668 return Err(io::Error::last_os_error());
669 }
670 }
671 Ok(())
672 });
673 }
674 }
675
676 for (key, value) in &provider_config.env {
678 command.env(key, value);
679 }
680
681 let mut child = command.spawn()?;
682 let child_pid = child
683 .id()
684 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
685
686 let log_path = generate_log_path(child_pid)?;
688 let record = TaskRecord::new(
689 Utc::now(),
690 child_pid.to_string(),
691 log_path.to_string_lossy().into_owned(),
692 Some(platform::current_pid()),
693 );
694
695 let record = match ProcessTreeInfo::current() {
697 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
698 Ok(updated) => updated,
699 Err(err) => {
700 warn(format!("Failed to attach process tree info: {}", err));
701 record
702 }
703 },
704 Err(err) => {
705 warn(format!("Failed to get process tree info: {}", err));
706 record
707 }
708 };
709
710 if let Err(err) = registry.register(child_pid, &record) {
711 platform::terminate_process(child_pid);
712 let _ = child.wait().await;
713 return Err(err.into());
714 }
715
716 let registration_guard = RegistrationGuard::new(registry, child_pid);
717 let signal_guard = signal::install(child_pid)?;
718
719 let status = child.wait().await?;
720 drop(signal_guard);
721
722 let completed_at = Utc::now();
724 let exit_code = status.code();
725 let result = match (status.success(), exit_code) {
726 (true, _) => Some("interactive_session_completed".to_owned()),
727 (false, Some(code)) => Some(format!("interactive_session_failed_with_exit_code_{code}")),
728 (false, None) => Some("interactive_session_failed_without_exit_code".to_owned()),
729 };
730 let _ = registration_guard.mark_completed(result, exit_code, completed_at);
731
732 Ok(extract_exit_code(status))
733}
734
735pub async fn execute_multiple_clis<S: TaskStorage>(
737 registry: &Registry<S>,
738 cli_selector: &crate::cli_type::CliSelector,
739 task_prompt: &str,
740 provider: Option<String>,
741) -> Result<Vec<i32>, ProcessError> {
742 let mut exit_codes = Vec::new();
743
744 for cli_type in &cli_selector.types {
745 let cli_args = cli_type.build_full_access_args(task_prompt);
746 let os_args: Vec<OsString> = cli_args.into_iter().map(|s| s.into()).collect();
747
748 let exit_code = execute_cli(registry, cli_type, &os_args, provider.clone()).await?;
749 exit_codes.push(exit_code);
750 }
751
752 Ok(exit_codes)
753}