pty-mcp 0.3.0

An MCP server for PTY management with SSH connections, remote sessions, file access, and mounts
Documentation
use std::{
    collections::BTreeSet,
    fs,
    path::{Component, Path, PathBuf},
};

use anyhow::{Result, bail, ensure};

use crate::{Config, ssh::SshAuthKind};

const DEFAULT_SSH_PORT: u16 = 22;

#[derive(Debug, Clone)]
pub struct SshPolicy {
    allowed_hosts: BTreeSet<String>,
    denied_hosts: BTreeSet<String>,
    allowed_users: BTreeSet<String>,
    allowed_auth_kinds: BTreeSet<String>,
    allowed_tunnel_bind_hosts: BTreeSet<String>,
    allow_explicit_mount_paths: bool,
    managed_mount_root: Option<PathBuf>,
    allowed_local_mount_roots: Vec<PathBuf>,
    port_min: u16,
    port_max: u16,
}

impl SshPolicy {
    pub fn from_config(config: &Config) -> Self {
        let managed_mount_root = config
            .ssh
            .managed_mount_root
            .as_ref()
            .map(|path| normalize_path(path.as_path()));
        let mut allowed_local_mount_roots = config
            .ssh
            .allowed_mount_roots
            .iter()
            .map(|path| normalize_path(path.as_path()))
            .collect::<Vec<_>>();
        if allowed_local_mount_roots.is_empty() {
            allowed_local_mount_roots = config
                .allowed_cwd_roots
                .iter()
                .map(|path| normalize_path(path.as_path()))
                .collect();
        }
        if let Some(root) = managed_mount_root.as_ref()
            && !allowed_local_mount_roots
                .iter()
                .any(|candidate| candidate == root)
        {
            allowed_local_mount_roots.push(root.clone());
        }

        Self {
            allowed_hosts: normalize_set(&config.ssh.allowed_hosts),
            denied_hosts: normalize_set(&config.ssh.denied_hosts),
            allowed_users: normalize_set(&config.ssh.allowed_users),
            allowed_auth_kinds: normalize_set(&config.ssh.allowed_auth_kinds),
            allowed_tunnel_bind_hosts: normalize_set(&config.ssh.allowed_tunnel_bind_hosts),
            allow_explicit_mount_paths: config.ssh.allow_explicit_mount_paths,
            managed_mount_root,
            allowed_local_mount_roots,
            port_min: config.ssh.port_min,
            port_max: config.ssh.port_max,
        }
    }

    pub fn validate_target(&self, host: &str, user: Option<&str>, port: Option<u16>) -> Result<()> {
        let normalized_host =
            normalize_value(host).ok_or_else(|| anyhow::anyhow!("ssh host cannot be empty"))?;

        if self
            .denied_hosts
            .iter()
            .any(|pattern| host_matches(pattern, &normalized_host))
        {
            bail!("ssh host is denied by policy: host={host}");
        }

        if !self.allowed_hosts.is_empty()
            && !self
                .allowed_hosts
                .iter()
                .any(|pattern| host_matches(pattern, &normalized_host))
        {
            bail!("ssh host is not in the allowlist: host={host}");
        }

        if let Some(user) = user {
            let normalized_user =
                normalize_value(user).ok_or_else(|| anyhow::anyhow!("ssh user cannot be empty"))?;

            if !self.allowed_users.is_empty() && !self.allowed_users.contains(&normalized_user) {
                bail!("ssh user is not in the allowlist: user={user}");
            }
        } else if !self.allowed_users.is_empty() {
            bail!("ssh user is required by policy");
        }

        let port = port.unwrap_or(DEFAULT_SSH_PORT);
        ensure!(port != 0, "ssh port must be greater than 0: port={port}");

        if port < self.port_min || port > self.port_max {
            bail!(
                "ssh port is outside the allowed policy range: port={port} port_min={} port_max={}",
                self.port_min,
                self.port_max
            );
        }

        Ok(())
    }

