Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7use crate::fs_util;
8
9impl SshConfigFile {
10    /// Write the config back to disk.
11    /// Creates a backup before writing and uses atomic write (temp file + rename).
12    /// Resolves symlinks so the rename targets the real file, not the link.
13    /// Acquires an advisory lock to prevent concurrent writes from multiple
14    /// purple processes or background sync threads.
15    pub fn write(&self) -> Result<()> {
16        // Resolve symlinks so we write through to the real file
17        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
18
19        // Acquire advisory lock (blocks until available)
20        let _lock =
21            fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
22
23        // Create backup if the file exists, keep only last 5
24        if self.path.exists() {
25            self.create_backup()
26                .context("Failed to create backup of SSH config")?;
27            self.prune_backups(5).ok();
28        }
29
30        let content = self.serialize();
31
32        fs_util::atomic_write(&target_path, content.as_bytes())
33            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
34
35        // Lock released on drop
36        Ok(())
37    }
38
39    /// Serialize the config to a string.
40    /// Collapses consecutive blank lines to prevent accumulation after deletions.
41    pub fn serialize(&self) -> String {
42        let mut lines = Vec::new();
43
44        for element in &self.elements {
45            match element {
46                ConfigElement::GlobalLine(line) => {
47                    lines.push(line.clone());
48                }
49                ConfigElement::HostBlock(block) => {
50                    lines.push(block.raw_host_line.clone());
51                    for directive in &block.directives {
52                        lines.push(directive.raw_line.clone());
53                    }
54                }
55                ConfigElement::Include(include) => {
56                    lines.push(include.raw_line.clone());
57                }
58            }
59        }
60
61        // Collapse consecutive blank lines (keep at most one)
62        let mut collapsed = Vec::with_capacity(lines.len());
63        let mut prev_blank = false;
64        for line in lines {
65            let is_blank = line.trim().is_empty();
66            if is_blank && prev_blank {
67                continue;
68            }
69            prev_blank = is_blank;
70            collapsed.push(line);
71        }
72
73        let line_ending = if self.crlf { "\r\n" } else { "\n" };
74        let mut result = String::new();
75        // Restore UTF-8 BOM if the original file had one
76        if self.bom {
77            result.push('\u{FEFF}');
78        }
79        for line in &collapsed {
80            result.push_str(line);
81            result.push_str(line_ending);
82        }
83        // Ensure files always end with exactly one newline
84        // (check collapsed instead of result, since BOM makes result non-empty)
85        if collapsed.is_empty() {
86            result.push_str(line_ending);
87        }
88        result
89    }
90
91    /// Create a timestamped backup of the current config file.
92    /// Backup files are created with chmod 600 to match the source file's sensitivity.
93    fn create_backup(&self) -> Result<()> {
94        let timestamp = SystemTime::now()
95            .duration_since(SystemTime::UNIX_EPOCH)
96            .unwrap_or_default()
97            .as_millis();
98        let backup_name = format!(
99            "{}.bak.{}",
100            self.path.file_name().unwrap_or_default().to_string_lossy(),
101            timestamp
102        );
103        let backup_path = self.path.with_file_name(backup_name);
104        fs::copy(&self.path, &backup_path).with_context(|| {
105            format!(
106                "Failed to copy {} to {}",
107                self.path.display(),
108                backup_path.display()
109            )
110        })?;
111
112        // Set backup permissions to 600 (owner read/write only)
113        #[cfg(unix)]
114        {
115            use std::os::unix::fs::PermissionsExt;
116            let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
117        }
118
119        Ok(())
120    }
121
122    /// Remove old backups, keeping only the most recent `keep` files.
123    fn prune_backups(&self, keep: usize) -> Result<()> {
124        let parent = self.path.parent().context("No parent directory")?;
125        let prefix = format!(
126            "{}.bak.",
127            self.path.file_name().unwrap_or_default().to_string_lossy()
128        );
129        let mut backups: Vec<_> = fs::read_dir(parent)?
130            .filter_map(|e| e.ok())
131            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
132            .collect();
133        backups.sort_by_key(|e| e.file_name());
134        if backups.len() > keep {
135            for old in &backups[..backups.len() - keep] {
136                let _ = fs::remove_file(old.path());
137            }
138        }
139        Ok(())
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::ssh_config::model::HostEntry;
147    use std::path::PathBuf;
148
149    fn parse_str(content: &str) -> SshConfigFile {
150        SshConfigFile {
151            elements: SshConfigFile::parse_content(content),
152            path: PathBuf::from("/tmp/test_config"),
153            crlf: content.contains("\r\n"),
154            bom: false,
155        }
156    }
157
158    #[test]
159    fn test_round_trip_basic() {
160        let content = "\
161Host myserver
162  HostName 192.168.1.10
163  User admin
164  Port 2222
165";
166        let config = parse_str(content);
167        assert_eq!(config.serialize(), content);
168    }
169
170    #[test]
171    fn test_round_trip_with_comments() {
172        let content = "\
173# My SSH config
174# Generated by hand
175
176Host alpha
177  HostName alpha.example.com
178  # Deploy user
179  User deploy
180
181Host beta
182  HostName beta.example.com
183  User root
184";
185        let config = parse_str(content);
186        assert_eq!(config.serialize(), content);
187    }
188
189    #[test]
190    fn test_round_trip_with_globals_and_wildcards() {
191        let content = "\
192# Global settings
193Host *
194  ServerAliveInterval 60
195  ServerAliveCountMax 3
196
197Host production
198  HostName prod.example.com
199  User deployer
200  IdentityFile ~/.ssh/prod_key
201";
202        let config = parse_str(content);
203        assert_eq!(config.serialize(), content);
204    }
205
206    #[test]
207    fn test_add_host_serializes() {
208        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
209        config.add_host(&HostEntry {
210            alias: "newhost".to_string(),
211            hostname: "10.0.0.2".to_string(),
212            user: "admin".to_string(),
213            port: 22,
214            ..Default::default()
215        });
216        let output = config.serialize();
217        assert!(output.contains("Host newhost"));
218        assert!(output.contains("HostName 10.0.0.2"));
219        assert!(output.contains("User admin"));
220        // Port 22 is default, should not be written
221        assert!(!output.contains("Port 22"));
222    }
223
224    #[test]
225    fn test_delete_host_serializes() {
226        let content = "\
227Host alpha
228  HostName alpha.example.com
229
230Host beta
231  HostName beta.example.com
232";
233        let mut config = parse_str(content);
234        config.delete_host("alpha");
235        let output = config.serialize();
236        assert!(!output.contains("Host alpha"));
237        assert!(output.contains("Host beta"));
238    }
239
240    #[test]
241    fn test_update_host_serializes() {
242        let content = "\
243Host myserver
244  HostName 10.0.0.1
245  User old_user
246";
247        let mut config = parse_str(content);
248        config.update_host(
249            "myserver",
250            &HostEntry {
251                alias: "myserver".to_string(),
252                hostname: "10.0.0.2".to_string(),
253                user: "new_user".to_string(),
254                port: 22,
255                ..Default::default()
256            },
257        );
258        let output = config.serialize();
259        assert!(output.contains("HostName 10.0.0.2"));
260        assert!(output.contains("User new_user"));
261        assert!(!output.contains("old_user"));
262    }
263
264    #[test]
265    fn test_update_host_preserves_unknown_directives() {
266        let content = "\
267Host myserver
268  HostName 10.0.0.1
269  User admin
270  ForwardAgent yes
271  LocalForward 8080 localhost:80
272  Compression yes
273";
274        let mut config = parse_str(content);
275        config.update_host(
276            "myserver",
277            &HostEntry {
278                alias: "myserver".to_string(),
279                hostname: "10.0.0.2".to_string(),
280                user: "admin".to_string(),
281                port: 22,
282                ..Default::default()
283            },
284        );
285        let output = config.serialize();
286        assert!(output.contains("HostName 10.0.0.2"));
287        assert!(output.contains("ForwardAgent yes"));
288        assert!(output.contains("LocalForward 8080 localhost:80"));
289        assert!(output.contains("Compression yes"));
290    }
291}