pty-mcp 0.2.2

An MCP server for PTY management with SSH connections, remote sessions, file access, and mounts
Documentation
use std::path::{Path, PathBuf};

use anyhow::{Result, bail, ensure};

use crate::config::SshConfig;

use super::{
    model::{SshAuthKind, SshTarget},
    policy::SshPolicy,
};

#[derive(Debug, Clone)]
pub struct SshConnectValidationInput<'a> {
    pub target: &'a SshTarget,
    pub auth_kind: Option<SshAuthKind>,
    pub identity_path: Option<&'a str>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SshConnectValidationResult {
    pub auth_kind: SshAuthKind,
    pub identity_path: Option<PathBuf>,
}

#[derive(Debug, Clone)]
pub struct SshMountValidationInput<'a> {
    pub local_path: &'a Path,
    pub remote_path: &'a str,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SshMountValidationResult {
    pub local_path: PathBuf,
    pub remote_path: String,
    pub is_managed_path: bool,
}

#[derive(Debug, Clone)]
pub struct SshGuard {
    policy: SshPolicy,
}

impl SshGuard {
    pub fn new(policy: SshPolicy) -> Self {
        Self { policy }
    }

    pub fn policy(&self) -> &SshPolicy {
        &self.policy
    }

    pub fn validate_connect_target(&self, target: &SshTarget) -> Result<()> {
        self.policy
            .validate_target(&target.host, target.user.as_deref(), target.port)
    }

    pub fn validate_connect_request(
        &self,
        config: &SshConfig,
        input: SshConnectValidationInput<'_>,
    ) -> Result<SshConnectValidationResult> {
        self.validate_connect_target(input.target)?;
        self.validate_host_user_port(config, input.target)?;

        let auth_kind = resolve_auth_kind(input.target, input.auth_kind, input.identity_path)?;
        self.validate_auth_kind(config, &auth_kind, input.target)?;
        let identity_path = validate_identity_path(&auth_kind, input.identity_path)?;

        Ok(SshConnectValidationResult {
            auth_kind,
            identity_path,
        })
    }

    pub fn validate_mount_local_path(&self, local_path: &Path) -> Result<()> {
        self.policy.validate_local_mount_path(local_path)
    }

    pub fn validate_mount_request(
        &self,
        config: &SshConfig,
        input: SshMountValidationInput<'_>,
    ) -> Result<SshMountValidationResult> {
        self.validate_mount_local_path(input.local_path)?;
        let remote_path = input.remote_path.trim();
        ensure!(
            !remote_path.is_empty() && remote_path.starts_with('/'),
            "remote_path must be an absolute path: remote_path={}",
            input.remote_path
        );

        let local_path = input.local_path.to_path_buf();
        let is_managed = config
            .managed_mount_root
            .as_ref()
            .is_some_and(|root| local_path.starts_with(root));
        let is_controlled_explicit = config
            .allowed_mount_roots
            .iter()
            .any(|root| local_path.starts_with(root));

        ensure!(
            is_managed || (config.allow_explicit_mount_paths && is_controlled_explicit),
            "mount local_path is outside allowed managed/controlled roots: local_path={} managed_mount_root={:?} allow_explicit_mount_paths={} allowed_mount_roots={:?}",
            local_path.display(),
            config.managed_mount_root,
            config.allow_explicit_mount_paths,
            config.allowed_mount_roots
        );

        Ok(SshMountValidationResult {
            local_path,
            remote_path: remote_path.to_string(),
            is_managed_path: is_managed,
        })
    }

    fn validate_host_user_port(&self, config: &SshConfig, target: &SshTarget) -> Result<()> {
        let host = target.host.trim();
        let host_alias = target.host_alias.as_deref().unwrap_or_default().trim();
        let user = target.user.as_deref().unwrap_or_default().trim();
        let port = target.port.unwrap_or(22);

        if contains_ci(&config.denied_hosts, host)
            || (!host_alias.is_empty() && contains_ci(&config.denied_hosts, host_alias))
        {
            bail!("target host is blocked by ssh policy: host={host} host_alias={host_alias}");
        }

        if !config.allowed_hosts.is_empty()
            && !contains_ci(&config.allowed_hosts, host)
            && (host_alias.is_empty() || !contains_ci(&config.allowed_hosts, host_alias))
        {
            bail!("target host is not allowed by ssh policy: host={host} host_alias={host_alias}");
        }

        if !config.allowed_users.is_empty()
            && (user.is_empty() || !contains_ci(&config.allowed_users, user))
        {
            bail!("target user is not allowed by ssh policy: user={user} host={host}");
        }

        if port < config.port_min || port > config.port_max {
            bail!(
                "target port is outside allowed ssh policy range: port={port} port_min={} port_max={} host={host}",
                config.port_min,
                config.port_max
            );
        }

        Ok(())
    }

    fn validate_auth_kind(
        &self,
        config: &SshConfig,
        auth_kind: &SshAuthKind,
        target: &SshTarget,
    ) -> Result<()> {
        if config.allowed_auth_kinds.is_empty() {
            return Ok(());
        }

        let key = match auth_kind {
            SshAuthKind::ConfigAlias => "host_alias",
            SshAuthKind::SshAgent => "ssh_agent",
            SshAuthKind::IdentityFile => "identity_path",
        };

        ensure!(
            contains_ci(&config.allowed_auth_kinds, key),
            "ssh auth kind is blocked by policy: auth_kind={key} allowed_auth_kinds={:?} host_alias={:?}",
            config.allowed_auth_kinds,
            target.host_alias
        );

        Ok(())
    }
}

fn resolve_auth_kind(
    target: &SshTarget,
    auth_kind: Option<SshAuthKind>,
    identity_path: Option<&str>,
) -> Result<SshAuthKind> {
    if let Some(auth_kind) = auth_kind {
        return Ok(auth_kind);
    }

    if target
        .host_alias
        .as_deref()
        .is_some_and(|value| !value.trim().is_empty())
    {
        return Ok(SshAuthKind::ConfigAlias);
    }

    if identity_path.is_some_and(|path| !path.trim().is_empty()) {
        return Ok(SshAuthKind::IdentityFile);
    }

    Ok(SshAuthKind::SshAgent)
}

fn validate_identity_path(
    auth_kind: &SshAuthKind,
    identity_path: Option<&str>,
) -> Result<Option<PathBuf>> {
    match auth_kind {
        SshAuthKind::IdentityFile => {
            let path = identity_path
                .map(str::trim)
                .filter(|value| !value.is_empty())
                .ok_or_else(|| {
                    anyhow::anyhow!("identity_path is required when auth_kind=identity_file")
                })?;
            let path = PathBuf::from(path);
            ensure!(
                path.is_absolute(),
                "identity_path must be an absolute path: identity_path={}",
                path.display()
            );
            Ok(Some(path))
        }
        _ => Ok(None),
    }
}

fn contains_ci(items: &[String], candidate: &str) -> bool {
    let normalized = candidate.trim().to_ascii_lowercase();
    items
        .iter()
        .any(|value| value.trim().to_ascii_lowercase() == normalized)
}