Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            debug!("[purple] ssh_config.write skipped (demo mode)");
19            return Ok(());
20        }
21        // Resolve symlinks so we write through to the real file
22        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23        debug!(
24            "[config] ssh_config.write: target={} elements={}",
25            target_path.display(),
26            self.elements.len()
27        );
28
29        // Acquire advisory lock (blocks until available)
30        let _lock = fs_util::FileLock::acquire(&target_path)
31            .inspect_err(|e| {
32                debug!(
33                    "[config] ssh_config.write: lock acquire failed: {} ({})",
34                    target_path.display(),
35                    e
36                );
37            })
38            .context("Failed to acquire config lock")?;
39
40        // Create backup if the file exists, keep only last 5.
41        // Use the canonical `target_path` so backups land next to the real
42        // file rather than next to a symlink. Without this, a user with
43        // `~/.ssh/config -> /Volumes/dotfiles/ssh_config` would see backups
44        // accumulate in `~/.ssh/` even though the actual file lives on the
45        // dotfiles volume. Recovery becomes confusing.
46        if target_path.exists() {
47            self.create_backup(&target_path)
48                .context("Failed to create backup of SSH config")?;
49            self.prune_backups(&target_path, 5).ok();
50        }
51
52        // Persist with one blank line between every top-level block. serialize()
53        // stays byte-for-byte round-trip faithful for undo and comparison; only
54        // the on-disk file is normalized, healing configs whose blocks were
55        // glued together before purple managed them.
56        let raw = self.serialized_lines();
57        let normalized = ensure_block_separators(&raw);
58        let healed = normalized.len() - raw.len();
59        if healed > 0 {
60            debug!(
61                "[config] ssh_config.write: inserted {healed} block separator(s) in {}",
62                target_path.display()
63            );
64        }
65        let content = self.lines_to_string(&normalized);
66
67        fs_util::atomic_write(&target_path, content.as_bytes())
68            .map_err(|err| {
69                error!(
70                    "[purple] SSH config write failed: {}: {err}",
71                    target_path.display()
72                );
73                err
74            })
75            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
76
77        debug!(
78            "[config] ssh_config.write: wrote {} bytes to {}",
79            content.len(),
80            target_path.display()
81        );
82
83        // Lock released on drop
84        Ok(())
85    }
86
87    /// Serialize the config to a string.
88    /// Collapses consecutive blank lines to prevent accumulation after deletions.
89    /// Round-trip faithful: blank-line layout is preserved exactly as parsed.
90    pub fn serialize(&self) -> String {
91        self.lines_to_string(&self.serialized_lines())
92    }
93
94    /// Flatten the element tree to content lines (no line endings), collapsing
95    /// runs of blank lines to at most one. Shared by `serialize` and `write`.
96    fn serialized_lines(&self) -> Vec<String> {
97        let mut lines = Vec::new();
98
99        for element in &self.elements {
100            match element {
101                ConfigElement::GlobalLine(line) => {
102                    lines.push(line.clone());
103                }
104                ConfigElement::HostBlock(block) => {
105                    lines.push(block.raw_host_line.clone());
106                    for directive in &block.directives {
107                        lines.push(directive.raw_line.clone());
108                    }
109                }
110                ConfigElement::Include(include) => {
111                    lines.push(include.raw_line.clone());
112                }
113            }
114        }
115
116        // Collapse consecutive blank lines (keep at most one)
117        let mut collapsed = Vec::with_capacity(lines.len());
118        let mut prev_blank = false;
119        for line in lines {
120            let is_blank = line.trim().is_empty();
121            if is_blank && prev_blank {
122                continue;
123            }
124            prev_blank = is_blank;
125            collapsed.push(line);
126        }
127        collapsed
128    }
129
130    /// Join content lines with the file's line ending, restoring the BOM and
131    /// guaranteeing exactly one trailing newline.
132    fn lines_to_string(&self, lines: &[String]) -> String {
133        let line_ending = if self.crlf { "\r\n" } else { "\n" };
134        let mut result = String::new();
135        // Restore UTF-8 BOM if the original file had one
136        if self.bom {
137            result.push('\u{FEFF}');
138        }
139        for line in lines {
140            result.push_str(line);
141            result.push_str(line_ending);
142        }
143        // Ensure files always end with exactly one newline
144        // (check lines instead of result, since BOM makes result non-empty)
145        if lines.is_empty() {
146            result.push_str(line_ending);
147        }
148        result
149    }
150
151    /// Create a timestamped backup of the current config file.
152    /// Backup files are created with chmod 600 to match the source file's sensitivity.
153    fn create_backup(&self, target_path: &std::path::Path) -> Result<()> {
154        let timestamp = SystemTime::now()
155            .duration_since(SystemTime::UNIX_EPOCH)
156            .unwrap_or_default()
157            .as_millis();
158        let backup_name = format!(
159            "{}.bak.{}",
160            target_path
161                .file_name()
162                .unwrap_or_default()
163                .to_string_lossy(),
164            timestamp
165        );
166        let backup_path = target_path.with_file_name(backup_name);
167        fs::copy(target_path, &backup_path).with_context(|| {
168            format!(
169                "Failed to copy {} to {}",
170                target_path.display(),
171                backup_path.display()
172            )
173        })?;
174
175        // Set backup permissions to 600 (owner read/write only)
176        #[cfg(unix)]
177        {
178            use std::os::unix::fs::PermissionsExt;
179            if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
180                debug!(
181                    "[config] Failed to set backup permissions on {}: {e}",
182                    backup_path.display()
183                );
184            }
185        }
186
187        Ok(())
188    }
189
190    /// Remove old backups, keeping only the most recent `keep` files.
191    fn prune_backups(&self, target_path: &std::path::Path, keep: usize) -> Result<()> {
192        let parent = target_path.parent().context("No parent directory")?;
193        let prefix = format!(
194            "{}.bak.",
195            target_path
196                .file_name()
197                .unwrap_or_default()
198                .to_string_lossy()
199        );
200        let mut backups: Vec<_> = fs::read_dir(parent)?
201            .filter_map(|e| e.ok())
202            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
203            .collect();
204        // Sort by mtime so prune is robust against future timestamp-digit-width
205        // changes. Filename sort would silently break if the millisecond
206        // suffix length ever grew.
207        backups.sort_by_key(|e| {
208            e.metadata()
209                .and_then(|m| m.modified())
210                .unwrap_or(SystemTime::UNIX_EPOCH)
211        });
212        if backups.len() > keep {
213            for old in &backups[..backups.len() - keep] {
214                if let Err(e) = fs::remove_file(old.path()) {
215                    debug!(
216                        "[config] Failed to prune old backup {}: {e}",
217                        old.path().display()
218                    );
219                }
220            }
221        }
222        Ok(())
223    }
224}
225
226/// True when `line` begins a top-level block (`Host`/`Match` at column 0).
227/// Keyword match is case-insensitive, matching the parser's own detection.
228fn is_block_start(line: &str) -> bool {
229    if line.starts_with(char::is_whitespace) {
230        return false;
231    }
232    match line.split_whitespace().next() {
233        Some(kw) => kw.eq_ignore_ascii_case("Host") || kw.eq_ignore_ascii_case("Match"),
234        None => false,
235    }
236}
237
238/// Insert a single blank line before each top-level block that runs directly
239/// into the previous line, so persisted configs never have glued-together Host
240/// blocks. A block kept glued to a preceding top-level comment (a group header
241/// or hand-written label) is left as-is. Operates on collapsed lines, so it can
242/// never create consecutive blanks.
243fn ensure_block_separators(lines: &[String]) -> Vec<String> {
244    let mut out: Vec<String> = Vec::with_capacity(lines.len() + 4);
245    for line in lines {
246        if is_block_start(line) {
247            if let Some(prev) = out.last() {
248                let prev_blank = prev.trim().is_empty();
249                let prev_top_level_comment =
250                    !prev.starts_with(char::is_whitespace) && prev.trim_start().starts_with('#');
251                if !prev_blank && !prev_top_level_comment {
252                    out.push(String::new());
253                }
254            }
255        }
256        out.push(line.clone());
257    }
258    out
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::ssh_config::model::HostEntry;
265
266    fn parse_str(content: &str) -> SshConfigFile {
267        SshConfigFile {
268            elements: SshConfigFile::parse_content(content),
269            path: tempfile::tempdir()
270                .expect("tempdir")
271                .keep()
272                .join("test_config"),
273            crlf: crate::ssh_config::parser::detect_crlf_majority(content),
274            bom: false,
275        }
276    }
277
278    #[test]
279    fn test_round_trip_basic() {
280        let content = "\
281Host myserver
282  HostName 192.168.1.10
283  User admin
284  Port 2222
285";
286        let config = parse_str(content);
287        assert_eq!(config.serialize(), content);
288    }
289
290    #[test]
291    fn test_round_trip_with_comments() {
292        let content = "\
293# My SSH config
294# Generated by hand
295
296Host alpha
297  HostName alpha.example.com
298  # Deploy user
299  User deploy
300
301Host beta
302  HostName beta.example.com
303  User root
304";
305        let config = parse_str(content);
306        assert_eq!(config.serialize(), content);
307    }
308
309    #[test]
310    fn test_round_trip_with_globals_and_wildcards() {
311        let content = "\
312# Global settings
313Host *
314  ServerAliveInterval 60
315  ServerAliveCountMax 3
316
317Host production
318  HostName prod.example.com
319  User deployer
320  IdentityFile ~/.ssh/prod_key
321";
322        let config = parse_str(content);
323        assert_eq!(config.serialize(), content);
324    }
325
326    #[test]
327    fn test_add_host_serializes() {
328        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
329        config.add_host(&HostEntry {
330            alias: "newhost".to_string(),
331            hostname: "10.0.0.2".to_string(),
332            user: "admin".to_string(),
333            port: 22,
334            ..Default::default()
335        });
336        let output = config.serialize();
337        assert!(output.contains("Host newhost"));
338        assert!(output.contains("HostName 10.0.0.2"));
339        assert!(output.contains("User admin"));
340        // Port 22 is default, should not be written
341        assert!(!output.contains("Port 22"));
342    }
343
344    #[test]
345    fn test_delete_host_serializes() {
346        let content = "\
347Host alpha
348  HostName alpha.example.com
349
350Host beta
351  HostName beta.example.com
352";
353        let mut config = parse_str(content);
354        config.delete_host("alpha");
355        let output = config.serialize();
356        assert!(!output.contains("Host alpha"));
357        assert!(output.contains("Host beta"));
358    }
359
360    #[test]
361    fn test_update_host_serializes() {
362        let content = "\
363Host myserver
364  HostName 10.0.0.1
365  User old_user
366";
367        let mut config = parse_str(content);
368        config.update_host(
369            "myserver",
370            &HostEntry {
371                alias: "myserver".to_string(),
372                hostname: "10.0.0.2".to_string(),
373                user: "new_user".to_string(),
374                port: 22,
375                ..Default::default()
376            },
377        );
378        let output = config.serialize();
379        assert!(output.contains("HostName 10.0.0.2"));
380        assert!(output.contains("User new_user"));
381        assert!(!output.contains("old_user"));
382    }
383
384    #[test]
385    fn test_update_host_preserves_unknown_directives() {
386        let content = "\
387Host myserver
388  HostName 10.0.0.1
389  User admin
390  ForwardAgent yes
391  LocalForward 8080 localhost:80
392  Compression yes
393";
394        let mut config = parse_str(content);
395        config.update_host(
396            "myserver",
397            &HostEntry {
398                alias: "myserver".to_string(),
399                hostname: "10.0.0.2".to_string(),
400                user: "admin".to_string(),
401                port: 22,
402                ..Default::default()
403            },
404        );
405        let output = config.serialize();
406        assert!(output.contains("HostName 10.0.0.2"));
407        assert!(output.contains("ForwardAgent yes"));
408        assert!(output.contains("LocalForward 8080 localhost:80"));
409        assert!(output.contains("Compression yes"));
410    }
411
412    fn lines(s: &[&str]) -> Vec<String> {
413        s.iter().map(|l| (*l).to_string()).collect()
414    }
415
416    #[test]
417    fn ensure_block_separators_splits_glued_hosts() {
418        let input = lines(&["Host a", "  HostName 1", "Host b", "  HostName 2"]);
419        let out = ensure_block_separators(&input);
420        assert_eq!(
421            out,
422            lines(&["Host a", "  HostName 1", "", "Host b", "  HostName 2"])
423        );
424    }
425
426    #[test]
427    fn ensure_block_separators_leaves_separated_hosts() {
428        let input = lines(&["Host a", "  HostName 1", "", "Host b", "  HostName 2"]);
429        let out = ensure_block_separators(&input);
430        assert_eq!(out, input, "already-separated input must be untouched");
431    }
432
433    #[test]
434    fn ensure_block_separators_keeps_group_header_glue() {
435        // A top-level comment (group header) directly above a Host stays glued:
436        // that separation is intentional, not the bug.
437        let input = lines(&["# purple:group DigitalOcean", "Host a", "  HostName 1"]);
438        let out = ensure_block_separators(&input);
439        assert_eq!(out, input);
440    }
441
442    #[test]
443    fn ensure_block_separators_splits_three_glued_hosts() {
444        let input = lines(&["Host a", "  HostName 1", "Host b", "  HostName 2", "Host c"]);
445        let out = ensure_block_separators(&input);
446        assert_eq!(
447            out,
448            lines(&[
449                "Host a",
450                "  HostName 1",
451                "",
452                "Host b",
453                "  HostName 2",
454                "",
455                "Host c",
456            ])
457        );
458    }
459
460    #[test]
461    fn ensure_block_separators_splits_glued_match_block() {
462        let input = lines(&["Host a", "  HostName 1", "Match host b", "  User x"]);
463        let out = ensure_block_separators(&input);
464        assert_eq!(
465            out,
466            lines(&["Host a", "  HostName 1", "", "Match host b", "  User x"])
467        );
468    }
469
470    #[test]
471    fn ensure_block_separators_no_leading_blank() {
472        let input = lines(&["Host a", "  HostName 1"]);
473        let out = ensure_block_separators(&input);
474        assert_eq!(out, input, "must not insert a blank before the first block");
475    }
476
477    #[test]
478    fn write_normalization_is_idempotent() {
479        // Writing a healed config and re-parsing it must produce the same bytes
480        // on a second write. Mirrors the fuzz round-trip/idempotency invariant
481        // for the glued-hosts mutation class.
482        let glued = "Host a\n  HostName 1\nhost b\n  HostName 2\nMatch host c\n  User x\n";
483        let dir = tempfile::tempdir().unwrap();
484        let path = dir.path().join("config");
485        let config = SshConfigFile {
486            elements: SshConfigFile::parse_content(glued),
487            path: path.clone(),
488            crlf: false,
489            bom: false,
490        };
491        config.write().unwrap();
492        let first = std::fs::read_to_string(&path).unwrap();
493
494        let reparsed = SshConfigFile {
495            elements: SshConfigFile::parse_content(&first),
496            path: path.clone(),
497            crlf: false,
498            bom: false,
499        };
500        reparsed.write().unwrap();
501        let second = std::fs::read_to_string(&path).unwrap();
502        assert_eq!(first, second, "write normalization must be idempotent");
503        assert!(!first.contains("\n\n\n"), "no triple blanks:\n{first}");
504    }
505
506    #[test]
507    fn ensure_block_separators_case_insensitive_keyword() {
508        // SSH keywords are case-insensitive; lowercase `host`/`match` must heal
509        // too, matching the parser's own case-insensitive detection.
510        let input = lines(&["host a", "  HostName 1", "MATCH host b", "  User x"]);
511        let out = ensure_block_separators(&input);
512        assert_eq!(
513            out,
514            lines(&["host a", "  HostName 1", "", "MATCH host b", "  User x"])
515        );
516    }
517
518    #[test]
519    fn write_normalizes_glued_hosts_on_disk_serialize_stays_pure() {
520        // serialize() must stay byte-for-byte round-trip faithful (glued stays
521        // glued), but the persisted file gets a blank line between the blocks.
522        let glued = "Host a\n  HostName 1.1.1.1\nHost b\n  HostName 2.2.2.2\n";
523        let dir = tempfile::tempdir().unwrap();
524        let path = dir.path().join("config");
525        let config = SshConfigFile {
526            elements: SshConfigFile::parse_content(glued),
527            path: path.clone(),
528            crlf: false,
529            bom: false,
530        };
531
532        // serialize() is unchanged: still glued.
533        assert_eq!(config.serialize(), glued);
534
535        // write() normalizes: blank line between the two blocks on disk.
536        config.write().unwrap();
537        let on_disk = std::fs::read_to_string(&path).unwrap();
538        assert_eq!(
539            on_disk,
540            "Host a\n  HostName 1.1.1.1\n\nHost b\n  HostName 2.2.2.2\n"
541        );
542    }
543
544    #[test]
545    fn write_normalizes_glued_hosts_preserves_crlf() {
546        let glued = "Host a\r\n  HostName 1.1.1.1\r\nHost b\r\n  HostName 2.2.2.2\r\n";
547        let dir = tempfile::tempdir().unwrap();
548        let path = dir.path().join("config");
549        let config = SshConfigFile {
550            elements: SshConfigFile::parse_content(glued),
551            path: path.clone(),
552            crlf: true,
553            bom: false,
554        };
555        config.write().unwrap();
556        let on_disk = std::fs::read_to_string(&path).unwrap();
557        assert_eq!(
558            on_disk,
559            "Host a\r\n  HostName 1.1.1.1\r\n\r\nHost b\r\n  HostName 2.2.2.2\r\n"
560        );
561    }
562}