Skip to main content

harn_vm/
shells.rs

1use std::cell::RefCell;
2use std::collections::BTreeMap;
3#[cfg(not(windows))]
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::rc::Rc;
7
8use crate::value::{VmError, VmValue};
9
10thread_local! {
11    static SELECTED_DEFAULT_SHELL_ID: RefCell<Option<String>> = const { RefCell::new(None) };
12}
13
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct ShellDescriptor {
16    pub id: String,
17    pub label: String,
18    pub path: String,
19    pub platform: String,
20    pub available: bool,
21    pub supports_login: bool,
22    pub supports_interactive: bool,
23    pub default_args: Vec<String>,
24    pub login_args: Vec<String>,
25    pub source: String,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct ShellCatalog {
30    pub shells: Vec<ShellDescriptor>,
31    pub default_shell_id: Option<String>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct ShellInvocation {
36    pub program: String,
37    pub args: Vec<String>,
38    pub command_arg_index: usize,
39    pub shell: ShellDescriptor,
40}
41
42pub fn clear_selected_default_shell_for_test() {
43    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = None);
44}
45
46pub fn discover_shells() -> ShellCatalog {
47    let shells = platform_shells();
48    let selected = SELECTED_DEFAULT_SHELL_ID.with(|selected| selected.borrow().clone());
49    let default_shell_id = selected
50        .filter(|id| {
51            shells
52                .iter()
53                .any(|shell| shell.id == *id && shell.available)
54        })
55        .or_else(|| {
56            shells
57                .iter()
58                .find(|shell| shell.available)
59                .map(|shell| shell.id.clone())
60        })
61        .or_else(|| shells.first().map(|shell| shell.id.clone()));
62    ShellCatalog {
63        shells,
64        default_shell_id,
65    }
66}
67
68pub fn get_default_shell() -> Option<ShellDescriptor> {
69    let catalog = discover_shells();
70    catalog
71        .default_shell_id
72        .as_deref()
73        .and_then(|id| catalog.shells.iter().find(|shell| shell.id == id))
74        .cloned()
75        .or_else(|| catalog.shells.first().cloned())
76}
77
78pub fn set_default_shell(shell_id: &str) -> Result<ShellDescriptor, String> {
79    let catalog = discover_shells();
80    let Some(shell) = catalog
81        .shells
82        .iter()
83        .find(|shell| shell.id == shell_id && shell.available)
84        .cloned()
85    else {
86        return Err(format!("unknown or unavailable shell id {shell_id:?}"));
87    };
88    SELECTED_DEFAULT_SHELL_ID.with(|selected| *selected.borrow_mut() = Some(shell.id.clone()));
89    Ok(shell)
90}
91
92pub fn list_shells_vm_value() -> VmValue {
93    shell_catalog_to_vm_value(&discover_shells())
94}
95
96pub fn default_shell_vm_value() -> VmValue {
97    get_default_shell()
98        .map(|shell| shell_descriptor_to_vm_value(&shell))
99        .unwrap_or(VmValue::Nil)
100}
101
102pub fn set_default_shell_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
103    let shell_id = params
104        .get("shell_id")
105        .or_else(|| params.get("id"))
106        .and_then(vm_string)
107        .ok_or_else(|| {
108            VmError::Runtime("process.set_default_shell missing shell_id".to_string())
109        })?;
110    set_default_shell(shell_id)
111        .map(|shell| shell_descriptor_to_vm_value(&shell))
112        .map_err(|err| VmError::Runtime(format!("process.set_default_shell: {err}")))
113}
114
115pub fn shell_invocation_vm_value(params: &BTreeMap<String, VmValue>) -> Result<VmValue, VmError> {
116    resolve_invocation_from_vm_params(params)
117        .map(|invocation| shell_invocation_to_vm_value(&invocation))
118        .map_err(|err| VmError::Runtime(format!("process.shell_invocation: {err}")))
119}
120
121pub fn default_shell_invocation(command: &str) -> Result<ShellInvocation, String> {
122    let shell = get_default_shell().ok_or_else(|| "no shell candidates available".to_string())?;
123    Ok(invocation_for_shell(
124        shell,
125        command.to_string(),
126        false,
127        false,
128    ))
129}
130
131pub fn resolve_invocation_from_vm_params(
132    params: &BTreeMap<String, VmValue>,
133) -> Result<ShellInvocation, String> {
134    let command = params
135        .get("command")
136        .and_then(vm_string)
137        .unwrap_or("{command}")
138        .to_string();
139    let login = optional_bool(params, "login").unwrap_or(false);
140    let interactive = optional_bool(params, "interactive").unwrap_or(false);
141    let shell = resolve_shell_from_vm_params(params)?;
142    Ok(invocation_for_shell(shell, command, login, interactive))
143}
144
145pub fn resolve_shell_from_vm_params(
146    params: &BTreeMap<String, VmValue>,
147) -> Result<ShellDescriptor, String> {
148    if let Some(shell) = params.get("shell").and_then(|value| value.as_dict()) {
149        return shell_descriptor_from_vm_dict(shell);
150    }
151    if let Some(shell_id) = params.get("shell_id").and_then(vm_string) {
152        return shell_by_id(shell_id);
153    }
154    Err("shell mode requires `shell` or `shell_id`".to_string())
155}
156
157pub fn shell_descriptor_to_vm_value(shell: &ShellDescriptor) -> VmValue {
158    let mut map = BTreeMap::new();
159    map.insert("id".to_string(), string(&shell.id));
160    map.insert("label".to_string(), string(&shell.label));
161    map.insert("path".to_string(), string(&shell.path));
162    map.insert("platform".to_string(), string(&shell.platform));
163    map.insert("available".to_string(), VmValue::Bool(shell.available));
164    map.insert(
165        "supports_login".to_string(),
166        VmValue::Bool(shell.supports_login),
167    );
168    map.insert(
169        "supports_interactive".to_string(),
170        VmValue::Bool(shell.supports_interactive),
171    );
172    map.insert("default_args".to_string(), string_list(&shell.default_args));
173    map.insert("login_args".to_string(), string_list(&shell.login_args));
174    map.insert("source".to_string(), string(&shell.source));
175    VmValue::Dict(Rc::new(map))
176}
177
178pub fn shell_invocation_to_vm_value(invocation: &ShellInvocation) -> VmValue {
179    let mut map = BTreeMap::new();
180    map.insert("program".to_string(), string(&invocation.program));
181    map.insert("args".to_string(), string_list(&invocation.args));
182    map.insert(
183        "command_arg_index".to_string(),
184        VmValue::Int(invocation.command_arg_index as i64),
185    );
186    map.insert(
187        "shell".to_string(),
188        shell_descriptor_to_vm_value(&invocation.shell),
189    );
190    VmValue::Dict(Rc::new(map))
191}
192
193fn shell_catalog_to_vm_value(catalog: &ShellCatalog) -> VmValue {
194    let mut map = BTreeMap::new();
195    map.insert(
196        "shells".to_string(),
197        VmValue::List(Rc::new(
198            catalog
199                .shells
200                .iter()
201                .map(shell_descriptor_to_vm_value)
202                .collect(),
203        )),
204    );
205    map.insert(
206        "default_shell_id".to_string(),
207        catalog
208            .default_shell_id
209            .as_ref()
210            .map(|id| string(id))
211            .unwrap_or(VmValue::Nil),
212    );
213    VmValue::Dict(Rc::new(map))
214}
215
216fn shell_descriptor_from_vm_dict(
217    dict: &BTreeMap<String, VmValue>,
218) -> Result<ShellDescriptor, String> {
219    if let Some(path) = dict.get("path").and_then(vm_string) {
220        let id = dict
221            .get("id")
222            .and_then(vm_string)
223            .map(ToString::to_string)
224            .unwrap_or_else(|| shell_id_from_path(path));
225        let platform = dict
226            .get("platform")
227            .and_then(vm_string)
228            .unwrap_or(platform_name())
229            .to_string();
230        let label = dict
231            .get("label")
232            .and_then(vm_string)
233            .map(ToString::to_string)
234            .unwrap_or_else(|| id.clone());
235        let default_args = dict
236            .get("default_args")
237            .and_then(vm_string_list)
238            .unwrap_or_else(|| default_args_for_id(&id));
239        let login_args = dict
240            .get("login_args")
241            .and_then(vm_string_list)
242            .unwrap_or_else(|| login_args_for_id(&id));
243        let available = dict
244            .get("available")
245            .and_then(|value| match value {
246                VmValue::Bool(value) => Some(*value),
247                _ => None,
248            })
249            .unwrap_or_else(|| executable_available(path));
250        let supports_login = dict
251            .get("supports_login")
252            .and_then(|value| match value {
253                VmValue::Bool(value) => Some(*value),
254                _ => None,
255            })
256            .unwrap_or_else(|| supports_login_for_id(&id));
257        let supports_interactive = dict
258            .get("supports_interactive")
259            .and_then(|value| match value {
260                VmValue::Bool(value) => Some(*value),
261                _ => None,
262            })
263            .unwrap_or_else(|| supports_interactive_for_id(&id));
264        return Ok(ShellDescriptor {
265            id,
266            label,
267            path: path.to_string(),
268            platform,
269            available,
270            supports_login,
271            supports_interactive,
272            default_args,
273            login_args,
274            source: dict
275                .get("source")
276                .and_then(vm_string)
277                .unwrap_or("host")
278                .to_string(),
279        });
280    }
281    if let Some(id) = dict.get("id").and_then(vm_string) {
282        return shell_by_id(id);
283    }
284    Err("shell object requires `path` or `id`".to_string())
285}
286
287fn shell_by_id(shell_id: &str) -> Result<ShellDescriptor, String> {
288    discover_shells()
289        .shells
290        .into_iter()
291        .find(|shell| shell.id == shell_id)
292        .ok_or_else(|| format!("unknown shell id {shell_id:?}"))
293}
294
295fn invocation_for_shell(
296    shell: ShellDescriptor,
297    command: String,
298    login: bool,
299    interactive: bool,
300) -> ShellInvocation {
301    let mut args = if login && shell.supports_login && !shell.login_args.is_empty() {
302        shell.login_args.clone()
303    } else {
304        shell.default_args.clone()
305    };
306    if interactive && shell.supports_interactive && !args.iter().any(|arg| arg == "-i") {
307        args.insert(0, "-i".to_string());
308    }
309    let command_arg_index = args.len();
310    args.push(command);
311    ShellInvocation {
312        program: shell.path.clone(),
313        args,
314        command_arg_index,
315        shell,
316    }
317}
318
319#[cfg(windows)]
320fn platform_shells() -> Vec<ShellDescriptor> {
321    let mut shells = Vec::new();
322    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
323        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
324    }
325    if let Ok(value) = std::env::var("COMSPEC") {
326        push_shell(&mut shells, descriptor_for_path(&value, "env"));
327    }
328    for (id, label, executable) in [
329        ("pwsh", "PowerShell 7", "pwsh.exe"),
330        ("powershell", "Windows PowerShell", "powershell.exe"),
331        ("cmd", "cmd", "cmd.exe"),
332    ] {
333        let path = find_on_path(executable).unwrap_or_else(|| executable.to_string());
334        let mut shell = descriptor_for_path(&path, "fallback");
335        shell.id = id.to_string();
336        shell.label = label.to_string();
337        push_shell(&mut shells, shell);
338    }
339    shells
340}
341
342#[cfg(not(windows))]
343fn platform_shells() -> Vec<ShellDescriptor> {
344    let mut shells = Vec::new();
345    if let Ok(value) = std::env::var("HARN_DEFAULT_SHELL") {
346        push_shell(&mut shells, descriptor_for_path(&value, "configured"));
347    }
348    if let Ok(value) = std::env::var("SHELL") {
349        push_shell(&mut shells, descriptor_for_path(&value, "env"));
350    }
351    if let Some(value) = login_shell_from_passwd() {
352        push_shell(&mut shells, descriptor_for_path(&value, "login"));
353    }
354    for value in shells_from_etc_shells() {
355        push_shell(&mut shells, descriptor_for_path(&value, "etc_shells"));
356    }
357    for value in [
358        "/bin/zsh",
359        "/bin/bash",
360        "/bin/sh",
361        "/usr/bin/zsh",
362        "/usr/bin/bash",
363        "/usr/bin/sh",
364    ] {
365        push_shell(&mut shells, descriptor_for_path(value, "fallback"));
366    }
367    shells
368}
369
370fn push_shell(shells: &mut Vec<ShellDescriptor>, shell: ShellDescriptor) {
371    if shells.iter().any(|existing| existing.id == shell.id) {
372        return;
373    }
374    shells.push(shell);
375}
376
377fn descriptor_for_path(path: &str, source: &str) -> ShellDescriptor {
378    let id = shell_id_from_path(path);
379    ShellDescriptor {
380        id: id.clone(),
381        label: label_for_id(&id),
382        path: path.to_string(),
383        platform: platform_name().to_string(),
384        available: executable_available(path),
385        supports_login: supports_login_for_id(&id),
386        supports_interactive: supports_interactive_for_id(&id),
387        default_args: default_args_for_id(&id),
388        login_args: login_args_for_id(&id),
389        source: source.to_string(),
390    }
391}
392
393fn shell_id_from_path(path: &str) -> String {
394    let raw = Path::new(path)
395        .file_name()
396        .and_then(|value| value.to_str())
397        .unwrap_or(path)
398        .to_ascii_lowercase();
399    let file_name = raw.strip_suffix(".exe").unwrap_or(&raw);
400    match file_name {
401        "powershell" | "windowspowershell" => "powershell".to_string(),
402        "pwsh" => "pwsh".to_string(),
403        "cmd" => "cmd".to_string(),
404        "bash" => "bash".to_string(),
405        "zsh" => "zsh".to_string(),
406        "fish" => "fish".to_string(),
407        _ if file_name.is_empty() => "shell".to_string(),
408        _ => file_name.to_string(),
409    }
410}
411
412fn label_for_id(id: &str) -> String {
413    match id {
414        "pwsh" => "PowerShell 7",
415        "powershell" => "Windows PowerShell",
416        "cmd" => "cmd",
417        "bash" => "bash",
418        "zsh" => "zsh",
419        "fish" => "fish",
420        "sh" => "sh",
421        other => other,
422    }
423    .to_string()
424}
425
426fn default_args_for_id(id: &str) -> Vec<String> {
427    match id {
428        "cmd" => vec!["/C".to_string()],
429        "pwsh" | "powershell" => vec!["-NoProfile".to_string(), "-Command".to_string()],
430        _ => vec!["-c".to_string()],
431    }
432}
433
434fn login_args_for_id(id: &str) -> Vec<String> {
435    match id {
436        "cmd" | "pwsh" | "powershell" => default_args_for_id(id),
437        _ => vec!["-l".to_string(), "-c".to_string()],
438    }
439}
440
441fn supports_login_for_id(id: &str) -> bool {
442    !matches!(id, "cmd" | "pwsh" | "powershell")
443}
444
445fn supports_interactive_for_id(id: &str) -> bool {
446    !matches!(id, "cmd" | "pwsh" | "powershell")
447}
448
449fn platform_name() -> &'static str {
450    if cfg!(target_os = "macos") {
451        "darwin"
452    } else if cfg!(target_os = "windows") {
453        "windows"
454    } else if cfg!(target_os = "linux") {
455        "linux"
456    } else {
457        std::env::consts::OS
458    }
459}
460
461fn executable_available(path: &str) -> bool {
462    let path_obj = Path::new(path);
463    if path_obj.components().count() > 1 || path_obj.is_absolute() {
464        return path_obj.is_file();
465    }
466    find_on_path(path).is_some()
467}
468
469fn find_on_path(program: &str) -> Option<String> {
470    let path = std::env::var_os("PATH")?;
471    let candidates = path_candidates(program);
472    for dir in std::env::split_paths(&path) {
473        for candidate in &candidates {
474            let full = dir.join(candidate);
475            if full.is_file() {
476                return Some(full.display().to_string());
477            }
478        }
479    }
480    None
481}
482
483#[cfg(windows)]
484fn path_candidates(program: &str) -> Vec<PathBuf> {
485    let mut candidates = vec![PathBuf::from(program)];
486    if Path::new(program).extension().is_none() {
487        for ext in [".exe", ".cmd", ".bat"] {
488            candidates.push(PathBuf::from(format!("{program}{ext}")));
489        }
490    }
491    candidates
492}
493
494#[cfg(not(windows))]
495fn path_candidates(program: &str) -> Vec<PathBuf> {
496    vec![PathBuf::from(program)]
497}
498
499#[cfg(not(windows))]
500fn login_shell_from_passwd() -> Option<String> {
501    let username = std::env::var("USER")
502        .or_else(|_| std::env::var("LOGNAME"))
503        .ok()?;
504    let passwd = std::fs::read_to_string("/etc/passwd").ok()?;
505    passwd.lines().find_map(|line| {
506        let mut parts = line.split(':');
507        let name = parts.next()?;
508        if name != username {
509            return None;
510        }
511        parts
512            .nth(5)
513            .map(str::trim)
514            .filter(|shell| {
515                !shell.is_empty() && !shell.ends_with("/false") && !shell.ends_with("/nologin")
516            })
517            .map(ToString::to_string)
518    })
519}
520
521#[cfg(not(windows))]
522fn shells_from_etc_shells() -> Vec<String> {
523    let Ok(content) = std::fs::read_to_string("/etc/shells") else {
524        return Vec::new();
525    };
526    let mut seen = BTreeSet::new();
527    content
528        .lines()
529        .map(str::trim)
530        .filter(|line| !line.is_empty() && !line.starts_with('#') && line.starts_with('/'))
531        .filter(|line| seen.insert((*line).to_string()))
532        .map(ToString::to_string)
533        .collect()
534}
535
536fn optional_bool(params: &BTreeMap<String, VmValue>, key: &str) -> Option<bool> {
537    match params.get(key) {
538        Some(VmValue::Bool(value)) => Some(*value),
539        _ => None,
540    }
541}
542
543fn vm_string(value: &VmValue) -> Option<&str> {
544    match value {
545        VmValue::String(value) => Some(value.as_ref()),
546        _ => None,
547    }
548}
549
550fn vm_string_list(value: &VmValue) -> Option<Vec<String>> {
551    let VmValue::List(values) = value else {
552        return None;
553    };
554    values
555        .iter()
556        .map(|value| vm_string(value).map(ToString::to_string))
557        .collect()
558}
559
560fn string(value: &str) -> VmValue {
561    VmValue::String(Rc::from(value.to_string()))
562}
563
564fn string_list(values: &[String]) -> VmValue {
565    VmValue::List(Rc::new(values.iter().map(|value| string(value)).collect()))
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn unix_shell_descriptor_uses_split_login_args() {
574        let shell = descriptor_for_path("/bin/zsh", "fallback");
575        assert_eq!(shell.id, "zsh");
576        assert_eq!(shell.default_args, vec!["-c"]);
577        assert_eq!(shell.login_args, vec!["-l", "-c"]);
578        assert!(shell.supports_login);
579        assert!(shell.supports_interactive);
580    }
581
582    #[test]
583    fn windows_shell_descriptor_distinguishes_cmd_and_pwsh() {
584        let cmd = descriptor_for_path("cmd.exe", "fallback");
585        assert_eq!(cmd.id, "cmd");
586        assert_eq!(cmd.default_args, vec!["/C"]);
587        assert!(!cmd.supports_login);
588
589        let pwsh = descriptor_for_path("pwsh.exe", "fallback");
590        assert_eq!(pwsh.id, "pwsh");
591        assert_eq!(pwsh.default_args, vec!["-NoProfile", "-Command"]);
592    }
593
594    #[test]
595    fn invocation_appends_command_after_shell_args() {
596        let shell = ShellDescriptor {
597            id: "zsh".to_string(),
598            label: "zsh".to_string(),
599            path: "/bin/zsh".to_string(),
600            platform: "darwin".to_string(),
601            available: true,
602            supports_login: true,
603            supports_interactive: true,
604            default_args: vec!["-c".to_string()],
605            login_args: vec!["-l".to_string(), "-c".to_string()],
606            source: "test".to_string(),
607        };
608        let invocation = invocation_for_shell(shell, "echo ok".to_string(), true, true);
609        assert_eq!(invocation.program, "/bin/zsh");
610        assert_eq!(invocation.args, vec!["-i", "-l", "-c", "echo ok"]);
611        assert_eq!(invocation.command_arg_index, 3);
612    }
613}