Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            return Ok(());
19        }
20        // Resolve symlinks so we write through to the real file
21        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23        // Acquire advisory lock (blocks until available)
24        let _lock =
25            fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27        // Create backup if the file exists, keep only last 5
28        if self.path.exists() {
29            self.create_backup()
30                .context("Failed to create backup of SSH config")?;
31            self.prune_backups(5).ok();
32        }
33
34        let content = self.serialize();
35
36        fs_util::atomic_write(&target_path, content.as_bytes())
37            .map_err(|err| {
38                error!(
39                    "[purple] SSH config write failed: {}: {err}",
40                    target_path.display()
41                );
42                err
43            })
44            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46        // Lock released on drop
47        Ok(())
48    }
49
50    /// Serialize the config to a string.
51    /// Collapses consecutive blank lines to prevent accumulation after deletions.
52    pub fn serialize(&self) -> String {
53        let mut lines = Vec::new();
54
55        for element in &self.elements {
56            match element {
57                ConfigElement::GlobalLine(line) => {
58                    lines.push(line.clone());
59                }
60                ConfigElement::HostBlock(block) => {
61                    lines.push(block.raw_host_line.clone());
62                    for directive in &block.directives {
63                        lines.push(directive.raw_line.clone());
64                    }
65                }
66                ConfigElement::Include(include) => {
67                    lines.push(include.raw_line.clone());
68                }
69            }
70        }
71
72        // Collapse consecutive blank lines (keep at most one)
73        let mut collapsed = Vec::with_capacity(lines.len());
74        let mut prev_blank = false;
75        for line in lines {
76            let is_blank = line.trim().is_empty();
77            if is_blank && prev_blank {
78                continue;
79            }
80            prev_blank = is_blank;
81            collapsed.push(line);
82        }
83
84        let line_ending = if self.crlf { "\r\n" } else { "\n" };
85        let mut result = String::new();
86        // Restore UTF-8 BOM if the original file had one
87        if self.bom {
88            result.push('\u{FEFF}');
89        }
90        for line in &collapsed {
91            result.push_str(line);
92            result.push_str(line_ending);
93        }
94        // Ensure files always end with exactly one newline
95        // (check collapsed instead of result, since BOM makes result non-empty)
96        if collapsed.is_empty() {
97            result.push_str(line_ending);
98        }
99        result
100    }
101
102    /// Create a timestamped backup of the current config file.
103    /// Backup files are created with chmod 600 to match the source file's sensitivity.
104    fn create_backup(&self) -> Result<()> {
105        let timestamp = SystemTime::now()
106            .duration_since(SystemTime::UNIX_EPOCH)
107            .unwrap_or_default()
108            .as_millis();
109        let backup_name = format!(
110            "{}.bak.{}",
111            self.path.file_name().unwrap_or_default().to_string_lossy(),
112            timestamp
113        );
114        let backup_path = self.path.with_file_name(backup_name);
115        fs::copy(&self.path, &backup_path).with_context(|| {
116            format!(
117                "Failed to copy {} to {}",
118                self.path.display(),
119                backup_path.display()
120            )
121        })?;
122
123        // Set backup permissions to 600 (owner read/write only)
124        #[cfg(unix)]
125        {
126            use std::os::unix::fs::PermissionsExt;
127            if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
128                debug!(
129                    "[config] Failed to set backup permissions on {}: {e}",
130                    backup_path.display()
131                );
132            }
133        }
134
135        Ok(())
136    }
137
138    /// Remove old backups, keeping only the most recent `keep` files.
139    fn prune_backups(&self, keep: usize) -> Result<()> {
140        let parent = self.path.parent().context("No parent directory")?;
141        let prefix = format!(
142            "{}.bak.",
143            self.path.file_name().unwrap_or_default().to_string_lossy()
144        );
145        let mut backups: Vec<_> = fs::read_dir(parent)?
146            .filter_map(|e| e.ok())
147            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148            .collect();
149        backups.sort_by_key(|e| e.file_name());
150        if backups.len() > keep {
151            for old in &backups[..backups.len() - keep] {
152                if let Err(e) = fs::remove_file(old.path()) {
153                    debug!(
154                        "[config] Failed to prune old backup {}: {e}",
155                        old.path().display()
156                    );
157                }
158            }
159        }
160        Ok(())
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::ssh_config::model::HostEntry;
168    use std::path::PathBuf;
169
170    fn parse_str(content: &str) -> SshConfigFile {
171        SshConfigFile {
172            elements: SshConfigFile::parse_content(content),
173            path: PathBuf::from("/tmp/test_config"),
174            crlf: content.contains("\r\n"),
175            bom: false,
176        }
177    }
178
179    #[test]
180    fn test_round_trip_basic() {
181        let content = "\
182Host myserver
183  HostName 192.168.1.10
184  User admin
185  Port 2222
186";
187        let config = parse_str(content);
188        assert_eq!(config.serialize(), content);
189    }
190
191    #[test]
192    fn test_round_trip_with_comments() {
193        let content = "\
194# My SSH config
195# Generated by hand
196
197Host alpha
198  HostName alpha.example.com
199  # Deploy user
200  User deploy
201
202Host beta
203  HostName beta.example.com
204  User root
205";
206        let config = parse_str(content);
207        assert_eq!(config.serialize(), content);
208    }
209
210    #[test]
211    fn test_round_trip_with_globals_and_wildcards() {
212        let content = "\
213# Global settings
214Host *
215  ServerAliveInterval 60
216  ServerAliveCountMax 3
217
218Host production
219  HostName prod.example.com
220  User deployer
221  IdentityFile ~/.ssh/prod_key
222";
223        let config = parse_str(content);
224        assert_eq!(config.serialize(), content);
225    }
226
227    #[test]
228    fn test_add_host_serializes() {
229        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
230        config.add_host(&HostEntry {
231            alias: "newhost".to_string(),
232            hostname: "10.0.0.2".to_string(),
233            user: "admin".to_string(),
234            port: 22,
235            ..Default::default()
236        });
237        let output = config.serialize();
238        assert!(output.contains("Host newhost"));
239        assert!(output.contains("HostName 10.0.0.2"));
240        assert!(output.contains("User admin"));
241        // Port 22 is default, should not be written
242        assert!(!output.contains("Port 22"));
243    }
244
245    #[test]
246    fn test_delete_host_serializes() {
247        let content = "\
248Host alpha
249  HostName alpha.example.com
250
251Host beta
252  HostName beta.example.com
253";
254        let mut config = parse_str(content);
255        config.delete_host("alpha");
256        let output = config.serialize();
257        assert!(!output.contains("Host alpha"));
258        assert!(output.contains("Host beta"));
259    }
260
261    #[test]
262    fn test_update_host_serializes() {
263        let content = "\
264Host myserver
265  HostName 10.0.0.1
266  User old_user
267";
268        let mut config = parse_str(content);
269        config.update_host(
270            "myserver",
271            &HostEntry {
272                alias: "myserver".to_string(),
273                hostname: "10.0.0.2".to_string(),
274                user: "new_user".to_string(),
275                port: 22,
276                ..Default::default()
277            },
278        );
279        let output = config.serialize();
280        assert!(output.contains("HostName 10.0.0.2"));
281        assert!(output.contains("User new_user"));
282        assert!(!output.contains("old_user"));
283    }
284
285    #[test]
286    fn test_update_host_preserves_unknown_directives() {
287        let content = "\
288Host myserver
289  HostName 10.0.0.1
290  User admin
291  ForwardAgent yes
292  LocalForward 8080 localhost:80
293  Compression yes
294";
295        let mut config = parse_str(content);
296        config.update_host(
297            "myserver",
298            &HostEntry {
299                alias: "myserver".to_string(),
300                hostname: "10.0.0.2".to_string(),
301                user: "admin".to_string(),
302                port: 22,
303                ..Default::default()
304            },
305        );
306        let output = config.serialize();
307        assert!(output.contains("HostName 10.0.0.2"));
308        assert!(output.contains("ForwardAgent yes"));
309        assert!(output.contains("LocalForward 8080 localhost:80"));
310        assert!(output.contains("Compression yes"));
311    }
312}