use std::fs;
use std::time::SystemTime;
use anyhow::{Context, Result};
use super::model::{ConfigElement, SshConfigFile};
use crate::fs_util;
impl SshConfigFile {
pub fn write(&self) -> Result<()> {
let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
if self.path.exists() {
self.create_backup()
.context("Failed to create backup of SSH config")?;
self.prune_backups(5).ok();
}
let content = self.serialize();
fs_util::atomic_write(&target_path, content.as_bytes())
.with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
Ok(())
}
pub fn serialize(&self) -> String {
let mut lines = Vec::new();
for element in &self.elements {
match element {
ConfigElement::GlobalLine(line) => {
lines.push(line.clone());
}
ConfigElement::HostBlock(block) => {
lines.push(block.raw_host_line.clone());
for directive in &block.directives {
lines.push(directive.raw_line.clone());
}
}
ConfigElement::Include(include) => {
lines.push(include.raw_line.clone());
}
}
}
let mut collapsed = Vec::with_capacity(lines.len());
let mut prev_blank = false;
for line in lines {
let is_blank = line.trim().is_empty();
if is_blank && prev_blank {
continue;
}
prev_blank = is_blank;
collapsed.push(line);
}
let line_ending = if self.crlf { "\r\n" } else { "\n" };
let mut result = String::new();
for line in &collapsed {
result.push_str(line);
result.push_str(line_ending);
}
if result.is_empty() {
result.push_str(line_ending);
}
result
}
fn create_backup(&self) -> Result<()> {
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let backup_name = format!(
"{}.bak.{}",
self.path.file_name().unwrap_or_default().to_string_lossy(),
timestamp
);
let backup_path = self.path.with_file_name(backup_name);
fs::copy(&self.path, &backup_path).with_context(|| {
format!(
"Failed to copy {} to {}",
self.path.display(),
backup_path.display()
)
})?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
}
Ok(())
}
fn prune_backups(&self, keep: usize) -> Result<()> {
let parent = self.path.parent().context("No parent directory")?;
let prefix = format!(
"{}.bak.",
self.path.file_name().unwrap_or_default().to_string_lossy()
);
let mut backups: Vec<_> = fs::read_dir(parent)?
.filter_map(|e| e.ok())
.filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
.collect();
backups.sort_by_key(|e| e.file_name());
if backups.len() > keep {
for old in &backups[..backups.len() - keep] {
let _ = fs::remove_file(old.path());
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ssh_config::model::HostEntry;
use std::path::PathBuf;
fn parse_str(content: &str) -> SshConfigFile {
SshConfigFile {
elements: SshConfigFile::parse_content(content),
path: PathBuf::from("/tmp/test_config"),
crlf: content.contains("\r\n"),
}
}
#[test]
fn test_round_trip_basic() {
let content = "\
Host myserver
HostName 192.168.1.10
User admin
Port 2222
";
let config = parse_str(content);
assert_eq!(config.serialize(), content);
}
#[test]
fn test_round_trip_with_comments() {
let content = "\
# My SSH config
# Generated by hand
Host alpha
HostName alpha.example.com
# Deploy user
User deploy
Host beta
HostName beta.example.com
User root
";
let config = parse_str(content);
assert_eq!(config.serialize(), content);
}
#[test]
fn test_round_trip_with_globals_and_wildcards() {
let content = "\
# Global settings
Host *
ServerAliveInterval 60
ServerAliveCountMax 3
Host production
HostName prod.example.com
User deployer
IdentityFile ~/.ssh/prod_key
";
let config = parse_str(content);
assert_eq!(config.serialize(), content);
}
#[test]
fn test_add_host_serializes() {
let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
config.add_host(&HostEntry {
alias: "newhost".to_string(),
hostname: "10.0.0.2".to_string(),
user: "admin".to_string(),
port: 22,
..Default::default()
});
let output = config.serialize();
assert!(output.contains("Host newhost"));
assert!(output.contains("HostName 10.0.0.2"));
assert!(output.contains("User admin"));
assert!(!output.contains("Port 22"));
}
#[test]
fn test_delete_host_serializes() {
let content = "\
Host alpha
HostName alpha.example.com
Host beta
HostName beta.example.com
";
let mut config = parse_str(content);
config.delete_host("alpha");
let output = config.serialize();
assert!(!output.contains("Host alpha"));
assert!(output.contains("Host beta"));
}
#[test]
fn test_update_host_serializes() {
let content = "\
Host myserver
HostName 10.0.0.1
User old_user
";
let mut config = parse_str(content);
config.update_host(
"myserver",
&HostEntry {
alias: "myserver".to_string(),
hostname: "10.0.0.2".to_string(),
user: "new_user".to_string(),
port: 22,
..Default::default()
},
);
let output = config.serialize();
assert!(output.contains("HostName 10.0.0.2"));
assert!(output.contains("User new_user"));
assert!(!output.contains("old_user"));
}
#[test]
fn test_update_host_preserves_unknown_directives() {
let content = "\
Host myserver
HostName 10.0.0.1
User admin
ForwardAgent yes
LocalForward 8080 localhost:80
Compression yes
";
let mut config = parse_str(content);
config.update_host(
"myserver",
&HostEntry {
alias: "myserver".to_string(),
hostname: "10.0.0.2".to_string(),
user: "admin".to_string(),
port: 22,
..Default::default()
},
);
let output = config.serialize();
assert!(output.contains("HostName 10.0.0.2"));
assert!(output.contains("ForwardAgent yes"));
assert!(output.contains("LocalForward 8080 localhost:80"));
assert!(output.contains("Compression yes"));
}
}