Skip to main content

harn_vm/
shells.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3#[cfg(not(windows))]
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6
7use crate::value::{VmError, VmValue};
8
9thread_local! {
10    static SELECTED_DEFAULT_SHELL_ID: RefCell<Option<String>> = const { RefCell::new(None) };
11}
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct ShellDescriptor {
15    pub id: String,
16    pub label: String,
17    pub path: String,
18    pub platform: String,
19    pub available: bool,
20    pub supports_login: bool,
21    pub supports_interactive: bool,
22    pub default_args: Vec<String>,
23    pub login_args: Vec<String>,
24    pub source: String,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub struct ShellCatalog {
29    pub shells: Vec<ShellDescriptor>,
30    pub default_shell_id: Option<String>,
31}
32
33#[derive(Clone, Debug, PartialEq, Eq)]
34pub struct ShellInvocation {
35    pub program: String,
36    pub args: Vec<String>,
37    pub command_arg_index: usize,
38    pub shell: ShellDescriptor,
39}
40
41pub fn clear_selected_default_shell_for_test() {
42    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = None);
43}
44
45pub fn discover_shells() -> ShellCatalog {
46    let shells = platform_shells();
47    let selected = SELECTED_DEFAULT_SHELL_ID.with(|selected| selected.borrow().clone());
48    let default_shell_id = selected
49        .filter(|id| {
50            shells
51                .iter()
52                .any(|shell| shell.id == *id && shell.available)
53        })
54        .or_else(|| {
55            shells
56                .iter()
57                .find(|shell| shell.available)
58                .map(|shell| shell.id.clone())
59        })
60        .or_else(|| shells.first().map(|shell| shell.id.clone()));
61    ShellCatalog {
62        shells,
63        default_shell_id,
64    }
65}
66
67pub fn get_default_shell() -> Option<ShellDescriptor> {
68    let catalog = discover_shells();
69    catalog
70        .default_shell_id
71        .as_deref()
72        .and_then(|id| catalog.shells.iter().find(|shell| shell.id == id))
73        .cloned()
74        .or_else(|| catalog.shells.first().cloned())
75}
76
77pub fn set_default_shell(shell_id: &str) -> Result<ShellDescriptor, String> {
78    let catalog = discover_shells();
79    let Some(shell) = catalog
80        .shells
81        .iter()
82        .find(|shell| shell.id == shell_id && shell.available)
83        .cloned()
84    else {
85        return Err(format!("unknown or unavailable shell id {shell_id:?}"));
86    };
87    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = Some(shell.id.clone()));
88    Ok(shell)
89}
90
91pub fn list_shells_vm_value() -> VmValue {
92    shell_catalog_to_vm_value(&discover_shells())
93}
94
95pub fn default_shell_vm_value() -> VmValue {
96    get_default_shell()
97        .map(|shell| shell_descriptor_to_vm_value(&shell))
98        .unwrap_or(VmValue::Nil)
99}
100
101pub fn set_default_shell_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
102    let shell_id = params
103        .get("shell_id")
104        .or_else(|| params.get("id"))
105        .and_then(vm_string)
106        .ok_or_else(|| {
107            VmError::Runtime("process.set_default_shell missing shell_id".to_string())
108        })?;
109    set_default_shell(shell_id)
110        .map(|shell| shell_descriptor_to_vm_value(&shell))
111        .map_err(|err| VmError::Runtime(format!("process.set_default_shell: {err}")))
112}
113
114pub fn shell_invocation_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
115    resolve_invocation_from_vm_params(params)
116        .map(|invocation| shell_invocation_to_vm_value(&invocation))
117        .map_err(|err| VmError::Runtime(format!("process.shell_invocation: {err}")))
118}
119
120pub fn default_shell_invocation(command: &str) -> Result<ShellInvocation, String> {
121    let shell = get_default_shell().ok_or_else(|| "no shell candidates available".to_string())?;
122    Ok(invocation_for_shell(
123        shell,
124        command.to_string(),
125        false,
126        false,
127    ))
128}
129
130pub fn resolve_invocation_from_vm_params(
131    params: &BTreeMap<String, VmValue>,
132) -> Result<ShellInvocation, String> {
133    // "{command}" is a downstream template placeholder, not a format string.
134    #[allow(clippy::literal_string_with_formatting_args)]
135    let command = params
136        .get("command")
137        .and_then(vm_string)
138        .unwrap_or("{command}")
139        .to_string();
140    let login = optional_bool(params, "login").unwrap_or(false);
141    let interactive = optional_bool(params, "interactive").unwrap_or(false);
142    let shell = resolve_shell_from_vm_params(params)?;
143    Ok(invocation_for_shell(shell, command, login, interactive))
144}
145
146pub fn resolve_shell_from_vm_params(
147    params: &BTreeMap<String, VmValue>,
148) -> Result<ShellDescriptor, String> {
149    if let Some(value) = params.get("shell") {
150        if let Some(shell) = value.as_dict() {
151            return shell_descriptor_from_vm_dict(shell);
152        }
153        if !matches!(value, VmValue::Nil) {
154            return Err(format!("shell must be a dict, got {}", value.type_name()));
155        }
156    }
157    if let Some(value) = params.get("shell_id") {
158        if let Some(shell_id) = vm_string(value) {
159            return shell_by_id(shell_id);
160        }
161        if !matches!(value, VmValue::Nil) {
162            return Err(format!(
163                "shell_id must be a string, got {}",
164                value.type_name()
165            ));
166        }
167    }
168    get_default_shell().ok_or_else(|| "no default shell available".to_string())
169}
170
171pub fn shell_descriptor_to_vm_value(shell: &ShellDescriptor) -> VmValue {
172    let mut map = BTreeMap::new();
173    map.insert("id".to_string(), string(&shell.id));
174    map.insert("label".to_string(), string(&shell.label));
175    map.insert("path".to_string(), string(&shell.path));
176    map.insert("platform".to_string(), string(&shell.platform));
177    map.insert("available".to_string(), VmValue::Bool(shell.available));
178    map.insert(
179        "supports_login".to_string(),
180        VmValue::Bool(shell.supports_login),
181    );
182    map.insert(
183        "supports_interactive".to_string(),
184        VmValue::Bool(shell.supports_interactive),
185    );
186    map.insert("default_args".to_string(), string_list(&shell.default_args));
187    map.insert("login_args".to_string(), string_list(&shell.login_args));
188    map.insert("source".to_string(), string(&shell.source));
189    VmValue::Dict(std::sync::Arc::new(map))
190}
191
192pub fn shell_invocation_to_vm_value(invocation: &ShellInvocation) -> VmValue {
193    let mut map = BTreeMap::new();
194    map.insert("program".to_string(), string(&invocation.program));
195    map.insert("args".to_string(), string_list(&invocation.args));
196    map.insert(
197        "command_arg_index".to_string(),
198        VmValue::Int(invocation.command_arg_index as i64),
199    );
200    map.insert(
201        "shell".to_string(),
202        shell_descriptor_to_vm_value(&invocation.shell),
203    );
204    VmValue::Dict(std::sync::Arc::new(map))
205}
206
207fn shell_catalog_to_vm_value(catalog: &ShellCatalog) -> VmValue {
208    let mut map = BTreeMap::new();
209    map.insert(
210        "shells".to_string(),
211        VmValue::List(std::sync::Arc::new(
212            catalog
213                .shells
214                .iter()
215                .map(shell_descriptor_to_vm_value)
216                .collect(),
217        )),
218    );
219    map.insert(
220        "default_shell_id".to_string(),
221        catalog
222            .default_shell_id
223            .as_ref()
224            .map(|id| string(id))
225            .unwrap_or(VmValue::Nil),
226    );
227    VmValue::Dict(std::sync::Arc::new(map))
228}
229
230fn shell_descriptor_from_vm_dict(
231    dict: &BTreeMap<String, VmValue>,
232) -> Result<ShellDescriptor, String> {
233    if let Some(path) = dict.get("path").and_then(vm_string) {
234        let id = dict
235            .get("id")
236            .and_then(vm_string)
237            .map(ToString::to_string)
238            .unwrap_or_else(|| shell_id_from_path(path));
239        let platform = dict
240            .get("platform")
241            .and_then(vm_string)
242            .unwrap_or(platform_name())
243            .to_string();
244        let label = dict
245            .get("label")
246            .and_then(vm_string)
247            .map(ToString::to_string)
248            .unwrap_or_else(|| id.clone());
249        let default_args = dict
250            .get("default_args")
251            .and_then(vm_string_list)
252            .unwrap_or_else(|| default_args_for_id(&id));
253        let login_args = dict
254            .get("login_args")
255            .and_then(vm_string_list)
256            .unwrap_or_else(|| login_args_for_id(&id));
257        let available = dict
258            .get("available")
259            .and_then(|value| match value {
260                VmValue::Bool(value) => Some(*value),
261                _ => None,
262            })
263            .unwrap_or_else(|| executable_available(path));
264        let supports_login = dict
265            .get("supports_login")
266            .and_then(|value| match value {
267                VmValue::Bool(value) => Some(*value),
268                _ => None,
269            })
270            .unwrap_or_else(|| supports_login_for_id(&id));
271        let supports_interactive = dict
272            .get("supports_interactive")
273            .and_then(|value| match value {
274                VmValue::Bool(value) => Some(*value),
275                _ => None,
276            })
277            .unwrap_or_else(|| supports_interactive_for_id(&id));
278        return Ok(ShellDescriptor {
279            id,
280            label,
281            path: path.to_string(),
282            platform,
283            available,
284            supports_login,
285            supports_interactive,
286            default_args,
287            login_args,
288            source: dict
289                .get("source")
290                .and_then(vm_string)
291                .unwrap_or("host")
292                .to_string(),
293        });
294    }
295    if let Some(id) = dict.get("id").and_then(vm_string) {
296        return shell_by_id(id);
297    }
298    Err("shell object requires `path` or `id`".to_string())
299}
300
301fn shell_by_id(shell_id: &str) -> Result<ShellDescriptor, String> {
302    discover_shells()
303        .shells
304        .into_iter()
305        .find(|shell| shell.id == shell_id)
306        .ok_or_else(|| format!("unknown shell id {shell_id:?}"))
307}
308
309fn invocation_for_shell(
310    shell: ShellDescriptor,
311    command: String,
312    login: bool,
313    interactive: bool,
314) -> ShellInvocation {
315    let mut args = if login && shell.supports_login && !shell.login_args.is_empty() {
316        shell.login_args.clone()
317    } else {
318        shell.default_args.clone()
319    };
320    if interactive && shell.supports_interactive && !args.iter().any(|arg| arg == "-i") {
321        args.insert(0, "-i".to_string());
322    }
323    let command_arg_index = args.len();
324    args.push(command);
325    ShellInvocation {
326        program: shell.path.clone(),
327        args,
328        command_arg_index,
329        shell,
330    }
331}
332
333#[cfg(windows)]
334fn platform_shells() -> Vec<ShellDescriptor> {
335    let mut shells = Vec::new();
336    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
337        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
338    }
339    if let Ok(value) = std::env::var("COMSPEC") {
340        push_shell(&mut shells, descriptor_for_path(&value, "env"));
341    }
342    for (id, label, executable) in [
343        ("pwsh", "PowerShell 7", "pwsh.exe"),
344        ("powershell", "Windows PowerShell", "powershell.exe"),
345        ("cmd", "cmd", "cmd.exe"),
346    ] {
347        let path = find_on_path(executable).unwrap_or_else(|| executable.to_string());
348        let mut shell = descriptor_for_path(&path, "fallback");
349        shell.id = id.to_string();
350        shell.label = label.to_string();
351        push_shell(&mut shells, shell);
352    }
353    shells
354}
355
356#[cfg(not(windows))]
357fn platform_shells() -> Vec<ShellDescriptor> {
358    let mut shells = Vec::new();
359    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
360        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
361    }
362    if let Ok(value) = std::env::var("SHELL") {
363        push_shell(&mut shells, descriptor_for_path(&value, "env"));
364    }
365    if let Some(value) = login_shell_from_passwd() {
366        push_shell(&mut shells, descriptor_for_path(&value, "login"));
367    }
368    for value in shells_from_etc_shells() {
369        push_shell(&mut shells, descriptor_for_path(&value, "etc_shells"));
370    }
371    for value in [
372        "/bin/zsh",
373        "/bin/bash",
374        "/bin/sh",
375        "/usr/bin/zsh",
376        "/usr/bin/bash",
377        "/usr/bin/sh",
378    ] {
379        push_shell(&mut shells, descriptor_for_path(value, "fallback"));
380    }
381    shells
382}
383
384fn push_shell(shells: &mut Vec<ShellDescriptor>, shell: ShellDescriptor) {
385    if shells.iter().any(|existing| existing.id == shell.id) {
386        return;
387    }
388    shells.push(shell);
389}
390
391fn descriptor_for_path(path: &str, source: &str) -> ShellDescriptor {
392    let id = shell_id_from_path(path);
393    ShellDescriptor {
394        id: id.clone(),
395        label: label_for_id(&id),
396        path: path.to_string(),
397        platform: platform_name().to_string(),
398        available: executable_available(path),
399        supports_login: supports_login_for_id(&id),
400        supports_interactive: supports_interactive_for_id(&id),
401        default_args: default_args_for_id(&id),
402        login_args: login_args_for_id(&id),
403        source: source.to_string(),
404    }
405}
406
407fn shell_id_from_path(path: &str) -> String {
408    let raw = Path::new(path)
409        .file_name()
410        .and_then(|value| value.to_str())
411        .unwrap_or(path)
412        .to_ascii_lowercase();
413    let file_name = raw.strip_suffix(".exe").unwrap_or(&raw);
414    match file_name {
415        "powershell" | "windowspowershell" => "powershell".to_string(),
416        "pwsh" => "pwsh".to_string(),
417        "cmd" => "cmd".to_string(),
418        "bash" => "bash".to_string(),
419        "zsh" => "zsh".to_string(),
420        "fish" => "fish".to_string(),
421        _ if file_name.is_empty() => "shell".to_string(),
422        _ => file_name.to_string(),
423    }
424}
425
426fn label_for_id(id: &str) -> String {
427    match id {
428        "pwsh" => "PowerShell 7",
429        "powershell" => "Windows PowerShell",
430        "cmd" => "cmd",
431        "bash" => "bash",
432        "zsh" => "zsh",
433        "fish" => "fish",
434        "sh" => "sh",
435        other => other,
436    }
437    .to_string()
438}
439
440fn default_args_for_id(id: &str) -> Vec<String> {
441    match id {
442        "cmd" => vec!["/C".to_string()],
443        "pwsh" | "powershell" => vec!["-NoProfile".to_string(), "-Command".to_string()],
444        _ => vec!["-c".to_string()],
445    }
446}
447
448fn login_args_for_id(id: &str) -> Vec<String> {
449    match id {
450        "cmd" | "pwsh" | "powershell" => default_args_for_id(id),
451        _ => vec!["-l".to_string(), "-c".to_string()],
452    }
453}
454
455fn supports_login_for_id(id: &str) -> bool {
456    !matches!(id, "cmd" | "pwsh" | "powershell")
457}
458
459fn supports_interactive_for_id(id: &str) -> bool {
460    !matches!(id, "cmd" | "pwsh" | "powershell")
461}
462
463fn platform_name() -> &'static str {
464    if cfg!(target_os = "macos") {
465        "darwin"
466    } else if cfg!(target_os = "windows") {
467        "windows"
468    } else if cfg!(target_os = "linux") {
469        "linux"
470    } else {
471        std::env::consts::OS
472    }
473}
474
475fn executable_available(path: &str) -> bool {
476    let path_obj = Path::new(path);
477    if path_obj.components().count() > 1 || path_obj.is_absolute() {
478        return path_obj.is_file();
479    }
480    find_on_path(path).is_some()
481}
482
483fn find_on_path(program: &str) -> Option<String> {
484    let path = std::env::var_os("PATH")?;
485    let candidates = path_candidates(program);
486    for dir in std::env::split_paths(&path) {
487        for candidate in &candidates {
488            let full = dir.join(candidate);
489            if full.is_file() {
490                return Some(full.display().to_string());
491            }
492        }
493    }
494    None
495}
496
497#[cfg(windows)]
498fn path_candidates(program: &str) -> Vec<PathBuf> {
499    let mut candidates = vec![PathBuf::from(program)];
500    if Path::new(program).extension().is_none() {
501        for ext in [".exe", ".cmd", ".bat"] {
502            candidates.push(PathBuf::from(format!("{program}{ext}")));
503        }
504    }
505    candidates
506}
507
508#[cfg(not(windows))]
509fn path_candidates(program: &str) -> Vec<PathBuf> {
510    vec![PathBuf::from(program)]
511}
512
513#[cfg(not(windows))]
514fn login_shell_from_passwd() -> Option<String> {
515    let username = std::env::var("USER")
516        .or_else(|_| std::env::var("LOGNAME"))
517        .ok()?;
518    let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
519    passwd.lines().find_map(|line| {
520        let mut parts = line.split(':');
521        let name = parts.next()?;
522        if name != username {
523            return None;
524        }
525        parts
526            .nth(5)
527            .map(str::trim)
528            .filter(|shell| {
529                !shell.is_empty() && !shell.ends_with("/false") && !shell.ends_with("/nologin")
530            })
531            .map(ToString::to_string)
532    })
533}
534
535#[cfg(not(windows))]
536fn shells_from_etc_shells() -> Vec<String> {
537    let Ok(content) = std::fs::read_to_string("/etc/shells") else {
538        return Vec::new();
539    };
540    let mut seen = BTreeSet::new();
541    content
542        .lines()
543        .map(str::trim)
544        .filter(|line| !line.is_empty() && !line.starts_with('#') && line.starts_with('/'))
545        .filter(|line| seen.insert((*line).to_string()))
546        .map(ToString::to_string)
547        .collect()
548}
549
550fn optional_bool(params: &BTreeMap<String, VmValue>, key: &str) -> Option<bool> {
551    match params.get(key) {
552        Some(VmValue::Bool(value)) => Some(*value),
553        _ => None,
554    }
555}
556
557fn vm_string(value: &VmValue) -> Option<&str> {
558    match value {
559        VmValue::String(value) => Some(value.as_ref()),
560        _ => None,
561    }
562}
563
564fn vm_string_list(value: &VmValue) -> Option<Vec<String>> {
565    let VmValue::List(values) = value else {
566        return None;
567    };
568    values
569        .iter()
570        .map(|value| vm_string(value).map(ToString::to_string))
571        .collect()
572}
573
574fn string(value: &str) -> VmValue {
575    VmValue::String(std::sync::Arc::from(value.to_string()))
576}
577
578fn string_list(values: &[String]) -> VmValue {
579    VmValue::List(std::sync::Arc::new(
580        values.iter().map(|value| string(value)).collect(),
581    ))
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn unix_shell_descriptor_uses_split_login_args() {
590        let shell = descriptor_for_path("/bin/zsh", "fallback");
591        assert_eq!(shell.id, "zsh");
592        assert_eq!(shell.default_args, vec!["-c"]);
593        assert_eq!(shell.login_args, vec!["-l", "-c"]);
594        assert!(shell.supports_login);
595        assert!(shell.supports_interactive);
596    }
597
598    #[test]
599    fn windows_shell_descriptor_distinguishes_cmd_and_pwsh() {
600        let cmd = descriptor_for_path("cmd.exe", "fallback");
601        assert_eq!(cmd.id, "cmd");
602        assert_eq!(cmd.default_args, vec!["/C"]);
603        assert!(!cmd.supports_login);
604
605        let pwsh = descriptor_for_path("pwsh.exe", "fallback");
606        assert_eq!(pwsh.id, "pwsh");
607        assert_eq!(pwsh.default_args, vec!["-NoProfile", "-Command"]);
608    }
609
610    #[test]
611    fn invocation_appends_command_after_shell_args() {
612        let shell = ShellDescriptor {
613            id: "zsh".to_string(),
614            label: "zsh".to_string(),
615            path: "/bin/zsh".to_string(),
616            platform: "darwin".to_string(),
617            available: true,
618            supports_login: true,
619            supports_interactive: true,
620            default_args: vec!["-c".to_string()],
621            login_args: vec!["-l".to_string(), "-c".to_string()],
622            source: "test".to_string(),
623        };
624        let invocation = invocation_for_shell(shell, "echo ok".to_string(), true, true);
625        assert_eq!(invocation.program, "/bin/zsh");
626        assert_eq!(invocation.args, vec!["-i", "-l", "-c", "echo ok"]);
627        assert_eq!(invocation.command_arg_index, 3);
628    }
629
630    #[test]
631    fn invocation_without_explicit_shell_uses_default_shell() {
632        clear_selected_default_shell_for_test();
633        let default_shell = get_default_shell().expect("test host should expose a default shell");
634
635        let mut params = BTreeMap::new();
636        params.insert("command".to_string(), string("echo default-shell"));
637
638        let invocation = resolve_invocation_from_vm_params(&params).unwrap();
639        assert_eq!(invocation.shell, default_shell);
640        assert_eq!(invocation.program, default_shell.path);
641        assert_eq!(
642            invocation.args[invocation.command_arg_index],
643            "echo default-shell"
644        );
645    }
646
647    #[test]
648    fn malformed_explicit_shell_fields_do_not_fall_back_to_default() {
649        let mut params = BTreeMap::new();
650        params.insert("shell".to_string(), VmValue::Int(1));
651        assert_eq!(
652            resolve_shell_from_vm_params(&params).unwrap_err(),
653            "shell must be a dict, got int"
654        );
655
656        let mut params = BTreeMap::new();
657        params.insert("shell_id".to_string(), VmValue::Int(1));
658        assert_eq!(
659            resolve_shell_from_vm_params(&params).unwrap_err(),
660            "shell_id must be a string, got int"
661        );
662    }
663}