Skip to main content

pty_mcp/
config.rs

1use std::{env, path::PathBuf};
2
3use anyhow::{Result, bail};
4
5#[derive(Debug, Clone, Default)]
6pub struct SshResolvedBinPaths {
7    pub ssh: Option<PathBuf>,
8    pub sshfs: Option<PathBuf>,
9    pub umount: Option<PathBuf>,
10    pub diskutil: Option<PathBuf>,
11}
12
13#[derive(Debug, Clone)]
14pub struct SshConfig {
15    pub ssh_bin_path: Option<PathBuf>,
16    pub sshfs_bin_path: Option<PathBuf>,
17    pub umount_bin_path: Option<PathBuf>,
18    pub diskutil_bin_path: Option<PathBuf>,
19    pub macos_block_apple_metadata: bool,
20    pub managed_mount_root: Option<PathBuf>,
21    pub allowed_hosts: Vec<String>,
22    pub denied_hosts: Vec<String>,
23    pub allowed_users: Vec<String>,
24    pub allowed_auth_kinds: Vec<String>,
25    pub allowed_tunnel_bind_hosts: Vec<String>,
26    pub allow_explicit_mount_paths: bool,
27    pub allowed_mount_roots: Vec<PathBuf>,
28    pub port_min: u16,
29    pub port_max: u16,
30}
31
32impl Default for SshConfig {
33    fn default() -> Self {
34        Self {
35            ssh_bin_path: resolve_bin_path(None, "ssh", ssh_default_paths()),
36            sshfs_bin_path: resolve_bin_path(None, "sshfs", sshfs_default_paths()),
37            umount_bin_path: resolve_bin_path(None, "umount", umount_default_paths()),
38            diskutil_bin_path: resolve_bin_path(None, "diskutil", diskutil_default_paths()),
39            macos_block_apple_metadata: default_macos_block_apple_metadata(),
40            managed_mount_root: None,
41            allowed_hosts: Vec::new(),
42            denied_hosts: Vec::new(),
43            allowed_users: Vec::new(),
44            allowed_auth_kinds: Vec::new(),
45            allowed_tunnel_bind_hosts: vec![
46                "127.0.0.1".to_string(),
47                "::1".to_string(),
48                "localhost".to_string(),
49            ],
50            allow_explicit_mount_paths: true,
51            allowed_mount_roots: Vec::new(),
52            port_min: 1,
53            port_max: u16::MAX,
54        }
55    }
56}
57
58impl SshConfig {
59    pub fn resolved_ssh_bin_path(&self) -> Option<PathBuf> {
60        resolve_bin_path(self.ssh_bin_path.clone(), "ssh", ssh_default_paths())
61    }
62
63    pub fn resolved_sshfs_bin_path(&self) -> Option<PathBuf> {
64        resolve_bin_path(self.sshfs_bin_path.clone(), "sshfs", sshfs_default_paths())
65    }
66
67    pub fn resolved_umount_bin_path(&self) -> Option<PathBuf> {
68        resolve_bin_path(
69            self.umount_bin_path.clone(),
70            "umount",
71            umount_default_paths(),
72        )
73    }
74
75    pub fn resolved_diskutil_bin_path(&self) -> Option<PathBuf> {
76        resolve_bin_path(
77            self.diskutil_bin_path.clone(),
78            "diskutil",
79            diskutil_default_paths(),
80        )
81    }
82
83    pub fn resolved_bin_paths(&self) -> SshResolvedBinPaths {
84        SshResolvedBinPaths {
85            ssh: self.resolved_ssh_bin_path(),
86            sshfs: self.resolved_sshfs_bin_path(),
87            umount: self.resolved_umount_bin_path(),
88            diskutil: self.resolved_diskutil_bin_path(),
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94pub struct Config {
95    pub session_limit: usize,
96    pub default_read_limit: usize,
97    pub max_buffer_lines: usize,
98    pub allowed_cwd_roots: Vec<PathBuf>,
99    pub allowed_commands: Vec<String>,
100    pub denied_commands: Vec<String>,
101    pub allowed_env_vars: Vec<String>,
102    pub denied_env_vars: Vec<String>,
103    pub ssh: SshConfig,
104}
105
106impl Default for Config {
107    fn default() -> Self {
108        Self {
109            session_limit: 32,
110            default_read_limit: 200,
111            max_buffer_lines: 50_000,
112            allowed_cwd_roots: vec![env::current_dir().unwrap_or_else(|_| PathBuf::from("."))],
113            allowed_commands: Vec::new(),
114            denied_commands: Vec::new(),
115            allowed_env_vars: Vec::new(),
116            denied_env_vars: vec![
117                "LD_PRELOAD".to_string(),
118                "LD_LIBRARY_PATH".to_string(),
119                "DYLD_INSERT_LIBRARIES".to_string(),
120                "DYLD_LIBRARY_PATH".to_string(),
121            ],
122            ssh: SshConfig::default(),
123        }
124    }
125}
126
127impl Config {
128    pub fn from_env() -> Result<Self> {
129        let mut config = Self::default();
130
131        if let Ok(value) = env::var("PTY_MCP_SESSION_LIMIT") {
132            config.session_limit = parse_usize("PTY_MCP_SESSION_LIMIT", &value)?;
133        }
134
135        if let Ok(value) = env::var("PTY_MCP_DEFAULT_READ_LIMIT") {
136            config.default_read_limit = parse_usize("PTY_MCP_DEFAULT_READ_LIMIT", &value)?;
137        }
138
139        if let Ok(value) = env::var("PTY_MCP_MAX_BUFFER_LINES") {
140            config.max_buffer_lines = parse_usize("PTY_MCP_MAX_BUFFER_LINES", &value)?;
141        }
142
143        if let Ok(value) = env::var("PTY_MCP_ALLOWED_CWD_ROOTS") {
144            config.allowed_cwd_roots = value
145                .split(':')
146                .filter(|segment| !segment.trim().is_empty())
147                .map(PathBuf::from)
148                .collect();
149        }
150
151        if let Ok(value) = env::var("PTY_MCP_ALLOWED_COMMANDS") {
152            config.allowed_commands = parse_csv(&value);
153        }
154
155        if let Ok(value) = env::var("PTY_MCP_DENIED_COMMANDS") {
156            config.denied_commands = parse_csv(&value);
157        }
158
159        if let Ok(value) = env::var("PTY_MCP_ALLOWED_ENV_VARS") {
160            config.allowed_env_vars = parse_csv(&value);
161        }
162
163        if let Ok(value) = env::var("PTY_MCP_DENIED_ENV_VARS") {
164            config.denied_env_vars = parse_csv(&value);
165        }
166
167        if let Ok(value) = env::var("PTY_MCP_SSH_BIN_PATH") {
168            config.ssh.ssh_bin_path =
169                resolve_bin_path(Some(PathBuf::from(value)), "ssh", ssh_default_paths());
170        }
171
172        if let Ok(value) = env::var("PTY_MCP_SSHFS_BIN_PATH") {
173            config.ssh.sshfs_bin_path =
174                resolve_bin_path(Some(PathBuf::from(value)), "sshfs", sshfs_default_paths());
175        }
176
177        if let Ok(value) = env::var("PTY_MCP_UMOUNT_BIN_PATH") {
178            config.ssh.umount_bin_path =
179                resolve_bin_path(Some(PathBuf::from(value)), "umount", umount_default_paths());
180        }
181
182        if let Ok(value) = env::var("PTY_MCP_DISKUTIL_BIN_PATH") {
183            config.ssh.diskutil_bin_path = resolve_bin_path(
184                Some(PathBuf::from(value)),
185                "diskutil",
186                diskutil_default_paths(),
187            );
188        }
189
190        if let Ok(value) = env::var("PTY_MCP_SSH_MACOS_BLOCK_APPLE_METADATA") {
191            config.ssh.macos_block_apple_metadata =
192                parse_bool("PTY_MCP_SSH_MACOS_BLOCK_APPLE_METADATA", &value)?;
193        }
194
195        if let Ok(value) = env::var("PTY_MCP_SSH_MANAGED_MOUNT_ROOT") {
196            let trimmed = value.trim();
197            if !trimmed.is_empty() {
198                let managed_mount_root = PathBuf::from(trimmed);
199                if !config.allowed_cwd_roots.contains(&managed_mount_root) {
200                    config.allowed_cwd_roots.push(managed_mount_root.clone());
201                }
202                config.ssh.managed_mount_root = Some(managed_mount_root);
203            }
204        }
205
206        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOWED_HOSTS") {
207            config.ssh.allowed_hosts = parse_csv(&value);
208        }
209
210        if let Ok(value) = env::var("PTY_MCP_SSH_DENIED_HOSTS") {
211            config.ssh.denied_hosts = parse_csv(&value);
212        }
213
214        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOWED_USERS") {
215            config.ssh.allowed_users = parse_csv(&value);
216        }
217
218        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOWED_AUTH_KINDS") {
219            config.ssh.allowed_auth_kinds =
220                parse_auth_kinds("PTY_MCP_SSH_ALLOWED_AUTH_KINDS", &value)?;
221        }
222
223        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOWED_TUNNEL_BIND_HOSTS") {
224            config.ssh.allowed_tunnel_bind_hosts = parse_csv(&value);
225        }
226
227        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOW_EXPLICIT_MOUNT_PATHS") {
228            config.ssh.allow_explicit_mount_paths =
229                parse_bool("PTY_MCP_SSH_ALLOW_EXPLICIT_MOUNT_PATHS", &value)?;
230        }
231
232        if let Ok(value) = env::var("PTY_MCP_SSH_ALLOWED_MOUNT_ROOTS") {
233            config.ssh.allowed_mount_roots = parse_path_list(&value);
234        }
235
236        if let Ok(value) = env::var("PTY_MCP_SSH_PORT_MIN") {
237            config.ssh.port_min = parse_u16("PTY_MCP_SSH_PORT_MIN", &value)?;
238        }
239
240        if let Ok(value) = env::var("PTY_MCP_SSH_PORT_MAX") {
241            config.ssh.port_max = parse_u16("PTY_MCP_SSH_PORT_MAX", &value)?;
242        }
243
244        if config.ssh.port_min > config.ssh.port_max {
245            bail!(
246                "invalid SSH port range: PTY_MCP_SSH_PORT_MIN={} is greater than PTY_MCP_SSH_PORT_MAX={}",
247                config.ssh.port_min,
248                config.ssh.port_max
249            );
250        }
251
252        if config.ssh.allowed_mount_roots.is_empty() {
253            config.ssh.allowed_mount_roots = config.allowed_cwd_roots.clone();
254        }
255
256        Ok(config)
257    }
258}
259
260fn parse_csv(value: &str) -> Vec<String> {
261    value
262        .split(',')
263        .map(str::trim)
264        .filter(|segment| !segment.is_empty())
265        .map(ToString::to_string)
266        .collect()
267}
268
269fn parse_path_list(value: &str) -> Vec<PathBuf> {
270    value
271        .split(':')
272        .map(str::trim)
273        .filter(|segment| !segment.is_empty())
274        .map(PathBuf::from)
275        .collect()
276}
277
278fn parse_auth_kinds(key: &'static str, value: &str) -> Result<Vec<String>> {
279    let mut parsed = Vec::new();
280    for segment in value
281        .split(',')
282        .map(str::trim)
283        .filter(|segment| !segment.is_empty())
284    {
285        let normalized = normalize_auth_kind(segment)
286            .ok_or_else(|| anyhow::anyhow!("invalid ssh auth kind for {key}: value={segment}"))?;
287        if !parsed.contains(&normalized) {
288            parsed.push(normalized);
289        }
290    }
291    Ok(parsed)
292}
293
294fn normalize_auth_kind(value: &str) -> Option<String> {
295    match value.trim().to_ascii_lowercase().as_str() {
296        "config_alias" => Some("config_alias".to_string()),
297        "ssh_agent" => Some("ssh_agent".to_string()),
298        "identity_file" => Some("identity_file".to_string()),
299        _ => None,
300    }
301}
302
303fn parse_usize(key: &'static str, value: &str) -> Result<usize> {
304    value
305        .parse::<usize>()
306        .map_err(|source| anyhow::anyhow!("invalid usize for {key}: value={value}: {source}"))
307}
308
309fn parse_u16(key: &'static str, value: &str) -> Result<u16> {
310    value
311        .parse::<u16>()
312        .map_err(|source| anyhow::anyhow!("invalid u16 for {key}: value={value}: {source}"))
313}
314
315fn parse_bool(key: &'static str, value: &str) -> Result<bool> {
316    match value.trim().to_ascii_lowercase().as_str() {
317        "1" | "true" | "yes" | "on" => Ok(true),
318        "0" | "false" | "no" | "off" => Ok(false),
319        _ => bail!("invalid bool for {key}: value={value}"),
320    }
321}
322
323#[cfg(target_os = "macos")]
324const fn default_macos_block_apple_metadata() -> bool {
325    true
326}
327
328#[cfg(not(target_os = "macos"))]
329const fn default_macos_block_apple_metadata() -> bool {
330    false
331}
332
333fn resolve_bin_path(
334    explicit: Option<PathBuf>,
335    command_name: &str,
336    default_candidates: &'static [&'static str],
337) -> Option<PathBuf> {
338    if let Some(path) = explicit {
339        let trimmed = path.to_string_lossy().trim().to_string();
340        if !trimmed.is_empty() {
341            return Some(PathBuf::from(trimmed));
342        }
343    }
344
345    for candidate in default_candidates {
346        let path = PathBuf::from(candidate);
347        if path.is_file() {
348            return Some(path);
349        }
350    }
351
352    find_in_path(command_name)
353}
354
355fn find_in_path(command_name: &str) -> Option<PathBuf> {
356    let path_var = env::var_os("PATH")?;
357    for entry in env::split_paths(&path_var) {
358        let candidate = entry.join(command_name);
359        if candidate.is_file() {
360            return Some(candidate);
361        }
362    }
363    None
364}
365
366#[cfg(target_os = "macos")]
367fn ssh_default_paths() -> &'static [&'static str] {
368    &[
369        "/usr/bin/ssh",
370        "/opt/homebrew/bin/ssh",
371        "/usr/local/bin/ssh",
372    ]
373}
374
375#[cfg(not(target_os = "macos"))]
376fn ssh_default_paths() -> &'static [&'static str] {
377    &["/usr/bin/ssh", "/usr/local/bin/ssh"]
378}
379
380#[cfg(target_os = "macos")]
381fn sshfs_default_paths() -> &'static [&'static str] {
382    &[
383        "/opt/homebrew/bin/sshfs",
384        "/usr/local/bin/sshfs",
385        "/usr/bin/sshfs",
386    ]
387}
388
389#[cfg(not(target_os = "macos"))]
390fn sshfs_default_paths() -> &'static [&'static str] {
391    &["/usr/bin/sshfs", "/usr/local/bin/sshfs"]
392}
393
394#[cfg(target_os = "macos")]
395fn umount_default_paths() -> &'static [&'static str] {
396    &["/sbin/umount", "/usr/sbin/umount", "/usr/bin/umount"]
397}
398
399#[cfg(not(target_os = "macos"))]
400fn umount_default_paths() -> &'static [&'static str] {
401    &["/usr/bin/umount", "/bin/umount", "/usr/local/bin/umount"]
402}
403
404#[cfg(target_os = "macos")]
405fn diskutil_default_paths() -> &'static [&'static str] {
406    &["/usr/sbin/diskutil", "/usr/bin/diskutil"]
407}
408
409#[cfg(not(target_os = "macos"))]
410fn diskutil_default_paths() -> &'static [&'static str] {
411    &[]
412}
413
414#[cfg(test)]
415mod tests {
416    use std::sync::{Mutex, OnceLock};
417
418    use super::{
419        Config, default_macos_block_apple_metadata, parse_auth_kinds, parse_bool, parse_usize,
420    };
421
422    fn env_lock() -> &'static Mutex<()> {
423        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
424        LOCK.get_or_init(|| Mutex::new(()))
425    }
426
427    struct EnvGuard {
428        key: &'static str,
429        original: Option<String>,
430    }
431
432    impl EnvGuard {
433        fn set(key: &'static str, value: Option<&str>) -> Self {
434            let original = std::env::var(key).ok();
435            match value {
436                Some(value) => unsafe { std::env::set_var(key, value) },
437                None => unsafe { std::env::remove_var(key) },
438            }
439            Self { key, original }
440        }
441    }
442
443    impl Drop for EnvGuard {
444        fn drop(&mut self) {
445            match self.original.as_deref() {
446                Some(value) => unsafe { std::env::set_var(self.key, value) },
447                None => unsafe { std::env::remove_var(self.key) },
448            }
449        }
450    }
451
452    #[test]
453    fn parse_usize_error_contains_key_and_value() {
454        let error = parse_usize("PTY_MCP_SESSION_LIMIT", "abc").expect_err("parse should fail");
455        let text = format!("{error:#}");
456        assert!(text.contains("PTY_MCP_SESSION_LIMIT"));
457        assert!(text.contains("abc"));
458    }
459
460    #[test]
461    fn parse_bool_error_contains_key_and_value() {
462        let error = parse_bool("PTY_MCP_SSH_ALLOW_EXPLICIT_MOUNT_PATHS", "maybe")
463            .expect_err("parse should fail");
464        let text = format!("{error:#}");
465        assert!(text.contains("PTY_MCP_SSH_ALLOW_EXPLICIT_MOUNT_PATHS"));
466        assert!(text.contains("maybe"));
467    }
468
469    #[test]
470    fn parse_auth_kind_error_contains_key_and_value() {
471        let error = parse_auth_kinds("PTY_MCP_SSH_ALLOWED_AUTH_KINDS", "magic")
472            .expect_err("parse should fail");
473        let text = format!("{error:#}");
474        assert!(text.contains("PTY_MCP_SSH_ALLOWED_AUTH_KINDS"));
475        assert!(text.contains("magic"));
476    }
477
478    #[test]
479    fn config_from_env_uses_platform_default_for_macos_metadata_blocking() {
480        let _lock = env_lock().lock().expect("env lock poisoned");
481        let _guard = EnvGuard::set("PTY_MCP_SSH_MACOS_BLOCK_APPLE_METADATA", None);
482
483        let config = Config::from_env().expect("config should load");
484
485        assert_eq!(
486            config.ssh.macos_block_apple_metadata,
487            default_macos_block_apple_metadata()
488        );
489    }
490
491    #[test]
492    fn config_from_env_allows_overriding_macos_metadata_blocking() {
493        let _lock = env_lock().lock().expect("env lock poisoned");
494        let _guard = EnvGuard::set("PTY_MCP_SSH_MACOS_BLOCK_APPLE_METADATA", Some("false"));
495
496        let config = Config::from_env().expect("config should load");
497
498        assert!(!config.ssh.macos_block_apple_metadata);
499    }
500}