    pub fn validate_auth(
        &self,
        auth_kind: SshAuthKind,
        host_alias: Option<&str>,
        identity_path: Option<&Path>,
    ) -> Result<()> {
        let normalized_auth_kind = normalize_value(match auth_kind {
            SshAuthKind::ConfigAlias => "config_alias",
            SshAuthKind::SshAgent => "ssh_agent",
            SshAuthKind::IdentityFile => "identity_file",
        })
        .expect("auth kind");
        if !self.allowed_auth_kinds.is_empty()
            && !self.allowed_auth_kinds.contains(&normalized_auth_kind)
        {
            bail!("ssh auth kind is not allowed by policy: auth_kind={normalized_auth_kind}");
        }

        match auth_kind {
            SshAuthKind::ConfigAlias => {
                ensure!(
                    host_alias
                        .map(str::trim)
                        .filter(|value| !value.is_empty())
                        .is_some(),
                    "host_alias is required when auth_kind=config_alias"
                );
            }
            SshAuthKind::IdentityFile => {
                let Some(identity_path) = identity_path else {
                    bail!("identity_path is required when auth_kind=identity_file");
                };

                ensure!(
                    identity_path.is_absolute(),
                    "identity_path must be absolute: identity_path={}",
                    identity_path.display()
                );

                let metadata = fs::metadata(identity_path).map_err(|_| {
                    anyhow::anyhow!(
                        "identity_path does not exist: identity_path={}",
                        identity_path.display()
                    )
                })?;
                ensure!(
                    metadata.is_file(),
                    "identity_path must be a file: identity_path={}",
                    identity_path.display()
                );
            }
            SshAuthKind::SshAgent => {}
        }

        Ok(())
    }

    pub fn validate_remote_path(&self, remote_path: &str) -> Result<()> {
        let trimmed = remote_path.trim();
        ensure!(!trimmed.is_empty(), "remote_path cannot be empty");

        ensure!(
            trimmed.starts_with('/'),
            "remote_path must be absolute: remote_path={remote_path}"
        );

        Ok(())
    }

    pub fn validate_tunnel_bind_host(&self, bind_host: &str) -> Result<()> {
        let normalized = normalize_value(bind_host)
            .ok_or_else(|| anyhow::anyhow!("ssh tunnel bind_host cannot be empty"))?;
        let is_loopback = matches!(normalized.as_str(), "127.0.0.1" | "::1" | "localhost");
        if is_loopback {
            return Ok(());
        }

        if self.allowed_tunnel_bind_hosts.contains(&normalized) {
            return Ok(());
        }

        bail!(
            "ssh tunnel bind_host is not allowed by policy: bind_host={bind_host} allowed_tunnel_bind_hosts={:?}",
            self.allowed_tunnel_bind_hosts
        )
    }

    pub fn validate_local_mount_path(&self, path: &Path) -> Result<()> {
        ensure!(
            path.is_absolute(),
            "ssh mount local_path must be absolute: local_path={}",
            path.display()
        );

        let normalized = normalize_path(path);
        if let Some(root) = &self.managed_mount_root
            && normalized.starts_with(root)
        {
            return Ok(());
        }

        if !self.allow_explicit_mount_paths {
            bail!(
                "explicit ssh mount paths are disabled by policy: local_path={}",
                normalized.display()
            );
        }

        if self
            .allowed_local_mount_roots
            .iter()
            .any(|root| normalized.starts_with(root))
        {
            return Ok(());
        }

        bail!(
            "ssh mount local_path is outside allowed roots: local_path={} allowed_roots={:?}",
            normalized.display(),
            self.allowed_local_mount_roots
                .iter()
                .map(|root| root.display().to_string())
                .collect::<Vec<_>>()
        )
    }
}

fn normalize_set(values: &[String]) -> BTreeSet<String> {
    values
        .iter()
        .filter_map(|value| normalize_value(value))
        .collect()
}

fn normalize_value(value: &str) -> Option<String> {
    let trimmed = value.trim().to_ascii_lowercase();
    if trimmed.is_empty() {
        None
    } else {
        Some(trimmed)
    }
}

fn host_matches(pattern: &str, host: &str) -> bool {
    if pattern == "*" {
        return true;
    }
    if let Some(suffix) = pattern.strip_prefix("*.") {
        return host == suffix || host.ends_with(&format!(".{suffix}"));
    }
    pattern == host
}

fn normalize_path(path: &Path) -> PathBuf {
    let absolute = if path.is_absolute() {
        path.to_path_buf()
    } else {
        std::env::current_dir()
            .unwrap_or_else(|_| PathBuf::from("."))
            .join(path)
    };

    let mut normalized = PathBuf::new();
    for component in absolute.components() {
        match component {
            Component::RootDir => normalized.push(component.as_os_str()),
            Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
            Component::Normal(segment) => normalized.push(segment),
            Component::CurDir => {}
            Component::ParentDir => {
                let _ = normalized.pop();
            }
        }
    }
    normalized
}