Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::error;
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11    /// Write the config back to disk.
12    /// Creates a backup before writing and uses atomic write (temp file + rename).
13    /// Resolves symlinks so the rename targets the real file, not the link.
14    /// Acquires an advisory lock to prevent concurrent writes from multiple
15    /// purple processes or background sync threads.
16    pub fn write(&self) -> Result<()> {
17        if crate::demo_flag::is_demo() {
18            return Ok(());
19        }
20        // Resolve symlinks so we write through to the real file
21        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23        // Acquire advisory lock (blocks until available)
24        let _lock =
25            fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27        // Create backup if the file exists, keep only last 5
28        if self.path.exists() {
29            self.create_backup()
30                .context("Failed to create backup of SSH config")?;
31            self.prune_backups(5).ok();
32        }
33
34        let content = self.serialize();
35
36        fs_util::atomic_write(&target_path, content.as_bytes())
37            .map_err(|err| {
38                error!(
39                    "[purple] SSH config write failed: {}: {err}",
40                    target_path.display()
41                );
42                err
43            })
44            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46        // Lock released on drop
47        Ok(())
48    }
49
50    /// Serialize the config to a string.
51    /// Collapses consecutive blank lines to prevent accumulation after deletions.
52    pub fn serialize(&self) -> String {
53        let mut lines = Vec::new();
54
55        for element in &self.elements {
56            match element {
57                ConfigElement::GlobalLine(line) => {
58                    lines.push(line.clone());
59                }
60                ConfigElement::HostBlock(block) => {
61                    lines.push(block.raw_host_line.clone());
62                    for directive in &block.directives {
63                        lines.push(directive.raw_line.clone());
64                    }
65                }
66                ConfigElement::Include(include) => {
67                    lines.push(include.raw_line.clone());
68                }
69            }
70        }
71
72        // Collapse consecutive blank lines (keep at most one)
73        let mut collapsed = Vec::with_capacity(lines.len());
74        let mut prev_blank = false;
75        for line in lines {
76            let is_blank = line.trim().is_empty();
77            if is_blank && prev_blank {
78                continue;
79            }
80            prev_blank = is_blank;
81            collapsed.push(line);
82        }
83
84        let line_ending = if self.crlf { "\r\n" } else { "\n" };
85        let mut result = String::new();
86        // Restore UTF-8 BOM if the original file had one
87        if self.bom {
88            result.push('\u{FEFF}');
89        }
90        for line in &collapsed {
91            result.push_str(line);
92            result.push_str(line_ending);
93        }
94        // Ensure files always end with exactly one newline
95        // (check collapsed instead of result, since BOM makes result non-empty)
96        if collapsed.is_empty() {
97            result.push_str(line_ending);
98        }
99        result
100    }
101
102    /// Create a timestamped backup of the current config file.
103    /// Backup files are created with chmod 600 to match the source file's sensitivity.
104    fn create_backup(&self) -> Result<()> {
105        let timestamp = SystemTime::now()
106            .duration_since(SystemTime::UNIX_EPOCH)
107            .unwrap_or_default()
108            .as_millis();
109        let backup_name = format!(
110            "{}.bak.{}",
111            self.path.file_name().unwrap_or_default().to_string_lossy(),
112            timestamp
113        );
114        let backup_path = self.path.with_file_name(backup_name);
115        fs::copy(&self.path, &backup_path).with_context(|| {
116            format!(
117                "Failed to copy {} to {}",
118                self.path.display(),
119                backup_path.display()
120            )
121        })?;
122
123        // Set backup permissions to 600 (owner read/write only)
124        #[cfg(unix)]
125        {
126            use std::os::unix::fs::PermissionsExt;
127            let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
128        }
129
130        Ok(())
131    }
132
133    /// Remove old backups, keeping only the most recent `keep` files.
134    fn prune_backups(&self, keep: usize) -> Result<()> {
135        let parent = self.path.parent().context("No parent directory")?;
136        let prefix = format!(
137            "{}.bak.",
138            self.path.file_name().unwrap_or_default().to_string_lossy()
139        );
140        let mut backups: Vec<_> = fs::read_dir(parent)?
141            .filter_map(|e| e.ok())
142            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
143            .collect();
144        backups.sort_by_key(|e| e.file_name());
145        if backups.len() > keep {
146            for old in &backups[..backups.len() - keep] {
147                let _ = fs::remove_file(old.path());
148            }
149        }
150        Ok(())
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::ssh_config::model::HostEntry;
158    use std::path::PathBuf;
159
160    fn parse_str(content: &str) -> SshConfigFile {
161        SshConfigFile {
162            elements: SshConfigFile::parse_content(content),
163            path: PathBuf::from("/tmp/test_config"),
164            crlf: content.contains("\r\n"),
165            bom: false,
166        }
167    }
168
169    #[test]
170    fn test_round_trip_basic() {
171        let content = "\
172Host myserver
173  HostName 192.168.1.10
174  User admin
175  Port 2222
176";
177        let config = parse_str(content);
178        assert_eq!(config.serialize(), content);
179    }
180
181    #[test]
182    fn test_round_trip_with_comments() {
183        let content = "\
184# My SSH config
185# Generated by hand
186
187Host alpha
188  HostName alpha.example.com
189  # Deploy user
190  User deploy
191
192Host beta
193  HostName beta.example.com
194  User root
195";
196        let config = parse_str(content);
197        assert_eq!(config.serialize(), content);
198    }
199
200    #[test]
201    fn test_round_trip_with_globals_and_wildcards() {
202        let content = "\
203# Global settings
204Host *
205  ServerAliveInterval 60
206  ServerAliveCountMax 3
207
208Host production
209  HostName prod.example.com
210  User deployer
211  IdentityFile ~/.ssh/prod_key
212";
213        let config = parse_str(content);
214        assert_eq!(config.serialize(), content);
215    }
216
217    #[test]
218    fn test_add_host_serializes() {
219        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
220        config.add_host(&HostEntry {
221            alias: "newhost".to_string(),
222            hostname: "10.0.0.2".to_string(),
223            user: "admin".to_string(),
224            port: 22,
225            ..Default::default()
226        });
227        let output = config.serialize();
228        assert!(output.contains("Host newhost"));
229        assert!(output.contains("HostName 10.0.0.2"));
230        assert!(output.contains("User admin"));
231        // Port 22 is default, should not be written
232        assert!(!output.contains("Port 22"));
233    }
234
235    #[test]
236    fn test_delete_host_serializes() {
237        let content = "\
238Host alpha
239  HostName alpha.example.com
240
241Host beta
242  HostName beta.example.com
243";
244        let mut config = parse_str(content);
245        config.delete_host("alpha");
246        let output = config.serialize();
247        assert!(!output.contains("Host alpha"));
248        assert!(output.contains("Host beta"));
249    }
250
251    #[test]
252    fn test_update_host_serializes() {
253        let content = "\
254Host myserver
255  HostName 10.0.0.1
256  User old_user
257";
258        let mut config = parse_str(content);
259        config.update_host(
260            "myserver",
261            &HostEntry {
262                alias: "myserver".to_string(),
263                hostname: "10.0.0.2".to_string(),
264                user: "new_user".to_string(),
265                port: 22,
266                ..Default::default()
267            },
268        );
269        let output = config.serialize();
270        assert!(output.contains("HostName 10.0.0.2"));
271        assert!(output.contains("User new_user"));
272        assert!(!output.contains("old_user"));
273    }
274
275    #[test]
276    fn test_update_host_preserves_unknown_directives() {
277        let content = "\
278Host myserver
279  HostName 10.0.0.1
280  User admin
281  ForwardAgent yes
282  LocalForward 8080 localhost:80
283  Compression yes
284";
285        let mut config = parse_str(content);
286        config.update_host(
287            "myserver",
288            &HostEntry {
289                alias: "myserver".to_string(),
290                hostname: "10.0.0.2".to_string(),
291                user: "admin".to_string(),
292                port: 22,
293                ..Default::default()
294            },
295        );
296        let output = config.serialize();
297        assert!(output.contains("HostName 10.0.0.2"));
298        assert!(output.contains("ForwardAgent yes"));
299        assert!(output.contains("LocalForward 8080 localhost:80"));
300        assert!(output.contains("Compression yes"));
301    }
302}