pass-ssh-unpack 0.5.1

A utility for unpacking proton's pass-cli ssh keys into usable ssh and rclone configurations.
use anyhow::{Context, Result};
use sanitize_filename::Options as SanitizeOptions;
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::process::Command;

use crate::config::SyncPublicKey;
use crate::platform::{self, set_private_permissions};
use crate::proton_pass::{ProtonPass, SshItem};
use crate::rclone::RcloneEntry;

/// Sanitize a string for use as a filename or rclone remote name.
/// Replaces invalid filesystem characters with hyphens, spaces with underscores,
/// and removes parentheses.
pub fn sanitize_name(name: &str) -> String {
    let opts = SanitizeOptions {
        replacement: "-",
        ..Default::default()
    };
    sanitize_filename::sanitize_with_options(name, opts)
        .replace(' ', "_")
        .replace(['(', ')'], "")
}

const CONFIG_HEADER: &str = r#"# =============================================================================
# DO NOT EDIT THIS FILE - IT IS AUTO-GENERATED BY pass-ssh-unpack
# =============================================================================
# Any manual changes will be lost on the next run.
#
# To use these keys, add the following to your ~/.ssh/config:
#     Include ~/.ssh/proton-pass/config
#
# To regenerate: pass-ssh-unpack
# To regenerate fully: pass-ssh-unpack --full
# ============================================================================="#;

/// Manages SSH key extraction and config generation
pub struct SshManager {
    base_dir: PathBuf,
    config_path: PathBuf,
    existing_hosts: HashMap<String, String>,
    new_hosts: HashMap<String, String>,
    full_mode: bool,
    dry_run: bool,
    sync_public_key: SyncPublicKey,
}

impl SshManager {
    /// Create a new SSH manager
    pub fn new(
        base_dir: &Path,
        full_mode: bool,
        dry_run: bool,
        sync_public_key: SyncPublicKey,
    ) -> Result<Self> {
        let config_path = base_dir.join("config");

        if !dry_run {
            // Full mode: delete entire folder and start fresh
            if full_mode && base_dir.exists() {
                fs::remove_dir_all(base_dir)
                    .with_context(|| format!("Failed to remove {}", base_dir.display()))?;
            }

            fs::create_dir_all(base_dir)
                .with_context(|| format!("Failed to create {}", base_dir.display()))?;
        }

        // Load existing config for incremental updates
        let existing_hosts = if !full_mode && config_path.exists() {
            Self::parse_existing_config(&config_path)?
        } else {
            HashMap::new()
        };

        Ok(Self {
            base_dir: base_dir.to_path_buf(),
            config_path,
            existing_hosts,
            new_hosts: HashMap::new(),
            full_mode,
            dry_run,
            sync_public_key,
        })
    }

    /// Get the path to the SSH config file
    pub fn config_path(&self) -> &Path {
        &self.config_path
    }

