Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            return Ok(());
19        }
20        // Resolve symlinks so we write through to the real file
21        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23        // Acquire advisory lock (blocks until available)
24        let _lock =
25            fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27        // Create backup if the file exists, keep only last 5
28        if self.path.exists() {
29            self.create_backup()
30                .context("Failed to create backup of SSH config")?;
31            self.prune_backups(5).ok();
32        }
33
34        let content = self.serialize();
35
36        fs_util::atomic_write(&target_path, content.as_bytes())
37            .map_err(|err| {
38                error!(
39                    "[purple] SSH config write failed: {}: {err}",
40                    target_path.display()
41                );
42                err
43            })
44            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46        // Lock released on drop
47        Ok(())
48    }
49
50    /// Serialize the config to a string.
51    /// Collapses consecutive blank lines to prevent accumulation after deletions.
52    pub fn serialize(&self) -> String {
53        let mut lines = Vec::new();
54
55        for element in &self.elements {
56            match element {
57                ConfigElement::GlobalLine(line) => {
58                    lines.push(line.clone());
59                }
60                ConfigElement::HostBlock(block) => {
61                    lines.push(block.raw_host_line.clone());
62                    for directive in &block.directives {
63                        lines.push(directive.raw_line.clone());
64                    }
65                }
66                ConfigElement::Include(include) => {
67                    lines.push(include.raw_line.clone());
68                }
69            }
70        }
71
72        // Collapse consecutive blank lines (keep at most one)
73        let mut collapsed = Vec::with_capacity(lines.len());
74        let mut prev_blank = false;
75        for line in lines {
76            let is_blank = line.trim().is_empty();
77            if is_blank && prev_blank {
78                continue;
79            }
80            prev_blank = is_blank;
81            collapsed.push(line);
82        }
83
84        let line_ending = if self.crlf { "\r\n" } else { "\n" };
85        let mut result = String::new();
86        // Restore UTF-8 BOM if the original file had one
87        if self.bom {
88            result.push('\u{FEFF}');
89        }
90        for line in &collapsed {
91            result.push_str(line);
92            result.push_str(line_ending);
93        }
94        // Ensure files always end with exactly one newline
95        // (check collapsed instead of result, since BOM makes result non-empty)
96        if collapsed.is_empty() {
97            result.push_str(line_ending);
98        }
99        result
100    }
101
102    /// Create a timestamped backup of the current config file.
103    /// Backup files are created with chmod 600 to match the source file's sensitivity.
104    fn create_backup(&self) -> Result<()> {
105        let timestamp = SystemTime::now()
106            .duration_since(SystemTime::UNIX_EPOCH)
107            .unwrap_or_default()
108            .as_millis();
109        let backup_name = format!(
110            "{}.bak.{}",
111            self.path.file_name().unwrap_or_default().to_string_lossy(),
112            timestamp
113        );
114        let backup_path = self.path.with_file_name(backup_name);
115        fs::copy(&self.path, &backup_path).with_context(|| {
116            format!(
117                "Failed to copy {} to {}",
118                self.path.display(),
119                backup_path.display()
120            )
121        })?;
122
123        // Set backup permissions to 600 (owner read/write only)
124        #[cfg(unix)]
125        {
126            use std::os::unix::fs::PermissionsExt;
127            if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
128                debug!(
129                    "[config] Failed to set backup permissions on {}: {e}",
130                    backup_path.display()
131                );
132            }
133        }
134
135        Ok(())
136    }
137
138    /// Remove old backups, keeping only the most recent `keep` files.
139    fn prune_backups(&self, keep: usize) -> Result<()> {
140        let parent = self.path.parent().context("No parent directory")?;
141        let prefix = format!(
142            "{}.bak.",
143            self.path.file_name().unwrap_or_default().to_string_lossy()
144        );
145        let mut backups: Vec<_> = fs::read_dir(parent)?
146            .filter_map(|e| e.ok())
147            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148            .collect();
149        backups.sort_by_key(|e| e.file_name());
150        if backups.len() > keep {
151            for old in &backups[..backups.len() - keep] {
152                if let Err(e) = fs::remove_file(old.path()) {
153                    debug!(
154                        "[config] Failed to prune old backup {}: {e}",
155                        old.path().display()
156                    );
157                }
158            }
159        }
160        Ok(())
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::ssh_config::model::HostEntry;
168
169    fn parse_str(content: &str) -> SshConfigFile {
170        SshConfigFile {
171            elements: SshConfigFile::parse_content(content),
172            path: tempfile::tempdir()
173                .expect("tempdir")
174                .keep()
175                .join("test_config"),
176            crlf: content.contains("\r\n"),
177            bom: false,
178        }
179    }
180
181    #[test]
182    fn test_round_trip_basic() {
183        let content = "\
184Host myserver
185  HostName 192.168.1.10
186  User admin
187  Port 2222
188";
189        let config = parse_str(content);
190        assert_eq!(config.serialize(), content);
191    }
192
193    #[test]
194    fn test_round_trip_with_comments() {
195        let content = "\
196# My SSH config
197# Generated by hand
198
199Host alpha
200  HostName alpha.example.com
201  # Deploy user
202  User deploy
203
204Host beta
205  HostName beta.example.com
206  User root
207";
208        let config = parse_str(content);
209        assert_eq!(config.serialize(), content);
210    }
211
212    #[test]
213    fn test_round_trip_with_globals_and_wildcards() {
214        let content = "\
215# Global settings
216Host *
217  ServerAliveInterval 60
218  ServerAliveCountMax 3
219
220Host production
221  HostName prod.example.com
222  User deployer
223  IdentityFile ~/.ssh/prod_key
224";
225        let config = parse_str(content);
226        assert_eq!(config.serialize(), content);
227    }
228
229    #[test]
230    fn test_add_host_serializes() {
231        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
232        config.add_host(&HostEntry {
233            alias: "newhost".to_string(),
234            hostname: "10.0.0.2".to_string(),
235            user: "admin".to_string(),
236            port: 22,
237            ..Default::default()
238        });
239        let output = config.serialize();
240        assert!(output.contains("Host newhost"));
241        assert!(output.contains("HostName 10.0.0.2"));
242        assert!(output.contains("User admin"));
243        // Port 22 is default, should not be written
244        assert!(!output.contains("Port 22"));
245    }
246
247    #[test]
248    fn test_delete_host_serializes() {
249        let content = "\
250Host alpha
251  HostName alpha.example.com
252
253Host beta
254  HostName beta.example.com
255";
256        let mut config = parse_str(content);
257        config.delete_host("alpha");
258        let output = config.serialize();
259        assert!(!output.contains("Host alpha"));
260        assert!(output.contains("Host beta"));
261    }
262
263    #[test]
264    fn test_update_host_serializes() {
265        let content = "\
266Host myserver
267  HostName 10.0.0.1
268  User old_user
269";
270        let mut config = parse_str(content);
271        config.update_host(
272            "myserver",
273            &HostEntry {
274                alias: "myserver".to_string(),
275                hostname: "10.0.0.2".to_string(),
276                user: "new_user".to_string(),
277                port: 22,
278                ..Default::default()
279            },
280        );
281        let output = config.serialize();
282        assert!(output.contains("HostName 10.0.0.2"));
283        assert!(output.contains("User new_user"));
284        assert!(!output.contains("old_user"));
285    }
286
287    #[test]
288    fn test_update_host_preserves_unknown_directives() {
289        let content = "\
290Host myserver
291  HostName 10.0.0.1
292  User admin
293  ForwardAgent yes
294  LocalForward 8080 localhost:80
295  Compression yes
296";
297        let mut config = parse_str(content);
298        config.update_host(
299            "myserver",
300            &HostEntry {
301                alias: "myserver".to_string(),
302                hostname: "10.0.0.2".to_string(),
303                user: "admin".to_string(),
304                port: 22,
305                ..Default::default()
306            },
307        );
308        let output = config.serialize();
309        assert!(output.contains("HostName 10.0.0.2"));
310        assert!(output.contains("ForwardAgent yes"));
311        assert!(output.contains("LocalForward 8080 localhost:80"));
312        assert!(output.contains("Compression yes"));
313    }
314}