intelli_shell/utils/
process.rs

1use std::{
2    borrow::Cow,
3    cmp::Ordering,
4    collections::BTreeMap,
5    env,
6    ffi::OsStr,
7    io::{self, Read, Write},
8    ops::Deref,
9    path::{Path, PathBuf},
10    process::{self, ExitStatus, Stdio},
11    sync::LazyLock,
12    time::Duration,
13};
14
15use color_eyre::eyre::Context;
16use ignore::WalkBuilder;
17use os_info::Info;
18use sysinfo::{Pid, System};
19use tokio::io::{AsyncBufReadExt, BufReader};
20use tokio_util::sync::CancellationToken;
21use wait_timeout::ChildExt;
22
23#[derive(Debug)]
24pub struct ShellInfo {
25    pub kind: ShellType,
26    pub version: Option<String>,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, strum::Display, strum::EnumString)]
30pub enum ShellType {
31    #[strum(serialize = "cmd", serialize = "cmd.exe")]
32    Cmd,
33    #[strum(serialize = "powershell", serialize = "powershell.exe")]
34    WindowsPowerShell,
35    #[strum(to_string = "pwsh", serialize = "pwsh.exe")]
36    PowerShellCore,
37    #[strum(to_string = "bash", serialize = "bash.exe")]
38    Bash,
39    #[strum(serialize = "sh")]
40    Sh,
41    #[strum(serialize = "fish")]
42    Fish,
43    #[strum(serialize = "zsh")]
44    Zsh,
45    #[strum(to_string = "nu", serialize = "nu.exe")]
46    Nushell,
47    #[strum(default, to_string = "{0}")]
48    Other(String),
49}
50
51static PARENT_SHELL_INFO: LazyLock<ShellInfo> = LazyLock::new(|| {
52    // Default to `sh` on Unix and `powershell` on Windows if detection fails
53    let default = if cfg!(target_os = "windows") {
54        ShellType::WindowsPowerShell
55    } else {
56        ShellType::Sh
57    };
58
59    // In test mode, always return the default shell to avoid complications
60    if cfg!(test) {
61        tracing::info!("Using default shell for tests: {default}");
62        return ShellInfo {
63            kind: default,
64            version: None,
65        };
66    }
67
68    // Otherwise, try to detect the parent shell process
69    let pid = Pid::from_u32(process::id());
70
71    tracing::debug!("Retrieving info for pid {pid}");
72    let sys = System::new_all();
73
74    let parent_process = sys
75        .process(Pid::from_u32(process::id()))
76        .expect("Couldn't retrieve current process from pid")
77        .parent()
78        .and_then(|parent_pid| sys.process(parent_pid));
79
80    let Some(parent) = parent_process else {
81        tracing::warn!("Couldn't detect shell, assuming {default}");
82        return ShellInfo {
83            kind: default,
84            version: None,
85        };
86    };
87
88    let parent_name = parent
89        .name()
90        .to_str()
91        .expect("Invalid parent shell name")
92        .trim()
93        .to_lowercase();
94
95    let kind = if parent_name == "cargo" || parent_name == "cargo.exe" {
96        tracing::warn!("Executed through cargo, assuming {default}");
97        return ShellInfo {
98            kind: default,
99            version: None,
100        };
101    } else {
102        ShellType::try_from(parent_name.as_str()).expect("infallible")
103    };
104
105    tracing::info!("Detected shell: {kind}");
106
107    let exe_path = parent
108        .exe()
109        .map(|p| p.as_os_str())
110        .filter(|p| !p.is_empty())
111        .unwrap_or_else(|| parent_name.as_ref());
112    let version = get_shell_version(&kind, exe_path).inspect(|v| tracing::info!("Detected shell version: {v}"));
113
114    ShellInfo { kind, version }
115});
116
117/// A helper function to get the version from a shell's executable path
118fn get_shell_version(shell_kind: &ShellType, shell_path: impl AsRef<OsStr>) -> Option<String> {
119    // `cmd.exe` version is tied to the OS version, so we don't query it
120    if *shell_kind == ShellType::Cmd {
121        return None;
122    }
123
124    // Most shells respond to `--version`, except PowerShell
125    let mut command = std::process::Command::new(shell_path);
126    if matches!(shell_kind, ShellType::PowerShellCore | ShellType::WindowsPowerShell) {
127        command.args([
128            "-NoProfile",
129            "-Command",
130            "'PowerShell {0} ({1} Edition)' -f $PSVersionTable.PSVersion, $PSVersionTable.PSEdition",
131        ]);
132    } else {
133        command.arg("--version");
134    }
135
136    // Configure pipes for stdout and stderr to capture the output manually
137    let mut child = match command.stdout(Stdio::piped()).stderr(Stdio::piped()).spawn() {
138        Ok(child) => child,
139        Err(err) => {
140            tracing::warn!("Failed to spawn shell process: {err}");
141            return None;
142        }
143    };
144
145    // Wait for the process to exit, with a timeout
146    match child.wait_timeout(Duration::from_millis(250)) {
147        // The command finished within the timeout period
148        Ok(Some(status)) => {
149            if status.success() {
150                let mut output = String::new();
151                // Read the output from the stdout pipe
152                if let Some(mut stdout) = child.stdout {
153                    stdout.read_to_string(&mut output).unwrap_or_default();
154                }
155                // Return just the first line of the output
156                Some(output.lines().next().unwrap_or("").trim().to_string()).filter(|v| !v.is_empty())
157            } else {
158                tracing::warn!("Shell version command failed with status: {}", status);
159                None
160            }
161        }
162        // The command timed out
163        Ok(None) => {
164            // Kill the child process to prevent it from running forever
165            if let Err(err) = child.kill() {
166                tracing::warn!("Failed to kill timed-out process: {err}");
167            }
168            tracing::warn!("Shell version command timed out");
169            None
170        }
171        // An error occurred while waiting
172        Err(err) => {
173            tracing::warn!("Error waiting for shell version command: {err}");
174            None
175        }
176    }
177}
178
179/// Retrieves information about the current shell, including its type and version
180pub fn get_shell_info() -> &'static ShellInfo {
181    PARENT_SHELL_INFO.deref()
182}
183
184/// Retrieves the current shell type
185pub fn get_shell_type() -> &'static ShellType {
186    &get_shell_info().kind
187}
188
189/// A helper function to get the version from an executable (e.g. git)
190pub fn get_executable_version(root_cmd: impl AsRef<OsStr>) -> Option<String> {
191    if root_cmd.as_ref().is_empty() {
192        return None;
193    }
194
195    // Most shells commands respond to `--version`
196    let mut child = std::process::Command::new(root_cmd)
197        .arg("--version")
198        .stdout(Stdio::piped())
199        .stderr(Stdio::piped())
200        .spawn()
201        .ok()?;
202
203    // Wait for the process to exit, with a timeout
204    match child.wait_timeout(Duration::from_millis(250)) {
205        Ok(Some(status)) if status.success() => {
206            let mut output = String::new();
207            if let Some(mut stdout) = child.stdout {
208                stdout.read_to_string(&mut output).unwrap_or_default();
209            }
210            Some(output.lines().next().unwrap_or("").trim().to_string()).filter(|v| !v.is_empty())
211        }
212        Ok(None) => {
213            if let Err(err) = child.kill() {
214                tracing::warn!("Failed to kill timed-out process: {err}");
215            }
216            None
217        }
218        _ => None,
219    }
220}
221
222static OS_INFO: LazyLock<Info> = LazyLock::new(|| {
223    let info = os_info::get();
224    tracing::info!("Detected OS: {info}");
225    info
226});
227
228/// Retrieves the operating system information
229pub fn get_os_info() -> &'static Info {
230    &OS_INFO
231}
232
233static WORING_DIR: LazyLock<String> = LazyLock::new(|| {
234    std::env::current_dir()
235        .inspect_err(|err| tracing::warn!("Couldn't retrieve current dir: {err}"))
236        .ok()
237        .and_then(|p| p.to_str().map(|s| s.to_owned()))
238        .unwrap_or_default()
239});
240
241/// Retrieves the working directory
242pub fn get_working_dir() -> &'static str {
243    WORING_DIR.deref()
244}
245
246/// Formats an env var name into its shell representation, based on the current shell
247pub fn format_env_var(var: impl AsRef<str>) -> String {
248    let var = var.as_ref();
249    match get_shell_type() {
250        ShellType::Cmd => format!("%{var}%"),
251        ShellType::WindowsPowerShell | ShellType::PowerShellCore => format!("$env:{var}"),
252        ShellType::Nushell => format!("$env.{var}"),
253        _ => format!("${var}"),
254    }
255}
256
257/// Generates a string representation of the current working directory tree, respecting .gitignore files
258pub fn generate_working_dir_tree(max_depth: usize, entry_limit: usize) -> Option<String> {
259    let root = PathBuf::from(get_working_dir());
260    if !root.is_dir() {
261        return None;
262    }
263
264    let root_canon = root.canonicalize().ok()?;
265
266    // Phase 1: Collect all entries by depth and also get total child counts for every directory
267    let mut entries_by_depth: BTreeMap<usize, Vec<ignore::DirEntry>> = BTreeMap::new();
268    let mut total_child_counts: BTreeMap<PathBuf, usize> = BTreeMap::new();
269    let walker = WalkBuilder::new(&root_canon).max_depth(Some(max_depth + 1)).build();
270
271    for entry in walker.flatten() {
272        if entry.depth() == 0 {
273            continue;
274        }
275        if let Some(parent_path) = entry.path().parent() {
276            *total_child_counts.entry(parent_path.to_path_buf()).or_default() += 1;
277        }
278        entries_by_depth.entry(entry.depth()).or_default().push(entry);
279    }
280
281    // Phase 2: Create a limited list of entries using the breadth-first approach
282    let mut limited_entries: Vec<ignore::DirEntry> = Vec::with_capacity(entry_limit);
283    'outer: for (_depth, entries) in entries_by_depth {
284        for entry in entries {
285            if limited_entries.len() >= entry_limit {
286                break 'outer;
287            }
288            limited_entries.push(entry);
289        }
290    }
291
292    // Phase 3: Populate the display tree and add "..." where contents are truncated
293    let mut dir_children: BTreeMap<PathBuf, Vec<(String, bool)>> = BTreeMap::new();
294    for entry in limited_entries {
295        let is_dir = entry.path().is_dir();
296        if let Some(parent_path) = entry.path().parent() {
297            let file_name = entry.file_name().to_string_lossy().to_string();
298            dir_children
299                .entry(parent_path.to_path_buf())
300                .or_default()
301                .push((file_name, is_dir));
302        }
303    }
304    for (path, total_count) in total_child_counts {
305        let displayed_count = dir_children.get(&path).map_or(0, |v| v.len());
306        if displayed_count < total_count {
307            dir_children.entry(path).or_default().push(("...".to_string(), false));
308        }
309    }
310
311    // Sort the children in each directory alphabetically for consistent output
312    for children in dir_children.values_mut() {
313        children.sort_by(|a, b| {
314            // "..." is always last
315            if a.0 == "..." {
316                Ordering::Greater
317            } else if b.0 == "..." {
318                Ordering::Less
319            } else {
320                // Otherwise, sort alphabetically
321                a.0.cmp(&b.0)
322            }
323        });
324    }
325
326    // Phase 4: Build the final string
327    let mut tree_string = format!("{} (current working dir)\n", root_canon.display());
328    build_tree_from_map(&root_canon, "", &mut tree_string, &dir_children);
329    Some(tree_string)
330}
331
332/// Recursively builds the tree string from the pre-compiled map of directory children
333fn build_tree_from_map(
334    dir_path: &Path,
335    prefix: &str,
336    output: &mut String,
337    dir_children: &BTreeMap<PathBuf, Vec<(String, bool)>>,
338) {
339    let Some(entries) = dir_children.get(dir_path) else {
340        return;
341    };
342
343    let mut iter = entries.iter().peekable();
344    while let Some((name, is_dir)) = iter.next() {
345        let is_last = iter.peek().is_none();
346        let connector = if is_last { "└── " } else { "├── " };
347        let new_prefix = format!("{prefix}{}", if is_last { "    " } else { "│   " });
348
349        if *is_dir {
350            // This is a directory; let's see if we can collapse it
351            let mut path_components = vec![name.clone()];
352            let mut current_path = dir_path.join(name);
353
354            // Keep collapsing as long as the current directory has only one child, which is also a directory
355            while let Some(children) = dir_children.get(&current_path) {
356                if children.len() == 1 {
357                    let (child_name, child_is_dir) = &children[0];
358                    if *child_is_dir {
359                        path_components.push(child_name.clone());
360                        current_path.push(child_name);
361                        // Continue to the next level of nesting
362                        continue;
363                    }
364                }
365                // Stop collapsing
366                break;
367            }
368
369            // Print the combined, collapsed path.
370            let collapsed_name = path_components.join("/");
371            output.push_str(&format!("{prefix}{connector}{collapsed_name}/\n"));
372
373            // Recurse using the final path in the chain
374            build_tree_from_map(&current_path, &new_prefix, output, dir_children);
375        } else {
376            // This is a file or "...", print it normally.
377            output.push_str(&format!("{prefix}{connector}{name}\n"));
378        }
379    }
380}
381
382/// Decodes the output of a shell command based on the OS and some heuristics
383pub fn decode_output(bytes: &[u8]) -> Cow<'_, str> {
384    // On Windows, PowerShell is known to have inconsistent output encoding
385    if cfg!(windows) {
386        // A UTF-8 BOM is a strong signal that the content is UTF-8
387        if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) {
388            // It has a UTF-8 BOM, so treat as UTF-8 (after skipping the BOM)
389            return String::from_utf8_lossy(&bytes[3..]);
390        }
391
392        // If the byte stream contains NUL bytes, it is almost certainly UTF-16LE
393        if bytes.contains(&0) {
394            let (cow, _encoding_used, _had_errors) = encoding_rs::UTF_16LE.decode(bytes);
395            return cow;
396        }
397    }
398
399    // For all other cases (Linux, macOS, non-PowerShell on Windows, or PowerShell output without NUL bytes)
400    String::from_utf8_lossy(bytes)
401}
402
403/// Executes a shell command, inheriting the parent's `stdout` and `stderr`
404pub async fn execute_shell_command_inherit(
405    command: &str,
406    include_prompt: bool,
407    cancellation_token: CancellationToken,
408) -> color_eyre::Result<ExitStatus> {
409    let mut cmd = prepare_command_execution(command, true, include_prompt)?;
410
411    // Spawn the child process to get a handle to it
412    let mut child = cmd
413        .spawn()
414        .with_context(|| format!("Failed to spawn command: `{command}`"))?;
415
416    // Race the child process against the cancellation token
417    let status = tokio::select! {
418        // Prioritize cancellation token
419        biased;
420        // Task is cancelled
421        _ = cancellation_token.cancelled() => {
422            tracing::info!("Received cancellation signal, terminating child process...");
423            // Send a kill signal to the child process
424            child.kill().await.with_context(|| format!("Failed to kill child process for command: `{command}`"))?;
425            // Wait for the process to exit and get its status
426            child.wait().await.with_context(|| "Failed to await child process after kill")?
427        }
428        // The child process completes on its own
429        status = child.wait() => {
430            status.with_context(|| format!("Child process for command `{command}` failed"))?
431        }
432    };
433
434    Ok(status)
435}
436
437/// Executes a shell command, capturing `stdout` and `stderr`.
438///
439/// While capturing, it simultaneously prints both streams to the parent's `stderr` in real-time.
440pub async fn execute_shell_command_capture(
441    command: &str,
442    include_prompt: bool,
443    cancellation_token: CancellationToken,
444) -> color_eyre::Result<(ExitStatus, String, bool)> {
445    let mut cmd = prepare_command_execution(command, true, include_prompt)?;
446
447    // Configure the command to capture output streams by creating pipes
448    cmd.stdout(Stdio::piped());
449    cmd.stderr(Stdio::piped());
450
451    let mut child = cmd
452        .spawn()
453        .with_context(|| format!("Failed to spawn command: `{command}`"))?;
454
455    // Create buffered readers for the child's output streams
456    let mut stdout_reader = BufReader::new(child.stdout.take().unwrap()).lines();
457    let mut stderr_reader = BufReader::new(child.stderr.take().unwrap()).lines();
458
459    let mut output_capture = String::new();
460
461    // Flag to track if the process was terminated by the cancellation token
462    let mut terminated_by_token = false;
463
464    // Use boolean flags to track when each stream is finished
465    let mut stdout_done = false;
466    let mut stderr_done = false;
467
468    // Loop until both stdout and stderr streams have been completely read
469    while !stdout_done || !stderr_done {
470        tokio::select! {
471            // Prioritize cancellation token
472            biased;
473            // Task is cancelled
474            _ = cancellation_token.cancelled() => {
475                tracing::info!("Received cancellation signal, terminating child process...");
476                // Kill the child process, this will also cause the stdout/stderr streams to close
477                child.kill().await.with_context(|| format!("Failed to kill child process for command: `{command}`"))?;
478                // Set the flag to true since we handled the signal
479                terminated_by_token = true;
480                // Break the loop to proceed to the final `child.wait()`
481                break;
482            },
483            // Read from stdout if it's not done yet
484            res = stdout_reader.next_line(), if !stdout_done => {
485                match res {
486                    Ok(Some(line)) => {
487                        writeln!(io::stderr(), "{line}")?;
488                        output_capture.push_str(&line);
489                        output_capture.push('\n');
490                    },
491                    _ => stdout_done = true,
492                }
493            },
494            // Read from stderr if it's not done yet
495            res = stderr_reader.next_line(), if !stderr_done => {
496                match res {
497                    Ok(Some(line)) => {
498                        writeln!(io::stderr(), "{line}")?;
499                        output_capture.push_str(&line);
500                        output_capture.push('\n');
501                    },
502                    _ => stderr_done = true,
503                }
504            },
505            // This branch is taken once both output streams are done
506            else => break,
507        }
508    }
509
510    // Wait for the process to fully exit to get its final status
511    let status = child.wait().await.wrap_err("Failed to wait for command")?;
512
513    Ok((status, output_capture, terminated_by_token))
514}
515
516/// Builds a base `Command` object for executing a command string via the OS shell
517pub fn prepare_command_execution(
518    command: &str,
519    output_command: bool,
520    include_prompt: bool,
521) -> color_eyre::Result<tokio::process::Command> {
522    // Let the OS shell parse the command, supporting complex commands, arguments, and pipelines
523    let shell = get_shell_type();
524    let shell_arg = match shell {
525        ShellType::Cmd => "/c",
526        ShellType::WindowsPowerShell => "-Command",
527        _ => "-c",
528    };
529
530    tracing::info!("Executing command: {shell} {shell_arg} -- {command}");
531
532    // Print the command on stderr
533    if output_command {
534        let write_result = if include_prompt {
535            writeln!(
536                io::stderr(),
537                "{}{command}",
538                env::var("INTELLI_EXEC_PROMPT").as_deref().unwrap_or("> "),
539            )
540        } else {
541            writeln!(io::stderr(), "{command}")
542        };
543        // Handle broken pipe
544        if let Err(err) = write_result {
545            if err.kind() != io::ErrorKind::BrokenPipe {
546                return Err(err).wrap_err("Failed writing to stderr");
547            }
548            tracing::error!("Failed writing to stderr: Broken pipe");
549        };
550    }
551
552    // Build the base command object
553    let mut cmd = tokio::process::Command::new(shell.to_string());
554    cmd.arg(shell_arg).arg(command).kill_on_drop(true);
555    Ok(cmd)
556}