1use crate::cli_type::CliType;
2use crate::core::models::ProcessTreeInfo;
3use crate::core::process_tree::ProcessTreeError;
4use crate::error::RegistryError;
5use crate::logging::debug;
6use crate::logging::warn;
7#[cfg(windows)]
8use crate::platform::ChildResources;
9use crate::platform::{self};
10use crate::provider::{AiType, ProviderManager};
11use crate::signal;
12use crate::storage::TaskStorage;
13use crate::task_record::TaskRecord;
14use crate::unified_registry::Registry;
15use chrono::{DateTime, Utc};
16use std::collections::VecDeque;
17use std::env;
18use std::ffi::OsString;
19use std::io;
20use std::path::PathBuf;
21use std::process::{ExitStatus, Stdio};
22use std::sync::Arc;
23use thiserror::Error;
24use tokio::fs::OpenOptions;
25use tokio::io::{AsyncRead, AsyncWriteExt, BufWriter};
26use tokio::process::Command;
27use tokio::sync::Mutex;
28
29#[derive(Debug, Error)]
30pub enum ProcessError {
31 #[error("IO error: {0}")]
32 Io(#[from] io::Error),
33 #[error("Registry error: {0}")]
34 Registry(#[from] RegistryError),
35 #[error("Process tree error: {0}")]
36 ProcessTree(#[from] ProcessTreeError),
37 #[error("CLI executable not found: {0}")]
38 CliNotFound(String),
39 #[error("{0}")]
40 Other(String),
41}
42
43async fn get_cli_command(cli_type: &CliType) -> Result<String, ProcessError> {
44 if let Ok(custom_path) = env::var(cli_type.env_var_name()) {
46 if custom_path.is_empty() {
47 return Err(ProcessError::CliNotFound(format!(
48 "{} environment variable is empty",
49 cli_type.env_var_name()
50 )));
51 }
52 return Ok(custom_path);
53 }
54
55 let default_cmd = cli_type.command_name();
57
58 if cfg!(windows) {
60 let output = Command::new("where")
61 .arg(default_cmd)
62 .output()
63 .await
64 .map_err(|_| {
65 ProcessError::CliNotFound(format!(
66 "Failed to check if '{}' exists in PATH",
67 default_cmd
68 ))
69 })?;
70
71 if output.status.success() {
72 let stdout = String::from_utf8_lossy(&output.stdout);
73 for line in stdout.lines() {
75 if line.ends_with(".cmd") || line.ends_with(".bat") || line.ends_with(".exe") {
76 return Ok(line.to_string());
77 }
78 }
79 if let Some(first_line) = stdout.lines().next() {
81 return Ok(first_line.to_string());
82 }
83 }
84
85 return Err(ProcessError::CliNotFound(format!(
86 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
87 default_cmd,
88 cli_type.env_var_name()
89 )));
90 } else {
91 let output = Command::new("which")
92 .arg(default_cmd)
93 .output()
94 .await
95 .map_err(|_| {
96 ProcessError::CliNotFound(format!(
97 "Failed to check if '{}' exists in PATH",
98 default_cmd
99 ))
100 })?;
101
102 if !output.status.success() {
103 return Err(ProcessError::CliNotFound(format!(
104 "'{}' not found in PATH. Set {} environment variable or ensure it's in PATH",
105 default_cmd,
106 cli_type.env_var_name()
107 )));
108 }
109 }
110
111 Ok(default_cmd.to_string())
112}
113
114enum OutputStrategy {
116 Mirror,
118 Capture(Arc<Mutex<Vec<u8>>>),
120}
121
122pub async fn execute_cli<S: TaskStorage>(
123 registry: &Registry<S>,
124 cli_type: &CliType,
125 args: &[OsString],
126 provider: Option<String>,
127) -> Result<i32, ProcessError> {
128 execute_cli_internal(
129 registry,
130 cli_type,
131 args,
132 provider,
133 None,
134 OutputStrategy::Mirror,
135 )
136 .await
137 .map(|(exit_code, _)| exit_code)
138}
139
140pub async fn execute_cli_with_output<S: TaskStorage>(
142 registry: &Registry<S>,
143 cli_type: &CliType,
144 args: &[OsString],
145 provider: Option<String>,
146 timeout: std::time::Duration,
147) -> Result<String, ProcessError> {
148 let buffer = Arc::new(Mutex::new(Vec::new()));
149 let (_, output_opt) = execute_cli_internal(
150 registry,
151 cli_type,
152 args,
153 provider,
154 Some(timeout),
155 OutputStrategy::Capture(buffer.clone()),
156 )
157 .await?;
158
159 match output_opt {
160 Some(output) => Ok(output),
161 None => Err(ProcessError::Other(
162 "Output capture failed unexpectedly".to_string(),
163 )),
164 }
165}
166
167async fn execute_cli_internal<S: TaskStorage>(
169 registry: &Registry<S>,
170 cli_type: &CliType,
171 args: &[OsString],
172 provider: Option<String>,
173 timeout: Option<std::time::Duration>,
174 output_strategy: OutputStrategy,
175) -> Result<(i32, Option<String>), ProcessError> {
176 let is_capture_mode = matches!(output_strategy, OutputStrategy::Capture(_));
177
178 platform::init_platform();
179
180 let terminate_wrapper = |pid: u32| {
181 platform::terminate_process(pid);
182 Ok(())
183 };
184 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
185
186 let provider_manager = ProviderManager::new()
188 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
189
190 let (provider_name, provider_config) = if let Some(name) = provider {
192 let config = provider_manager
193 .get_provider(&name)
194 .map_err(|e| ProcessError::Other(e.to_string()))?;
195 (name, config)
196 } else {
197 let (name, config) = provider_manager
198 .get_default_provider()
199 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
200 (name, config)
201 };
202
203 let _ai_type = match cli_type {
205 CliType::Claude => AiType::Claude,
206 CliType::Codex => AiType::Codex,
207 CliType::Gemini => AiType::Gemini,
208 };
209
210 if provider_name != *"official" {
212 eprintln!(
213 "Using provider: {} ({})",
214 provider_name,
215 provider_config.summary()
216 );
217 }
218
219 let cli_command = get_cli_command(cli_type).await?;
220
221 let mut command = Command::new(&cli_command);
222 command.args(args);
223 command.stdin(Stdio::null());
224 command.stdout(Stdio::piped());
225 command.stderr(Stdio::piped());
226
227 #[cfg(unix)]
229 {
230 unsafe {
231 command.pre_exec(|| {
232 let result = libc::setpgid(0, 0);
233 if result != 0 {
234 return Err(io::Error::last_os_error());
235 }
236 #[cfg(target_os = "linux")]
237 {
238 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
239 if result != 0 {
240 return Err(io::Error::last_os_error());
241 }
242 }
243 Ok(())
244 });
245 }
246 }
247
248 for (key, value) in &provider_config.env {
250 command.env(key, value);
251 }
252
253 let mut child = command.spawn()?;
254 let child_pid = child
255 .id()
256 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
257
258 let log_path = match generate_log_path(child_pid) {
259 Ok(path) => path,
260 Err(err) => {
261 platform::terminate_process(child_pid);
262 let _ = child.wait();
263 return Err(err.into());
264 }
265 };
266
267 let log_file = match OpenOptions::new()
268 .create(true)
269 .write(true)
270 .truncate(true)
271 .open(&log_path)
272 .await
273 {
274 Ok(file) => file,
275 Err(err) => {
276 platform::terminate_process(child_pid);
277 let _ = child.wait().await;
278 return Err(err.into());
279 }
280 };
281
282 debug(format!(
283 "Started {} process pid={} log={}{}",
284 cli_type.display_name(),
285 child_pid,
286 log_path.display(),
287 if is_capture_mode {
288 " (capture mode)"
289 } else {
290 ""
291 }
292 ));
293
294 #[cfg(windows)]
295 {
296 let _resources = ChildResources::with_job(None);
297 }
298
299 let signal_guard = signal::install(child_pid)?;
300
301 let log_writer = Arc::new(Mutex::new(BufWriter::new(log_file)));
302 let mut copy_handles = Vec::new();
303
304 let scrolling_display = Arc::new(Mutex::new(ScrollingDisplay::new(DEFAULT_MAX_DISPLAY_LINES)));
306
307 if let Some(stdout) = child.stdout.take() {
309 match &output_strategy {
310 OutputStrategy::Mirror => {
311 copy_handles.push(tokio::spawn(spawn_copy(
312 stdout,
313 log_writer.clone(),
314 StreamMirror::Stdout,
315 scrolling_display.clone(),
316 )));
317 }
318 OutputStrategy::Capture(buffer) => {
319 let buffer_clone = buffer.clone();
320 let writer_clone = log_writer.clone();
321 copy_handles.push(tokio::spawn(async move {
322 spawn_copy_with_capture(stdout, writer_clone, buffer_clone).await
323 }));
324 }
325 }
326 }
327
328 if let Some(stderr) = child.stderr.take() {
329 copy_handles.push(tokio::spawn(spawn_copy(
330 stderr,
331 log_writer.clone(),
332 StreamMirror::Stderr,
333 scrolling_display.clone(),
334 )));
335 }
336
337 let registration_guard = {
338 let mut record = TaskRecord::new(
339 Utc::now(),
340 child_pid.to_string(),
341 log_path.to_string_lossy().into_owned(),
342 Some(platform::current_pid()),
343 );
344
345 match ProcessTreeInfo::current() {
347 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
348 Ok(updated) => {
349 record = updated;
350 }
351 Err(err) => {
352 warn(format!("Failed to attach process tree info: {}", err));
353 }
354 },
355 Err(err) => {
356 warn(format!("Failed to get process tree info: {}", err));
357 }
358 }
359
360 if let Err(err) = registry.register(child_pid, &record) {
361 platform::terminate_process(child_pid);
362 let _ = child.wait().await;
363 return Err(err.into());
364 }
365 Some(RegistrationGuard::new(registry, child_pid))
366 };
367
368 let status = if let Some(timeout_duration) = timeout {
370 tokio::select! {
371 result = child.wait() => result?,
372 _ = tokio::time::sleep(timeout_duration) => {
373 platform::terminate_process(child_pid);
374 let _ = child.wait().await;
375 return Err(ProcessError::Other(format!(
376 "CLI execution timed out after {:?}",
377 timeout_duration
378 )));
379 }
380 }
381 } else {
382 child.wait().await?
383 };
384
385 drop(signal_guard);
386
387 for handle in copy_handles {
388 match handle.await {
389 Ok(result) => result?,
390 Err(_) => {
391 return Err(io::Error::other("Log writer task failed").into());
392 }
393 }
394 }
395
396 {
397 let mut writer = log_writer.lock().await;
398 writer.flush().await?;
399 writer.get_ref().sync_all().await?;
400 }
401
402 if !is_capture_mode {
404 eprintln!("完整日志已保存到: {}", log_path.display());
405 }
406
407 if let Some(guard) = registration_guard {
408 let completed_at = Utc::now();
409 let exit_code = status.code();
410 let result = match (status.success(), exit_code) {
411 (true, _) => Some(if is_capture_mode {
412 "codegen_success".to_owned()
413 } else {
414 "success".to_owned()
415 }),
416 (false, Some(code)) => Some(format!(
417 "{}_failed_with_exit_code_{code}",
418 if is_capture_mode { "codegen" } else { "cli" }
419 )),
420 (false, None) => Some(format!(
421 "{}_failed_without_exit_code",
422 if is_capture_mode { "codegen" } else { "cli" }
423 )),
424 };
425 let _ = guard.mark_completed(result, exit_code, completed_at);
426 }
427
428 let captured_output = if let OutputStrategy::Capture(buffer) = output_strategy {
430 let output = buffer.lock().await.clone();
431 let output_str = String::from_utf8_lossy(&output).to_string();
432
433 if !status.success() {
434 return Err(ProcessError::Other(format!(
435 "{} CLI failed with exit code {}: {}",
436 cli_type.display_name(),
437 extract_exit_code(status),
438 output_str
439 )));
440 }
441
442 Some(output_str)
443 } else {
444 None
445 };
446
447 Ok((extract_exit_code(status), captured_output))
448}
449
450fn generate_log_path(pid: u32) -> io::Result<PathBuf> {
458 let log_dir = std::env::temp_dir().join(".aiw").join("logs");
464
465 if !log_dir.exists() {
467 std::fs::create_dir_all(&log_dir)?;
468
469 #[cfg(unix)]
471 {
472 use std::os::unix::fs::PermissionsExt;
473 let mut perms = std::fs::metadata(&log_dir)?.permissions();
474 perms.set_mode(0o700); std::fs::set_permissions(&log_dir, perms)?;
476 }
477 }
478
479 Ok(log_dir.join(format!("{pid}.log")))
480}
481
482struct ScrollingDisplay {
484 lines: VecDeque<String>,
485 max_lines: usize,
486 current_line_buffer: String,
487 displayed_count: usize,
488}
489
490impl ScrollingDisplay {
491 fn new(max_lines: usize) -> Self {
492 Self {
493 lines: VecDeque::with_capacity(max_lines),
494 max_lines,
495 current_line_buffer: String::new(),
496 displayed_count: 0,
497 }
498 }
499
500 fn process(&mut self, data: &[u8]) -> String {
502 let text = String::from_utf8_lossy(data);
503 let mut output = String::new();
504
505 for ch in text.chars() {
506 if ch == '\n' {
507 let line = std::mem::take(&mut self.current_line_buffer);
509 self.lines.push_back(line);
510
511 if self.lines.len() > self.max_lines {
513 self.lines.pop_front();
514 output.push_str(&self.redraw());
516 } else {
517 if let Some(last) = self.lines.back() {
519 output.push_str(last);
520 output.push('\n');
521 }
522 self.displayed_count = self.lines.len();
523 }
524 } else if ch == '\r' {
525 self.current_line_buffer.clear();
527 } else {
528 self.current_line_buffer.push(ch);
529 }
530 }
531
532 output
533 }
534
535 fn redraw(&mut self) -> String {
537 let mut output = String::new();
538
539 if self.displayed_count > 0 {
541 output.push_str(&format!("\x1b[{}A", self.displayed_count));
543 output.push_str("\x1b[J");
545 }
546
547 for line in &self.lines {
549 output.push_str(line);
550 output.push('\n');
551 }
552
553 self.displayed_count = self.lines.len();
554 output
555 }
556
557 fn flush_remaining(&mut self) -> String {
559 if self.current_line_buffer.is_empty() {
560 return String::new();
561 }
562 let line = std::mem::take(&mut self.current_line_buffer);
563 format!("{}\n", line)
564 }
565}
566
567#[derive(Copy, Clone)]
568enum StreamMirror {
569 Stdout,
570 Stderr,
571}
572
573impl StreamMirror {
574 async fn write(self, data: &[u8]) -> io::Result<()> {
575 use tokio::io::AsyncWriteExt;
576 match self {
577 StreamMirror::Stdout => {
578 let mut handle = tokio::io::stdout();
579 handle.write_all(data).await?;
580 handle.flush().await
581 }
582 StreamMirror::Stderr => {
583 let mut handle = tokio::io::stderr();
584 handle.write_all(data).await?;
585 handle.flush().await
586 }
587 }
588 }
589
590 async fn write_str(self, data: &str) -> io::Result<()> {
591 self.write(data.as_bytes()).await
592 }
593}
594
595const DEFAULT_MAX_DISPLAY_LINES: usize = 50;
597
598async fn spawn_copy<R>(
599 mut reader: R,
600 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
601 mirror: StreamMirror,
602 scrolling_display: Arc<Mutex<ScrollingDisplay>>,
603) -> io::Result<()>
604where
605 R: AsyncRead + Unpin + Send + 'static,
606{
607 use tokio::io::AsyncReadExt;
608
609 let mut buffer = [0u8; 8192];
610 loop {
611 let read = reader.read(&mut buffer).await?;
612 if read == 0 {
613 break;
614 }
615 let chunk = &buffer[..read];
616
617 {
619 let mut guard = writer.lock().await;
620 guard.write_all(chunk).await?;
621 guard.flush().await?;
622 }
623
624 let display_output = {
626 let mut display = scrolling_display.lock().await;
627 display.process(chunk)
628 };
629 if !display_output.is_empty() {
630 mirror.write_str(&display_output).await?;
631 }
632 }
633
634 let remaining = {
636 let mut display = scrolling_display.lock().await;
637 display.flush_remaining()
638 };
639 if !remaining.is_empty() {
640 mirror.write_str(&remaining).await?;
641 }
642
643 Ok(())
644}
645
646async fn spawn_copy_with_capture<R>(
648 mut reader: R,
649 writer: Arc<Mutex<BufWriter<tokio::fs::File>>>,
650 capture_buffer: Arc<Mutex<Vec<u8>>>,
651) -> io::Result<()>
652where
653 R: AsyncRead + Unpin + Send + 'static,
654{
655 use tokio::io::AsyncReadExt;
656
657 let mut buffer = [0u8; 8192];
658 loop {
659 let read = reader.read(&mut buffer).await?;
660 if read == 0 {
661 break;
662 }
663 let chunk = &buffer[..read];
664
665 {
667 let mut guard = writer.lock().await;
668 guard.write_all(chunk).await?;
669 guard.flush().await?;
670 }
671
672 {
674 let mut capture = capture_buffer.lock().await;
675 capture.extend_from_slice(chunk);
676 }
677 }
678 Ok(())
679}
680
681fn extract_exit_code(status: ExitStatus) -> i32 {
682 status.code().unwrap_or(1)
683}
684
685struct RegistrationGuard<'a, S: TaskStorage> {
686 registry: &'a Registry<S>,
687 pid: u32,
688 active: bool,
689}
690
691impl<'a, S: TaskStorage> RegistrationGuard<'a, S> {
692 fn new(registry: &'a Registry<S>, pid: u32) -> Self {
693 Self {
694 registry,
695 pid,
696 active: true,
697 }
698 }
699
700 fn mark_completed(
701 mut self,
702 result: Option<String>,
703 exit_code: Option<i32>,
704 completed_at: DateTime<Utc>,
705 ) -> Result<(), RegistryError> {
706 if self.active {
707 self.registry
708 .mark_completed(self.pid, result, exit_code, completed_at)?;
709 self.active = false;
710 }
711 Ok(())
712 }
713}
714
715impl<S: TaskStorage> Drop for RegistrationGuard<'_, S> {
716 fn drop(&mut self) {
717 }
721}
722
723pub async fn start_interactive_cli<S: TaskStorage>(
725 registry: &Registry<S>,
726 cli_type: &CliType,
727 provider: Option<String>,
728) -> Result<i32, ProcessError> {
729 platform::init_platform();
730
731 let terminate_wrapper = |pid: u32| {
732 platform::terminate_process(pid);
733 Ok(())
734 };
735 registry.sweep_stale_entries(Utc::now(), platform::process_alive, &terminate_wrapper)?;
736
737 let provider_manager = ProviderManager::new()
739 .map_err(|e| ProcessError::Other(format!("Failed to load provider: {}", e)))?;
740
741 let (provider_name, provider_config) = if let Some(name) = provider {
743 let config = provider_manager
744 .get_provider(&name)
745 .map_err(|e| ProcessError::Other(e.to_string()))?;
746 (name, config)
747 } else {
748 let (name, config) = provider_manager
749 .get_default_provider()
750 .ok_or_else(|| ProcessError::Other("No default provider configured".to_string()))?;
751 (name, config)
752 };
753
754 let _ai_type = match cli_type {
756 CliType::Claude => AiType::Claude,
757 CliType::Codex => AiType::Codex,
758 CliType::Gemini => AiType::Gemini,
759 };
760
761 if provider_name != *"official" {
763 eprintln!(
764 "Using provider: {} ({})",
765 provider_name,
766 provider_config.summary()
767 );
768 }
769
770 let cli_command = get_cli_command(cli_type).await?;
771
772 let mut command = Command::new(&cli_command);
774
775 let interactive_args = cli_type.build_interactive_args();
777 command.args(&interactive_args);
778
779 command.stdin(Stdio::inherit());
780 command.stdout(Stdio::inherit());
781 command.stderr(Stdio::inherit());
782
783 #[cfg(unix)]
785 {
786 unsafe {
787 command.pre_exec(|| {
788 let result = libc::setpgid(0, 0);
790 if result != 0 {
791 return Err(io::Error::last_os_error());
792 }
793 #[cfg(target_os = "linux")]
795 {
796 let result = libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM);
797 if result != 0 {
798 return Err(io::Error::last_os_error());
799 }
800 }
801 Ok(())
802 });
803 }
804 }
805
806 for (key, value) in &provider_config.env {
808 command.env(key, value);
809 }
810
811 let mut child = command.spawn()?;
812 let child_pid = child
813 .id()
814 .ok_or_else(|| io::Error::other("Failed to get child PID"))?;
815
816 let log_path = generate_log_path(child_pid)?;
818 let record = TaskRecord::new(
819 Utc::now(),
820 child_pid.to_string(),
821 log_path.to_string_lossy().into_owned(),
822 Some(platform::current_pid()),
823 );
824
825 let record = match ProcessTreeInfo::current() {
827 Ok(tree_info) => match record.clone().with_process_tree_info(tree_info) {
828 Ok(updated) => updated,
829 Err(err) => {
830 warn(format!("Failed to attach process tree info: {}", err));
831 record
832 }
833 },
834 Err(err) => {
835 warn(format!("Failed to get process tree info: {}", err));
836 record
837 }
838 };
839
840 if let Err(err) = registry.register(child_pid, &record) {
841 platform::terminate_process(child_pid);
842 let _ = child.wait().await;
843 return Err(err.into());
844 }
845
846 let registration_guard = RegistrationGuard::new(registry, child_pid);
847 let signal_guard = signal::install(child_pid)?;
848
849 let status = child.wait().await?;
850 drop(signal_guard);
851
852 let completed_at = Utc::now();
854 let exit_code = status.code();
855 let result = match (status.success(), exit_code) {
856 (true, _) => Some("interactive_session_completed".to_owned()),
857 (false, Some(code)) => Some(format!("interactive_session_failed_with_exit_code_{code}")),
858 (false, None) => Some("interactive_session_failed_without_exit_code".to_owned()),
859 };
860 let _ = registration_guard.mark_completed(result, exit_code, completed_at);
861
862 Ok(extract_exit_code(status))
863}
864
865pub async fn execute_multiple_clis<S: TaskStorage>(
867 registry: &Registry<S>,
868 cli_selector: &crate::cli_type::CliSelector,
869 task_prompt: &str,
870 provider: Option<String>,
871) -> Result<Vec<i32>, ProcessError> {
872 let mut exit_codes = Vec::new();
873
874 for cli_type in &cli_selector.types {
875 let cli_args = cli_type.build_full_access_args(task_prompt);
876 let os_args: Vec<OsString> = cli_args.into_iter().map(|s| s.into()).collect();
877
878 let exit_code = execute_cli(registry, cli_type, &os_args, provider.clone()).await?;
879 exit_codes.push(exit_code);
880 }
881
882 Ok(exit_codes)
883}