Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    /// Resolves symlinks so the rename targets the real file, not the link.
12    pub fn write(&self) -> Result<()> {
13        // Resolve symlinks so we write through to the real file
14        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
15
16        // Create backup if the file exists, keep only last 5
17        if self.path.exists() {
18            self.create_backup()
19                .context("Failed to create backup of SSH config")?;
20            self.prune_backups(5).ok();
21        }
22
23        let content = self.serialize();
24
25        // Ensure parent directory exists
26        if let Some(parent) = target_path.parent() {
27            fs::create_dir_all(parent)
28                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
29        }
30
31        // Atomic write: write to temp file (created with 0o600), then rename
32        let tmp_path =
33            target_path.with_extension(format!("purple_tmp.{}", std::process::id()));
34
35        #[cfg(unix)]
36        {
37            use std::io::Write;
38            use std::os::unix::fs::OpenOptionsExt;
39            let mut file = fs::OpenOptions::new()
40                .write(true)
41                .create(true)
42                .truncate(true)
43                .mode(0o600)
44                .open(&tmp_path)
45                .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
46            file.write_all(content.as_bytes())
47                .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48        }
49
50        #[cfg(not(unix))]
51        fs::write(&tmp_path, &content)
52            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
53
54        let result = fs::rename(&tmp_path, &target_path);
55        if result.is_err() {
56            let _ = fs::remove_file(&tmp_path);
57        }
58        result.with_context(|| {
59            format!(
60                "Failed to rename {} to {}",
61                tmp_path.display(),
62                target_path.display()
63            )
64        })?;
65
66        Ok(())
67    }
68
69    /// Serialize the config to a string.
70    /// Collapses consecutive blank lines to prevent accumulation after deletions.
71    pub fn serialize(&self) -> String {
72        let mut lines = Vec::new();
73
74        for element in &self.elements {
75            match element {
76                ConfigElement::GlobalLine(line) => {
77                    lines.push(line.clone());
78                }
79                ConfigElement::HostBlock(block) => {
80                    lines.push(block.raw_host_line.clone());
81                    for directive in &block.directives {
82                        lines.push(directive.raw_line.clone());
83                    }
84                }
85                ConfigElement::Include(include) => {
86                    lines.push(include.raw_line.clone());
87                }
88            }
89        }
90
91        // Collapse consecutive blank lines (keep at most one)
92        let mut collapsed = Vec::with_capacity(lines.len());
93        let mut prev_blank = false;
94        for line in lines {
95            let is_blank = line.trim().is_empty();
96            if is_blank && prev_blank {
97                continue;
98            }
99            prev_blank = is_blank;
100            collapsed.push(line);
101        }
102
103        let line_ending = if self.crlf { "\r\n" } else { "\n" };
104        let mut result = String::new();
105        for line in &collapsed {
106            result.push_str(line);
107            result.push_str(line_ending);
108        }
109        // Ensure non-empty files end with exactly one newline
110        if result.is_empty() {
111            result.push_str(line_ending);
112        }
113        result
114    }
115
116    /// Create a timestamped backup of the current config file.
117    fn create_backup(&self) -> Result<()> {
118        let timestamp = SystemTime::now()
119            .duration_since(SystemTime::UNIX_EPOCH)
120            .unwrap_or_default()
121            .as_millis();
122        let backup_name = format!(
123            "{}.bak.{}",
124            self.path.file_name().unwrap_or_default().to_string_lossy(),
125            timestamp
126        );
127        let backup_path = self.path.with_file_name(backup_name);
128        fs::copy(&self.path, &backup_path).with_context(|| {
129            format!(
130                "Failed to copy {} to {}",
131                self.path.display(),
132                backup_path.display()
133            )
134        })?;
135        Ok(())
136    }
137
138    /// Remove old backups, keeping only the most recent `keep` files.
139    fn prune_backups(&self, keep: usize) -> Result<()> {
140        let parent = self.path.parent().context("No parent directory")?;
141        let prefix = format!(
142            "{}.bak.",
143            self.path.file_name().unwrap_or_default().to_string_lossy()
144        );
145        let mut backups: Vec<_> = fs::read_dir(parent)?
146            .filter_map(|e| e.ok())
147            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148            .collect();
149        backups.sort_by_key(|e| e.file_name());
150        if backups.len() > keep {
151            for old in &backups[..backups.len() - keep] {
152                let _ = fs::remove_file(old.path());
153            }
154        }
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::ssh_config::model::HostEntry;
163    use std::path::PathBuf;
164
165    fn parse_str(content: &str) -> SshConfigFile {
166        SshConfigFile {
167            elements: SshConfigFile::parse_content(content),
168            path: PathBuf::from("/tmp/test_config"),
169            crlf: content.contains("\r\n"),
170        }
171    }
172
173    #[test]
174    fn test_round_trip_basic() {
175        let content = "\
176Host myserver
177  HostName 192.168.1.10
178  User admin
179  Port 2222
180";
181        let config = parse_str(content);
182        assert_eq!(config.serialize(), content);
183    }
184
185    #[test]
186    fn test_round_trip_with_comments() {
187        let content = "\
188# My SSH config
189# Generated by hand
190
191Host alpha
192  HostName alpha.example.com
193  # Deploy user
194  User deploy
195
196Host beta
197  HostName beta.example.com
198  User root
199";
200        let config = parse_str(content);
201        assert_eq!(config.serialize(), content);
202    }
203
204    #[test]
205    fn test_round_trip_with_globals_and_wildcards() {
206        let content = "\
207# Global settings
208Host *
209  ServerAliveInterval 60
210  ServerAliveCountMax 3
211
212Host production
213  HostName prod.example.com
214  User deployer
215  IdentityFile ~/.ssh/prod_key
216";
217        let config = parse_str(content);
218        assert_eq!(config.serialize(), content);
219    }
220
221    #[test]
222    fn test_add_host_serializes() {
223        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
224        config.add_host(&HostEntry {
225            alias: "newhost".to_string(),
226            hostname: "10.0.0.2".to_string(),
227            user: "admin".to_string(),
228            port: 22,
229            ..Default::default()
230        });
231        let output = config.serialize();
232        assert!(output.contains("Host newhost"));
233        assert!(output.contains("HostName 10.0.0.2"));
234        assert!(output.contains("User admin"));
235        // Port 22 is default, should not be written
236        assert!(!output.contains("Port 22"));
237    }
238
239    #[test]
240    fn test_delete_host_serializes() {
241        let content = "\
242Host alpha
243  HostName alpha.example.com
244
245Host beta
246  HostName beta.example.com
247";
248        let mut config = parse_str(content);
249        config.delete_host("alpha");
250        let output = config.serialize();
251        assert!(!output.contains("Host alpha"));
252        assert!(output.contains("Host beta"));
253    }
254
255    #[test]
256    fn test_update_host_serializes() {
257        let content = "\
258Host myserver
259  HostName 10.0.0.1
260  User old_user
261";
262        let mut config = parse_str(content);
263        config.update_host(
264            "myserver",
265            &HostEntry {
266                alias: "myserver".to_string(),
267                hostname: "10.0.0.2".to_string(),
268                user: "new_user".to_string(),
269                port: 22,
270                ..Default::default()
271            },
272        );
273        let output = config.serialize();
274        assert!(output.contains("HostName 10.0.0.2"));
275        assert!(output.contains("User new_user"));
276        assert!(!output.contains("old_user"));
277    }
278
279    #[test]
280    fn test_update_host_preserves_unknown_directives() {
281        let content = "\
282Host myserver
283  HostName 10.0.0.1
284  User admin
285  ForwardAgent yes
286  LocalForward 8080 localhost:80
287  Compression yes
288";
289        let mut config = parse_str(content);
290        config.update_host(
291            "myserver",
292            &HostEntry {
293                alias: "myserver".to_string(),
294                hostname: "10.0.0.2".to_string(),
295                user: "admin".to_string(),
296                port: 22,
297                ..Default::default()
298            },
299        );
300        let output = config.serialize();
301        assert!(output.contains("HostName 10.0.0.2"));
302        assert!(output.contains("ForwardAgent yes"));
303        assert!(output.contains("LocalForward 8080 localhost:80"));
304        assert!(output.contains("Compression yes"));
305    }
306}