Skip to main content

harn_vm/
shells.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3#[cfg(not(windows))]
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::rc::Rc;
7
8use crate::value::{VmError, VmValue};
9
10thread_local! {
11    static SELECTED_DEFAULT_SHELL_ID: RefCell<Option<String>> = const { RefCell::new(None) };
12}
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct ShellDescriptor {
16    pub id: String,
17    pub label: String,
18    pub path: String,
19    pub platform: String,
20    pub available: bool,
21    pub supports_login: bool,
22    pub supports_interactive: bool,
23    pub default_args: Vec<String>,
24    pub login_args: Vec<String>,
25    pub source: String,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct ShellCatalog {
30    pub shells: Vec<ShellDescriptor>,
31    pub default_shell_id: Option<String>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct ShellInvocation {
36    pub program: String,
37    pub args: Vec<String>,
38    pub command_arg_index: usize,
39    pub shell: ShellDescriptor,
40}
41
42pub fn clear_selected_default_shell_for_test() {
43    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = None);
44}
45
46pub fn discover_shells() -> ShellCatalog {
47    let shells = platform_shells();
48    let selected = SELECTED_DEFAULT_SHELL_ID.with(|selected| selected.borrow().clone());
49    let default_shell_id = selected
50        .filter(|id| {
51            shells
52                .iter()
53                .any(|shell| shell.id == *id && shell.available)
54        })
55        .or_else(|| {
56            shells
57                .iter()
58                .find(|shell| shell.available)
59                .map(|shell| shell.id.clone())
60        })
61        .or_else(|| shells.first().map(|shell| shell.id.clone()));
62    ShellCatalog {
63        shells,
64        default_shell_id,
65    }
66}
67
68pub fn get_default_shell() -> Option<ShellDescriptor> {
69    let catalog = discover_shells();
70    catalog
71        .default_shell_id
72        .as_deref()
73        .and_then(|id| catalog.shells.iter().find(|shell| shell.id == id))
74        .cloned()
75        .or_else(|| catalog.shells.first().cloned())
76}
77
78pub fn set_default_shell(shell_id: &str) -> Result<ShellDescriptor, String> {
79    let catalog = discover_shells();
80    let Some(shell) = catalog
81        .shells
82        .iter()
83        .find(|shell| shell.id == shell_id && shell.available)
84        .cloned()
85    else {
86        return Err(format!("unknown or unavailable shell id {shell_id:?}"));
87    };
88    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = Some(shell.id.clone()));
89    Ok(shell)
90}
91
92pub fn list_shells_vm_value() -> VmValue {
93    shell_catalog_to_vm_value(&discover_shells())
94}
95
96pub fn default_shell_vm_value() -> VmValue {
97    get_default_shell()
98        .map(|shell| shell_descriptor_to_vm_value(&shell))
99        .unwrap_or(VmValue::Nil)
100}
101
102pub fn set_default_shell_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
103    let shell_id = params
104        .get("shell_id")
105        .or_else(|| params.get("id"))
106        .and_then(vm_string)
107        .ok_or_else(|| {
108            VmError::Runtime("process.set_default_shell missing shell_id".to_string())
109        })?;
110    set_default_shell(shell_id)
111        .map(|shell| shell_descriptor_to_vm_value(&shell))
112        .map_err(|err| VmError::Runtime(format!("process.set_default_shell: {err}")))
113}
114
115pub fn shell_invocation_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
116    resolve_invocation_from_vm_params(params)
117        .map(|invocation| shell_invocation_to_vm_value(&invocation))
118        .map_err(|err| VmError::Runtime(format!("process.shell_invocation: {err}")))
119}
120
121pub fn default_shell_invocation(command: &str) -> Result<ShellInvocation, String> {
122    let shell = get_default_shell().ok_or_else(|| "no shell candidates available".to_string())?;
123    Ok(invocation_for_shell(
124        shell,
125        command.to_string(),
126        false,
127        false,
128    ))
129}
130
131pub fn resolve_invocation_from_vm_params(
132    params: &BTreeMap<String, VmValue>,
133) -> Result<ShellInvocation, String> {
134    // "{command}" is a downstream template placeholder, not a format string.
135    #[allow(clippy::literal_string_with_formatting_args)]
136    let command = params
137        .get("command")
138        .and_then(vm_string)
139        .unwrap_or("{command}")
140        .to_string();
141    let login = optional_bool(params, "login").unwrap_or(false);
142    let interactive = optional_bool(params, "interactive").unwrap_or(false);
143    let shell = resolve_shell_from_vm_params(params)?;
144    Ok(invocation_for_shell(shell, command, login, interactive))
145}
146
147pub fn resolve_shell_from_vm_params(
148    params: &BTreeMap<String, VmValue>,
149) -> Result<ShellDescriptor, String> {
150    if let Some(value) = params.get("shell") {
151        if let Some(shell) = value.as_dict() {
152            return shell_descriptor_from_vm_dict(shell);
153        }
154        if !matches!(value, VmValue::Nil) {
155            return Err(format!("shell must be a dict, got {}", value.type_name()));
156        }
157    }
158    if let Some(value) = params.get("shell_id") {
159        if let Some(shell_id) = vm_string(value) {
160            return shell_by_id(shell_id);
161        }
162        if !matches!(value, VmValue::Nil) {
163            return Err(format!(
164                "shell_id must be a string, got {}",
165                value.type_name()
166            ));
167        }
168    }
169    get_default_shell().ok_or_else(|| "no default shell available".to_string())
170}
171
172pub fn shell_descriptor_to_vm_value(shell: &ShellDescriptor) -> VmValue {
173    let mut map = BTreeMap::new();
174    map.insert("id".to_string(), string(&shell.id));
175    map.insert("label".to_string(), string(&shell.label));
176    map.insert("path".to_string(), string(&shell.path));
177    map.insert("platform".to_string(), string(&shell.platform));
178    map.insert("available".to_string(), VmValue::Bool(shell.available));
179    map.insert(
180        "supports_login".to_string(),
181        VmValue::Bool(shell.supports_login),
182    );
183    map.insert(
184        "supports_interactive".to_string(),
185        VmValue::Bool(shell.supports_interactive),
186    );
187    map.insert("default_args".to_string(), string_list(&shell.default_args));
188    map.insert("login_args".to_string(), string_list(&shell.login_args));
189    map.insert("source".to_string(), string(&shell.source));
190    VmValue::Dict(Rc::new(map))
191}
192
193pub fn shell_invocation_to_vm_value(invocation: &ShellInvocation) -> VmValue {
194    let mut map = BTreeMap::new();
195    map.insert("program".to_string(), string(&invocation.program));
196    map.insert("args".to_string(), string_list(&invocation.args));
197    map.insert(
198        "command_arg_index".to_string(),
199        VmValue::Int(invocation.command_arg_index as i64),
200    );
201    map.insert(
202        "shell".to_string(),
203        shell_descriptor_to_vm_value(&invocation.shell),
204    );
205    VmValue::Dict(Rc::new(map))
206}
207
208fn shell_catalog_to_vm_value(catalog: &ShellCatalog) -> VmValue {
209    let mut map = BTreeMap::new();
210    map.insert(
211        "shells".to_string(),
212        VmValue::List(Rc::new(
213            catalog
214                .shells
215                .iter()
216                .map(shell_descriptor_to_vm_value)
217                .collect(),
218        )),
219    );
220    map.insert(
221        "default_shell_id".to_string(),
222        catalog
223            .default_shell_id
224            .as_ref()
225            .map(|id| string(id))
226            .unwrap_or(VmValue::Nil),
227    );
228    VmValue::Dict(Rc::new(map))
229}
230
231fn shell_descriptor_from_vm_dict(
232    dict: &BTreeMap<String, VmValue>,
233) -> Result<ShellDescriptor, String> {
234    if let Some(path) = dict.get("path").and_then(vm_string) {
235        let id = dict
236            .get("id")
237            .and_then(vm_string)
238            .map(ToString::to_string)
239            .unwrap_or_else(|| shell_id_from_path(path));
240        let platform = dict
241            .get("platform")
242            .and_then(vm_string)
243            .unwrap_or(platform_name())
244            .to_string();
245        let label = dict
246            .get("label")
247            .and_then(vm_string)
248            .map(ToString::to_string)
249            .unwrap_or_else(|| id.clone());
250        let default_args = dict
251            .get("default_args")
252            .and_then(vm_string_list)
253            .unwrap_or_else(|| default_args_for_id(&id));
254        let login_args = dict
255            .get("login_args")
256            .and_then(vm_string_list)
257            .unwrap_or_else(|| login_args_for_id(&id));
258        let available = dict
259            .get("available")
260            .and_then(|value| match value {
261                VmValue::Bool(value) => Some(*value),
262                _ => None,
263            })
264            .unwrap_or_else(|| executable_available(path));
265        let supports_login = dict
266            .get("supports_login")
267            .and_then(|value| match value {
268                VmValue::Bool(value) => Some(*value),
269                _ => None,
270            })
271            .unwrap_or_else(|| supports_login_for_id(&id));
272        let supports_interactive = dict
273            .get("supports_interactive")
274            .and_then(|value| match value {
275                VmValue::Bool(value) => Some(*value),
276                _ => None,
277            })
278            .unwrap_or_else(|| supports_interactive_for_id(&id));
279        return Ok(ShellDescriptor {
280            id,
281            label,
282            path: path.to_string(),
283            platform,
284            available,
285            supports_login,
286            supports_interactive,
287            default_args,
288            login_args,
289            source: dict
290                .get("source")
291                .and_then(vm_string)
292                .unwrap_or("host")
293                .to_string(),
294        });
295    }
296    if let Some(id) = dict.get("id").and_then(vm_string) {
297        return shell_by_id(id);
298    }
299    Err("shell object requires `path` or `id`".to_string())
300}
301
302fn shell_by_id(shell_id: &str) -> Result<ShellDescriptor, String> {
303    discover_shells()
304        .shells
305        .into_iter()
306        .find(|shell| shell.id == shell_id)
307        .ok_or_else(|| format!("unknown shell id {shell_id:?}"))
308}
309
310fn invocation_for_shell(
311    shell: ShellDescriptor,
312    command: String,
313    login: bool,
314    interactive: bool,
315) -> ShellInvocation {
316    let mut args = if login && shell.supports_login && !shell.login_args.is_empty() {
317        shell.login_args.clone()
318    } else {
319        shell.default_args.clone()
320    };
321    if interactive && shell.supports_interactive && !args.iter().any(|arg| arg == "-i") {
322        args.insert(0, "-i".to_string());
323    }
324    let command_arg_index = args.len();
325    args.push(command);
326    ShellInvocation {
327        program: shell.path.clone(),
328        args,
329        command_arg_index,
330        shell,
331    }
332}
333
334#[cfg(windows)]
335fn platform_shells() -> Vec<ShellDescriptor> {
336    let mut shells = Vec::new();
337    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
338        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
339    }
340    if let Ok(value) = std::env::var("COMSPEC") {
341        push_shell(&mut shells, descriptor_for_path(&value, "env"));
342    }
343    for (id, label, executable) in [
344        ("pwsh", "PowerShell 7", "pwsh.exe"),
345        ("powershell", "Windows PowerShell", "powershell.exe"),
346        ("cmd", "cmd", "cmd.exe"),
347    ] {
348        let path = find_on_path(executable).unwrap_or_else(|| executable.to_string());
349        let mut shell = descriptor_for_path(&path, "fallback");
350        shell.id = id.to_string();
351        shell.label = label.to_string();
352        push_shell(&mut shells, shell);
353    }
354    shells
355}
356
357#[cfg(not(windows))]
358fn platform_shells() -> Vec<ShellDescriptor> {
359    let mut shells = Vec::new();
360    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
361        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
362    }
363    if let Ok(value) = std::env::var("SHELL") {
364        push_shell(&mut shells, descriptor_for_path(&value, "env"));
365    }
366    if let Some(value) = login_shell_from_passwd() {
367        push_shell(&mut shells, descriptor_for_path(&value, "login"));
368    }
369    for value in shells_from_etc_shells() {
370        push_shell(&mut shells, descriptor_for_path(&value, "etc_shells"));
371    }
372    for value in [
373        "/bin/zsh",
374        "/bin/bash",
375        "/bin/sh",
376        "/usr/bin/zsh",
377        "/usr/bin/bash",
378        "/usr/bin/sh",
379    ] {
380        push_shell(&mut shells, descriptor_for_path(value, "fallback"));
381    }
382    shells
383}
384
385fn push_shell(shells: &mut Vec<ShellDescriptor>, shell: ShellDescriptor) {
386    if shells.iter().any(|existing| existing.id == shell.id) {
387        return;
388    }
389    shells.push(shell);
390}
391
392fn descriptor_for_path(path: &str, source: &str) -> ShellDescriptor {
393    let id = shell_id_from_path(path);
394    ShellDescriptor {
395        id: id.clone(),
396        label: label_for_id(&id),
397        path: path.to_string(),
398        platform: platform_name().to_string(),
399        available: executable_available(path),
400        supports_login: supports_login_for_id(&id),
401        supports_interactive: supports_interactive_for_id(&id),
402        default_args: default_args_for_id(&id),
403        login_args: login_args_for_id(&id),
404        source: source.to_string(),
405    }
406}
407
408fn shell_id_from_path(path: &str) -> String {
409    let raw = Path::new(path)
410        .file_name()
411        .and_then(|value| value.to_str())
412        .unwrap_or(path)
413        .to_ascii_lowercase();
414    let file_name = raw.strip_suffix(".exe").unwrap_or(&raw);
415    match file_name {
416        "powershell" | "windowspowershell" => "powershell".to_string(),
417        "pwsh" => "pwsh".to_string(),
418        "cmd" => "cmd".to_string(),
419        "bash" => "bash".to_string(),
420        "zsh" => "zsh".to_string(),
421        "fish" => "fish".to_string(),
422        _ if file_name.is_empty() => "shell".to_string(),
423        _ => file_name.to_string(),
424    }
425}
426
427fn label_for_id(id: &str) -> String {
428    match id {
429        "pwsh" => "PowerShell 7",
430        "powershell" => "Windows PowerShell",
431        "cmd" => "cmd",
432        "bash" => "bash",
433        "zsh" => "zsh",
434        "fish" => "fish",
435        "sh" => "sh",
436        other => other,
437    }
438    .to_string()
439}
440
441fn default_args_for_id(id: &str) -> Vec<String> {
442    match id {
443        "cmd" => vec!["/C".to_string()],
444        "pwsh" | "powershell" => vec!["-NoProfile".to_string(), "-Command".to_string()],
445        _ => vec!["-c".to_string()],
446    }
447}
448
449fn login_args_for_id(id: &str) -> Vec<String> {
450    match id {
451        "cmd" | "pwsh" | "powershell" => default_args_for_id(id),
452        _ => vec!["-l".to_string(), "-c".to_string()],
453    }
454}
455
456fn supports_login_for_id(id: &str) -> bool {
457    !matches!(id, "cmd" | "pwsh" | "powershell")
458}
459
460fn supports_interactive_for_id(id: &str) -> bool {
461    !matches!(id, "cmd" | "pwsh" | "powershell")
462}
463
464fn platform_name() -> &'static str {
465    if cfg!(target_os = "macos") {
466        "darwin"
467    } else if cfg!(target_os = "windows") {
468        "windows"
469    } else if cfg!(target_os = "linux") {
470        "linux"
471    } else {
472        std::env::consts::OS
473    }
474}
475
476fn executable_available(path: &str) -> bool {
477    let path_obj = Path::new(path);
478    if path_obj.components().count() > 1 || path_obj.is_absolute() {
479        return path_obj.is_file();
480    }
481    find_on_path(path).is_some()
482}
483
484fn find_on_path(program: &str) -> Option<String> {
485    let path = std::env::var_os("PATH")?;
486    let candidates = path_candidates(program);
487    for dir in std::env::split_paths(&path) {
488        for candidate in &candidates {
489            let full = dir.join(candidate);
490            if full.is_file() {
491                return Some(full.display().to_string());
492            }
493        }
494    }
495    None
496}
497
498#[cfg(windows)]
499fn path_candidates(program: &str) -> Vec<PathBuf> {
500    let mut candidates = vec![PathBuf::from(program)];
501    if Path::new(program).extension().is_none() {
502        for ext in [".exe", ".cmd", ".bat"] {
503            candidates.push(PathBuf::from(format!("{program}{ext}")));
504        }
505    }
506    candidates
507}
508
509#[cfg(not(windows))]
510fn path_candidates(program: &str) -> Vec<PathBuf> {
511    vec![PathBuf::from(program)]
512}
513
514#[cfg(not(windows))]
515fn login_shell_from_passwd() -> Option<String> {
516    let username = std::env::var("USER")
517        .or_else(|_| std::env::var("LOGNAME"))
518        .ok()?;
519    let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
520    passwd.lines().find_map(|line| {
521        let mut parts = line.split(':');
522        let name = parts.next()?;
523        if name != username {
524            return None;
525        }
526        parts
527            .nth(5)
528            .map(str::trim)
529            .filter(|shell| {
530                !shell.is_empty() && !shell.ends_with("/false") && !shell.ends_with("/nologin")
531            })
532            .map(ToString::to_string)
533    })
534}
535
536#[cfg(not(windows))]
537fn shells_from_etc_shells() -> Vec<String> {
538    let Ok(content) = std::fs::read_to_string("/etc/shells") else {
539        return Vec::new();
540    };
541    let mut seen = BTreeSet::new();
542    content
543        .lines()
544        .map(str::trim)
545        .filter(|line| !line.is_empty() && !line.starts_with('#') && line.starts_with('/'))
546        .filter(|line| seen.insert((*line).to_string()))
547        .map(ToString::to_string)
548        .collect()
549}
550
551fn optional_bool(params: &BTreeMap<String, VmValue>, key: &str) -> Option<bool> {
552    match params.get(key) {
553        Some(VmValue::Bool(value)) => Some(*value),
554        _ => None,
555    }
556}
557
558fn vm_string(value: &VmValue) -> Option<&str> {
559    match value {
560        VmValue::String(value) => Some(value.as_ref()),
561        _ => None,
562    }
563}
564
565fn vm_string_list(value: &VmValue) -> Option<Vec<String>> {
566    let VmValue::List(values) = value else {
567        return None;
568    };
569    values
570        .iter()
571        .map(|value| vm_string(value).map(ToString::to_string))
572        .collect()
573}
574
575fn string(value: &str) -> VmValue {
576    VmValue::String(Rc::from(value.to_string()))
577}
578
579fn string_list(values: &[String]) -> VmValue {
580    VmValue::List(Rc::new(values.iter().map(|value| string(value)).collect()))
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn unix_shell_descriptor_uses_split_login_args() {
589        let shell = descriptor_for_path("/bin/zsh", "fallback");
590        assert_eq!(shell.id, "zsh");
591        assert_eq!(shell.default_args, vec!["-c"]);
592        assert_eq!(shell.login_args, vec!["-l", "-c"]);
593        assert!(shell.supports_login);
594        assert!(shell.supports_interactive);
595    }
596
597    #[test]
598    fn windows_shell_descriptor_distinguishes_cmd_and_pwsh() {
599        let cmd = descriptor_for_path("cmd.exe", "fallback");
600        assert_eq!(cmd.id, "cmd");
601        assert_eq!(cmd.default_args, vec!["/C"]);
602        assert!(!cmd.supports_login);
603
604        let pwsh = descriptor_for_path("pwsh.exe", "fallback");
605        assert_eq!(pwsh.id, "pwsh");
606        assert_eq!(pwsh.default_args, vec!["-NoProfile", "-Command"]);
607    }
608
609    #[test]
610    fn invocation_appends_command_after_shell_args() {
611        let shell = ShellDescriptor {
612            id: "zsh".to_string(),
613            label: "zsh".to_string(),
614            path: "/bin/zsh".to_string(),
615            platform: "darwin".to_string(),
616            available: true,
617            supports_login: true,
618            supports_interactive: true,
619            default_args: vec!["-c".to_string()],
620            login_args: vec!["-l".to_string(), "-c".to_string()],
621            source: "test".to_string(),
622        };
623        let invocation = invocation_for_shell(shell, "echo ok".to_string(), true, true);
624        assert_eq!(invocation.program, "/bin/zsh");
625        assert_eq!(invocation.args, vec!["-i", "-l", "-c", "echo ok"]);
626        assert_eq!(invocation.command_arg_index, 3);
627    }
628
629    #[test]
630    fn invocation_without_explicit_shell_uses_default_shell() {
631        clear_selected_default_shell_for_test();
632        let default_shell = get_default_shell().expect("test host should expose a default shell");
633
634        let mut params = BTreeMap::new();
635        params.insert("command".to_string(), string("echo default-shell"));
636
637        let invocation = resolve_invocation_from_vm_params(&params).unwrap();
638        assert_eq!(invocation.shell, default_shell);
639        assert_eq!(invocation.program, default_shell.path);
640        assert_eq!(
641            invocation.args[invocation.command_arg_index],
642            "echo default-shell"
643        );
644    }
645
646    #[test]
647    fn malformed_explicit_shell_fields_do_not_fall_back_to_default() {
648        let mut params = BTreeMap::new();
649        params.insert("shell".to_string(), VmValue::Int(1));
650        assert_eq!(
651            resolve_shell_from_vm_params(&params).unwrap_err(),
652            "shell must be a dict, got int"
653        );
654
655        let mut params = BTreeMap::new();
656        params.insert("shell_id".to_string(), VmValue::Int(1));
657        assert_eq!(
658            resolve_shell_from_vm_params(&params).unwrap_err(),
659            "shell_id must be a string, got int"
660        );
661    }
662}