Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            debug!("[purple] ssh_config.write skipped (demo mode)");
19            return Ok(());
20        }
21        // Resolve symlinks so we write through to the real file
22        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23        debug!(
24            "[config] ssh_config.write: target={} elements={}",
25            target_path.display(),
26            self.elements.len()
27        );
28
29        // Acquire advisory lock (blocks until available)
30        let _lock = fs_util::FileLock::acquire(&target_path)
31            .inspect_err(|e| {
32                debug!(
33                    "[config] ssh_config.write: lock acquire failed: {} ({})",
34                    target_path.display(),
35                    e
36                );
37            })
38            .context("Failed to acquire config lock")?;
39
40        // Create backup if the file exists, keep only last 5.
41        // Use the canonical `target_path` so backups land next to the real
42        // file rather than next to a symlink. Without this, a user with
43        // `~/.ssh/config -> /Volumes/dotfiles/ssh_config` would see backups
44        // accumulate in `~/.ssh/` even though the actual file lives on the
45        // dotfiles volume — recovery becomes confusing.
46        if target_path.exists() {
47            self.create_backup(&target_path)
48                .context("Failed to create backup of SSH config")?;
49            self.prune_backups(&target_path, 5).ok();
50        }
51
52        let content = self.serialize();
53
54        fs_util::atomic_write(&target_path, content.as_bytes())
55            .map_err(|err| {
56                error!(
57                    "[purple] SSH config write failed: {}: {err}",
58                    target_path.display()
59                );
60                err
61            })
62            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
63
64        debug!(
65            "[config] ssh_config.write: wrote {} bytes to {}",
66            content.len(),
67            target_path.display()
68        );
69
70        // Lock released on drop
71        Ok(())
72    }
73
74    /// Serialize the config to a string.
75    /// Collapses consecutive blank lines to prevent accumulation after deletions.
76    pub fn serialize(&self) -> String {
77        let mut lines = Vec::new();
78
79        for element in &self.elements {
80            match element {
81                ConfigElement::GlobalLine(line) => {
82                    lines.push(line.clone());
83                }
84                ConfigElement::HostBlock(block) => {
85                    lines.push(block.raw_host_line.clone());
86                    for directive in &block.directives {
87                        lines.push(directive.raw_line.clone());
88                    }
89                }
90                ConfigElement::Include(include) => {
91                    lines.push(include.raw_line.clone());
92                }
93            }
94        }
95
96        // Collapse consecutive blank lines (keep at most one)
97        let mut collapsed = Vec::with_capacity(lines.len());
98        let mut prev_blank = false;
99        for line in lines {
100            let is_blank = line.trim().is_empty();
101            if is_blank && prev_blank {
102                continue;
103            }
104            prev_blank = is_blank;
105            collapsed.push(line);
106        }
107
108        let line_ending = if self.crlf { "\r\n" } else { "\n" };
109        let mut result = String::new();
110        // Restore UTF-8 BOM if the original file had one
111        if self.bom {
112            result.push('\u{FEFF}');
113        }
114        for line in &collapsed {
115            result.push_str(line);
116            result.push_str(line_ending);
117        }
118        // Ensure files always end with exactly one newline
119        // (check collapsed instead of result, since BOM makes result non-empty)
120        if collapsed.is_empty() {
121            result.push_str(line_ending);
122        }
123        result
124    }
125
126    /// Create a timestamped backup of the current config file.
127    /// Backup files are created with chmod 600 to match the source file's sensitivity.
128    fn create_backup(&self, target_path: &std::path::Path) -> Result<()> {
129        let timestamp = SystemTime::now()
130            .duration_since(SystemTime::UNIX_EPOCH)
131            .unwrap_or_default()
132            .as_millis();
133        let backup_name = format!(
134            "{}.bak.{}",
135            target_path
136                .file_name()
137                .unwrap_or_default()
138                .to_string_lossy(),
139            timestamp
140        );
141        let backup_path = target_path.with_file_name(backup_name);
142        fs::copy(target_path, &backup_path).with_context(|| {
143            format!(
144                "Failed to copy {} to {}",
145                target_path.display(),
146                backup_path.display()
147            )
148        })?;
149
150        // Set backup permissions to 600 (owner read/write only)
151        #[cfg(unix)]
152        {
153            use std::os::unix::fs::PermissionsExt;
154            if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
155                debug!(
156                    "[config] Failed to set backup permissions on {}: {e}",
157                    backup_path.display()
158                );
159            }
160        }
161
162        Ok(())
163    }
164
165    /// Remove old backups, keeping only the most recent `keep` files.
166    fn prune_backups(&self, target_path: &std::path::Path, keep: usize) -> Result<()> {
167        let parent = target_path.parent().context("No parent directory")?;
168        let prefix = format!(
169            "{}.bak.",
170            target_path
171                .file_name()
172                .unwrap_or_default()
173                .to_string_lossy()
174        );
175        let mut backups: Vec<_> = fs::read_dir(parent)?
176            .filter_map(|e| e.ok())
177            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
178            .collect();
179        // Sort by mtime so prune is robust against future timestamp-digit-width
180        // changes. Filename sort would silently break if the millisecond
181        // suffix length ever grew.
182        backups.sort_by_key(|e| {
183            e.metadata()
184                .and_then(|m| m.modified())
185                .unwrap_or(SystemTime::UNIX_EPOCH)
186        });
187        if backups.len() > keep {
188            for old in &backups[..backups.len() - keep] {
189                if let Err(e) = fs::remove_file(old.path()) {
190                    debug!(
191                        "[config] Failed to prune old backup {}: {e}",
192                        old.path().display()
193                    );
194                }
195            }
196        }
197        Ok(())
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::ssh_config::model::HostEntry;
205
206    fn parse_str(content: &str) -> SshConfigFile {
207        SshConfigFile {
208            elements: SshConfigFile::parse_content(content),
209            path: tempfile::tempdir()
210                .expect("tempdir")
211                .keep()
212                .join("test_config"),
213            crlf: crate::ssh_config::parser::detect_crlf_majority(content),
214            bom: false,
215        }
216    }
217
218    #[test]
219    fn test_round_trip_basic() {
220        let content = "\
221Host myserver
222  HostName 192.168.1.10
223  User admin
224  Port 2222
225";
226        let config = parse_str(content);
227        assert_eq!(config.serialize(), content);
228    }
229
230    #[test]
231    fn test_round_trip_with_comments() {
232        let content = "\
233# My SSH config
234# Generated by hand
235
236Host alpha
237  HostName alpha.example.com
238  # Deploy user
239  User deploy
240
241Host beta
242  HostName beta.example.com
243  User root
244";
245        let config = parse_str(content);
246        assert_eq!(config.serialize(), content);
247    }
248
249    #[test]
250    fn test_round_trip_with_globals_and_wildcards() {
251        let content = "\
252# Global settings
253Host *
254  ServerAliveInterval 60
255  ServerAliveCountMax 3
256
257Host production
258  HostName prod.example.com
259  User deployer
260  IdentityFile ~/.ssh/prod_key
261";
262        let config = parse_str(content);
263        assert_eq!(config.serialize(), content);
264    }
265
266    #[test]
267    fn test_add_host_serializes() {
268        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
269        config.add_host(&HostEntry {
270            alias: "newhost".to_string(),
271            hostname: "10.0.0.2".to_string(),
272            user: "admin".to_string(),
273            port: 22,
274            ..Default::default()
275        });
276        let output = config.serialize();
277        assert!(output.contains("Host newhost"));
278        assert!(output.contains("HostName 10.0.0.2"));
279        assert!(output.contains("User admin"));
280        // Port 22 is default, should not be written
281        assert!(!output.contains("Port 22"));
282    }
283
284    #[test]
285    fn test_delete_host_serializes() {
286        let content = "\
287Host alpha
288  HostName alpha.example.com
289
290Host beta
291  HostName beta.example.com
292";
293        let mut config = parse_str(content);
294        config.delete_host("alpha");
295        let output = config.serialize();
296        assert!(!output.contains("Host alpha"));
297        assert!(output.contains("Host beta"));
298    }
299
300    #[test]
301    fn test_update_host_serializes() {
302        let content = "\
303Host myserver
304  HostName 10.0.0.1
305  User old_user
306";
307        let mut config = parse_str(content);
308        config.update_host(
309            "myserver",
310            &HostEntry {
311                alias: "myserver".to_string(),
312                hostname: "10.0.0.2".to_string(),
313                user: "new_user".to_string(),
314                port: 22,
315                ..Default::default()
316            },
317        );
318        let output = config.serialize();
319        assert!(output.contains("HostName 10.0.0.2"));
320        assert!(output.contains("User new_user"));
321        assert!(!output.contains("old_user"));
322    }
323
324    #[test]
325    fn test_update_host_preserves_unknown_directives() {
326        let content = "\
327Host myserver
328  HostName 10.0.0.1
329  User admin
330  ForwardAgent yes
331  LocalForward 8080 localhost:80
332  Compression yes
333";
334        let mut config = parse_str(content);
335        config.update_host(
336            "myserver",
337            &HostEntry {
338                alias: "myserver".to_string(),
339                hostname: "10.0.0.2".to_string(),
340                user: "admin".to_string(),
341                port: 22,
342                ..Default::default()
343            },
344        );
345        let output = config.serialize();
346        assert!(output.contains("HostName 10.0.0.2"));
347        assert!(output.contains("ForwardAgent yes"));
348        assert!(output.contains("LocalForward 8080 localhost:80"));
349        assert!(output.contains("Compression yes"));
350    }
351}