    /// Process an SSH item, extracting keys and building config entries
    /// Returns an RcloneEntry if successful and has a host field
    pub fn process_item(
        &mut self,
        proton_pass: &ProtonPass,
        vault: &str,
        item: &SshItem,
        log: &impl Fn(&str),
    ) -> Result<Option<RcloneEntry>> {
        let host_field = match &item.host {
            Some(h) => h.clone(),
            None => {
                log("    -> skipped (no Host field)");
                return Ok(None);
            }
        };

        // Sanitize title for filename
        let safe_title = sanitize_name(&item.title);
        let vault_dir = self.base_dir.join(vault);

        if !self.dry_run {
            fs::create_dir_all(&vault_dir)?;
        }

        let privkey_path = vault_dir.join(&safe_title);
        let pubkey_path = vault_dir.join(format!("{}.pub", safe_title));

        let mut has_key = false;
        let mut identity_path = String::new();

        // Process private key if present
        if let Some(ref private_key) = item.private_key {
            if !private_key.is_empty() {
                if self.dry_run {
                    // In dry run, check if key already exists
                    has_key = true;
                    identity_path = format!(
                        "{}/.ssh/proton-pass/{}/{}",
                        platform::ssh_home_placeholder(),
                        vault,
                        safe_title
                    );
                    if privkey_path.exists() {
                        log(&format!("    -> {} (exists)", safe_title));
                    } else {
                        log(&format!("    -> {} (would write key)", safe_title));
                    }
                } else {
                    // Write private key
                    let mut file = File::create(&privkey_path)?;
                    writeln!(file, "{}", private_key)?;
                    drop(file);

                    // Set permissions
                    set_private_permissions(&privkey_path)?;

                    // Generate public key
                    let keygen_output = Command::new("ssh-keygen")
                        .args(["-y", "-f"])
                        .arg(&privkey_path)
                        .output()
                        .context("Failed to run ssh-keygen")?;

                    if keygen_output.status.success() {
                        let generated_pubkey = String::from_utf8_lossy(&keygen_output.stdout)
                            .trim()
                            .to_string();

                        fs::write(&pubkey_path, &generated_pubkey)?;
                        has_key = true;
                        identity_path = format!(
                            "{}/.ssh/proton-pass/{}/{}",
                            platform::ssh_home_placeholder(),
                            vault,
                            safe_title
                        );

                        // Determine if we should sync public key to Proton Pass
                        let pubkey_is_empty = item.public_key.is_none()
                            || item
                                .public_key
                                .as_ref()
                                .map(|s| s.is_empty())
                                .unwrap_or(true);

                        let should_sync = match self.sync_public_key {
                            SyncPublicKey::Never => false,
                            SyncPublicKey::IfEmpty => pubkey_is_empty,
                            SyncPublicKey::Always => true,
                        };

                        if should_sync {
                            match proton_pass.update_item_field(
                                vault,
                                &item.title,
                                "public_key",
                                &generated_pubkey,
                            ) {
                                Ok(_) => log(&format!(
                                    "    -> {} (saved pubkey to Proton Pass)",
                                    safe_title
                                )),
                                Err(_) => log(&format!(
                                    "    -> {} (failed to save pubkey to Proton Pass)",
                                    safe_title
                                )),
                            }
                        } else {
                            log(&format!("    -> {}", safe_title));
                        }
                    } else {
                        log(&format!(
                            "    -> {} (failed to generate public key)",
                            safe_title
                        ));
                        fs::remove_file(&privkey_path).ok();
                    }
                }
            }
        } else {
            log(&format!("    -> {} (no key, password auth)", safe_title));
        }

        // Build config entries
        let sanitized_host = sanitize_name(&host_field);
        let mut config_block = format!("Host {}", sanitized_host);
        if has_key {
            config_block.push_str(&format!(
                "\n    IdentityFile \"{}\"\n    IdentitiesOnly yes",
                identity_path
            ));
        }
        if let Some(ref username) = item.username {
            config_block.push_str(&format!("\n    User {}", username));
        }
        if let Some(ref jump) = item.jump {
            config_block.push_str(&format!("\n    ProxyJump {}", jump));
        }
        self.new_hosts.insert(sanitized_host.clone(), config_block);

        // Build alias entries
        let aliases_list: Vec<String> = if let Some(ref aliases) = item.aliases {
            aliases
                .split(',')
                .map(|s| s.trim().to_string())
                .filter(|s| !s.is_empty())
                .collect()
        } else {
            vec![item.title.clone()]
        };

        for alias_entry in &aliases_list {
            if alias_entry == &host_field {
                continue;
            }

            let sanitized_alias = sanitize_name(alias_entry);
            let mut alias_block =
                format!("# Alias of {}\nHost {}", sanitized_host, sanitized_alias);
            if has_key {
                alias_block.push_str(&format!(
                    "\n    IdentityFile \"{}\"\n    IdentitiesOnly yes",
                    identity_path
                ));
            }
            if let Some(ref username) = item.username {
                alias_block.push_str(&format!("\n    User {}", username));
            }
            if let Some(ref jump) = item.jump {
                alias_block.push_str(&format!("\n    ProxyJump {}", jump));
            }
            self.new_hosts.insert(sanitized_alias, alias_block);
        }

        // Build rclone entry
        let rclone_key_file = if has_key {
            format!("~/.ssh/proton-pass/{}/{}", vault, safe_title)
        } else {
            String::new()
        };

        // First alias is the remote name, rest are other_aliases
        let (remote_name, other_aliases) = if !aliases_list.is_empty() {
            let remote_name = sanitize_name(&aliases_list[0]);
            let other_aliases = if aliases_list.len() > 1 {
                aliases_list[1..]
                    .iter()
                    .map(|s| sanitize_name(s))
                    .collect::<Vec<_>>()
                    .join(",")
            } else {
                String::new()
            };
            (remote_name, other_aliases)
        } else {
            (sanitize_name(&item.title), String::new())
        };

        // Check if this is a valid entry for rclone/ssh:
        // Must have at least one of:
        // 1. A key file (private_key was present and generated)
        // 2. An SSH command ("ssh" field)
        // 3. A server command ("server_command" field)
        let is_valid = has_key || item.ssh.is_some() || item.server_command.is_some();

        if !is_valid {
            return Ok(None);
        }

        Ok(Some(RcloneEntry {
            remote_name,
            host: host_field,
            user: item.username.clone().unwrap_or_default(),
            key_file: rclone_key_file,
            other_aliases,
            ssh: item.ssh.clone(),
            server_command: item.server_command.clone(),
        }))
    }

