Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            debug!("[purple] ssh_config.write skipped (demo mode)");
19            return Ok(());
20        }
21        // Resolve symlinks so we write through to the real file
22        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23        debug!(
24            "[config] ssh_config.write: target={} elements={}",
25            target_path.display(),
26            self.elements.len()
27        );
28
29        // Acquire advisory lock (blocks until available)
30        let _lock = fs_util::FileLock::acquire(&target_path)
31            .inspect_err(|e| {
32                debug!(
33                    "[config] ssh_config.write: lock acquire failed: {} ({})",
34                    target_path.display(),
35                    e
36                );
37            })
38            .context("Failed to acquire config lock")?;
39
40        // Create backup if the file exists, keep only last 5
41        if self.path.exists() {
42            self.create_backup()
43                .context("Failed to create backup of SSH config")?;
44            self.prune_backups(5).ok();
45        }
46
47        let content = self.serialize();
48
49        fs_util::atomic_write(&target_path, content.as_bytes())
50            .map_err(|err| {
51                error!(
52                    "[purple] SSH config write failed: {}: {err}",
53                    target_path.display()
54                );
55                err
56            })
57            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
58
59        debug!(
60            "[config] ssh_config.write: wrote {} bytes to {}",
61            content.len(),
62            target_path.display()
63        );
64
65        // Lock released on drop
66        Ok(())
67    }
68
69    /// Serialize the config to a string.
70    /// Collapses consecutive blank lines to prevent accumulation after deletions.
71    pub fn serialize(&self) -> String {
72        let mut lines = Vec::new();
73
74        for element in &self.elements {
75            match element {
76                ConfigElement::GlobalLine(line) => {
77                    lines.push(line.clone());
78                }
79                ConfigElement::HostBlock(block) => {
80                    lines.push(block.raw_host_line.clone());
81                    for directive in &block.directives {
82                        lines.push(directive.raw_line.clone());
83                    }
84                }
85                ConfigElement::Include(include) => {
86                    lines.push(include.raw_line.clone());
87                }
88            }
89        }
90
91        // Collapse consecutive blank lines (keep at most one)
92        let mut collapsed = Vec::with_capacity(lines.len());
93        let mut prev_blank = false;
94        for line in lines {
95            let is_blank = line.trim().is_empty();
96            if is_blank && prev_blank {
97                continue;
98            }
99            prev_blank = is_blank;
100            collapsed.push(line);
101        }
102
103        let line_ending = if self.crlf { "\r\n" } else { "\n" };
104        let mut result = String::new();
105        // Restore UTF-8 BOM if the original file had one
106        if self.bom {
107            result.push('\u{FEFF}');
108        }
109        for line in &collapsed {
110            result.push_str(line);
111            result.push_str(line_ending);
112        }
113        // Ensure files always end with exactly one newline
114        // (check collapsed instead of result, since BOM makes result non-empty)
115        if collapsed.is_empty() {
116            result.push_str(line_ending);
117        }
118        result
119    }
120
121    /// Create a timestamped backup of the current config file.
122    /// Backup files are created with chmod 600 to match the source file's sensitivity.
123    fn create_backup(&self) -> Result<()> {
124        let timestamp = SystemTime::now()
125            .duration_since(SystemTime::UNIX_EPOCH)
126            .unwrap_or_default()
127            .as_millis();
128        let backup_name = format!(
129            "{}.bak.{}",
130            self.path.file_name().unwrap_or_default().to_string_lossy(),
131            timestamp
132        );
133        let backup_path = self.path.with_file_name(backup_name);
134        fs::copy(&self.path, &backup_path).with_context(|| {
135            format!(
136                "Failed to copy {} to {}",
137                self.path.display(),
138                backup_path.display()
139            )
140        })?;
141
142        // Set backup permissions to 600 (owner read/write only)
143        #[cfg(unix)]
144        {
145            use std::os::unix::fs::PermissionsExt;
146            if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
147                debug!(
148                    "[config] Failed to set backup permissions on {}: {e}",
149                    backup_path.display()
150                );
151            }
152        }
153
154        Ok(())
155    }
156
157    /// Remove old backups, keeping only the most recent `keep` files.
158    fn prune_backups(&self, keep: usize) -> Result<()> {
159        let parent = self.path.parent().context("No parent directory")?;
160        let prefix = format!(
161            "{}.bak.",
162            self.path.file_name().unwrap_or_default().to_string_lossy()
163        );
164        let mut backups: Vec<_> = fs::read_dir(parent)?
165            .filter_map(|e| e.ok())
166            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
167            .collect();
168        backups.sort_by_key(|e| e.file_name());
169        if backups.len() > keep {
170            for old in &backups[..backups.len() - keep] {
171                if let Err(e) = fs::remove_file(old.path()) {
172                    debug!(
173                        "[config] Failed to prune old backup {}: {e}",
174                        old.path().display()
175                    );
176                }
177            }
178        }
179        Ok(())
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::ssh_config::model::HostEntry;
187
188    fn parse_str(content: &str) -> SshConfigFile {
189        SshConfigFile {
190            elements: SshConfigFile::parse_content(content),
191            path: tempfile::tempdir()
192                .expect("tempdir")
193                .keep()
194                .join("test_config"),
195            crlf: content.contains("\r\n"),
196            bom: false,
197        }
198    }
199
200    #[test]
201    fn test_round_trip_basic() {
202        let content = "\
203Host myserver
204  HostName 192.168.1.10
205  User admin
206  Port 2222
207";
208        let config = parse_str(content);
209        assert_eq!(config.serialize(), content);
210    }
211
212    #[test]
213    fn test_round_trip_with_comments() {
214        let content = "\
215# My SSH config
216# Generated by hand
217
218Host alpha
219  HostName alpha.example.com
220  # Deploy user
221  User deploy
222
223Host beta
224  HostName beta.example.com
225  User root
226";
227        let config = parse_str(content);
228        assert_eq!(config.serialize(), content);
229    }
230
231    #[test]
232    fn test_round_trip_with_globals_and_wildcards() {
233        let content = "\
234# Global settings
235Host *
236  ServerAliveInterval 60
237  ServerAliveCountMax 3
238
239Host production
240  HostName prod.example.com
241  User deployer
242  IdentityFile ~/.ssh/prod_key
243";
244        let config = parse_str(content);
245        assert_eq!(config.serialize(), content);
246    }
247
248    #[test]
249    fn test_add_host_serializes() {
250        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
251        config.add_host(&HostEntry {
252            alias: "newhost".to_string(),
253            hostname: "10.0.0.2".to_string(),
254            user: "admin".to_string(),
255            port: 22,
256            ..Default::default()
257        });
258        let output = config.serialize();
259        assert!(output.contains("Host newhost"));
260        assert!(output.contains("HostName 10.0.0.2"));
261        assert!(output.contains("User admin"));
262        // Port 22 is default, should not be written
263        assert!(!output.contains("Port 22"));
264    }
265
266    #[test]
267    fn test_delete_host_serializes() {
268        let content = "\
269Host alpha
270  HostName alpha.example.com
271
272Host beta
273  HostName beta.example.com
274";
275        let mut config = parse_str(content);
276        config.delete_host("alpha");
277        let output = config.serialize();
278        assert!(!output.contains("Host alpha"));
279        assert!(output.contains("Host beta"));
280    }
281
282    #[test]
283    fn test_update_host_serializes() {
284        let content = "\
285Host myserver
286  HostName 10.0.0.1
287  User old_user
288";
289        let mut config = parse_str(content);
290        config.update_host(
291            "myserver",
292            &HostEntry {
293                alias: "myserver".to_string(),
294                hostname: "10.0.0.2".to_string(),
295                user: "new_user".to_string(),
296                port: 22,
297                ..Default::default()
298            },
299        );
300        let output = config.serialize();
301        assert!(output.contains("HostName 10.0.0.2"));
302        assert!(output.contains("User new_user"));
303        assert!(!output.contains("old_user"));
304    }
305
306    #[test]
307    fn test_update_host_preserves_unknown_directives() {
308        let content = "\
309Host myserver
310  HostName 10.0.0.1
311  User admin
312  ForwardAgent yes
313  LocalForward 8080 localhost:80
314  Compression yes
315";
316        let mut config = parse_str(content);
317        config.update_host(
318            "myserver",
319            &HostEntry {
320                alias: "myserver".to_string(),
321                hostname: "10.0.0.2".to_string(),
322                user: "admin".to_string(),
323                port: 22,
324                ..Default::default()
325            },
326        );
327        let output = config.serialize();
328        assert!(output.contains("HostName 10.0.0.2"));
329        assert!(output.contains("ForwardAgent yes"));
330        assert!(output.contains("LocalForward 8080 localhost:80"));
331        assert!(output.contains("Compression yes"));
332    }
333}