    /// Write the final SSH config file
    /// Returns (primary_count, alias_count)
    pub fn write_config(&self) -> Result<(usize, usize)> {
        // Merge: new hosts override existing, keep existing if not touched
        let mut final_hosts = if self.full_mode {
            HashMap::new()
        } else {
            self.existing_hosts.clone()
        };

        // Override/add new hosts
        for (host, block) in &self.new_hosts {
            final_hosts.insert(host.clone(), block.clone());
        }

        // Write final config (skip in dry run)
        if !self.dry_run {
            let mut file = File::create(&self.config_path)?;
            writeln!(file, "{}", CONFIG_HEADER)?;

            // Sort hosts for consistent output
            let mut sorted_hosts: Vec<_> = final_hosts.keys().collect();
            sorted_hosts.sort();

            for host in sorted_hosts {
                writeln!(file)?;
                writeln!(file, "{}", final_hosts[host])?;
            }
        }

        // Count primaries and aliases
        let total_hosts = final_hosts.len();
        let alias_count = final_hosts
            .values()
            .filter(|block| block.contains("# Alias of"))
            .count();
        let primary_count = total_hosts - alias_count;

        Ok((primary_count, alias_count))
    }

    /// Parse existing SSH config file into host -> block map
    fn parse_existing_config(path: &Path) -> Result<HashMap<String, String>> {
        let file = File::open(path)?;
        let reader = BufReader::new(file);

        let mut hosts = HashMap::new();
        let mut current_host = String::new();
        let mut current_block = String::new();

        for line in reader.lines() {
            let line = line?;

            // Skip header comments
            if line.contains("DO NOT EDIT")
                || line.contains("=====")
                || line.contains("Include")
                || line.contains("regenerate")
                || line.contains("To use")
            {
                continue;
            }

            if line.starts_with("Host ") {
                // Save previous block
                if !current_host.is_empty() {
                    hosts.insert(current_host.clone(), current_block.clone());
                }

                current_host = line.strip_prefix("Host ").unwrap_or("").to_string();
                current_block = line.clone();
            } else if !current_host.is_empty() && !line.is_empty() {
                current_block.push('\n');
                current_block.push_str(&line);
            }
        }

        // Save last block
        if !current_host.is_empty() {
            hosts.insert(current_host, current_block);
        }

        Ok(hosts)
    }
}