Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7use crate::fs_util;
8
9impl SshConfigFile {
10    /// Write the config back to disk.
11    /// Creates a backup before writing and uses atomic write (temp file + rename).
12    /// Resolves symlinks so the rename targets the real file, not the link.
13    /// Acquires an advisory lock to prevent concurrent writes from multiple
14    /// purple processes or background sync threads.
15    pub fn write(&self) -> Result<()> {
16        if crate::demo_flag::is_demo() {
17            return Ok(());
18        }
19        // Resolve symlinks so we write through to the real file
20        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
21
22        // Acquire advisory lock (blocks until available)
23        let _lock =
24            fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
25
26        // Create backup if the file exists, keep only last 5
27        if self.path.exists() {
28            self.create_backup()
29                .context("Failed to create backup of SSH config")?;
30            self.prune_backups(5).ok();
31        }
32
33        let content = self.serialize();
34
35        fs_util::atomic_write(&target_path, content.as_bytes())
36            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
37
38        // Lock released on drop
39        Ok(())
40    }
41
42    /// Serialize the config to a string.
43    /// Collapses consecutive blank lines to prevent accumulation after deletions.
44    pub fn serialize(&self) -> String {
45        let mut lines = Vec::new();
46
47        for element in &self.elements {
48            match element {
49                ConfigElement::GlobalLine(line) => {
50                    lines.push(line.clone());
51                }
52                ConfigElement::HostBlock(block) => {
53                    lines.push(block.raw_host_line.clone());
54                    for directive in &block.directives {
55                        lines.push(directive.raw_line.clone());
56                    }
57                }
58                ConfigElement::Include(include) => {
59                    lines.push(include.raw_line.clone());
60                }
61            }
62        }
63
64        // Collapse consecutive blank lines (keep at most one)
65        let mut collapsed = Vec::with_capacity(lines.len());
66        let mut prev_blank = false;
67        for line in lines {
68            let is_blank = line.trim().is_empty();
69            if is_blank && prev_blank {
70                continue;
71            }
72            prev_blank = is_blank;
73            collapsed.push(line);
74        }
75
76        let line_ending = if self.crlf { "\r\n" } else { "\n" };
77        let mut result = String::new();
78        // Restore UTF-8 BOM if the original file had one
79        if self.bom {
80            result.push('\u{FEFF}');
81        }
82        for line in &collapsed {
83            result.push_str(line);
84            result.push_str(line_ending);
85        }
86        // Ensure files always end with exactly one newline
87        // (check collapsed instead of result, since BOM makes result non-empty)
88        if collapsed.is_empty() {
89            result.push_str(line_ending);
90        }
91        result
92    }
93
94    /// Create a timestamped backup of the current config file.
95    /// Backup files are created with chmod 600 to match the source file's sensitivity.
96    fn create_backup(&self) -> Result<()> {
97        let timestamp = SystemTime::now()
98            .duration_since(SystemTime::UNIX_EPOCH)
99            .unwrap_or_default()
100            .as_millis();
101        let backup_name = format!(
102            "{}.bak.{}",
103            self.path.file_name().unwrap_or_default().to_string_lossy(),
104            timestamp
105        );
106        let backup_path = self.path.with_file_name(backup_name);
107        fs::copy(&self.path, &backup_path).with_context(|| {
108            format!(
109                "Failed to copy {} to {}",
110                self.path.display(),
111                backup_path.display()
112            )
113        })?;
114
115        // Set backup permissions to 600 (owner read/write only)
116        #[cfg(unix)]
117        {
118            use std::os::unix::fs::PermissionsExt;
119            let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
120        }
121
122        Ok(())
123    }
124
125    /// Remove old backups, keeping only the most recent `keep` files.
126    fn prune_backups(&self, keep: usize) -> Result<()> {
127        let parent = self.path.parent().context("No parent directory")?;
128        let prefix = format!(
129            "{}.bak.",
130            self.path.file_name().unwrap_or_default().to_string_lossy()
131        );
132        let mut backups: Vec<_> = fs::read_dir(parent)?
133            .filter_map(|e| e.ok())
134            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
135            .collect();
136        backups.sort_by_key(|e| e.file_name());
137        if backups.len() > keep {
138            for old in &backups[..backups.len() - keep] {
139                let _ = fs::remove_file(old.path());
140            }
141        }
142        Ok(())
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::ssh_config::model::HostEntry;
150    use std::path::PathBuf;
151
152    fn parse_str(content: &str) -> SshConfigFile {
153        SshConfigFile {
154            elements: SshConfigFile::parse_content(content),
155            path: PathBuf::from("/tmp/test_config"),
156            crlf: content.contains("\r\n"),
157            bom: false,
158        }
159    }
160
161    #[test]
162    fn test_round_trip_basic() {
163        let content = "\
164Host myserver
165  HostName 192.168.1.10
166  User admin
167  Port 2222
168";
169        let config = parse_str(content);
170        assert_eq!(config.serialize(), content);
171    }
172
173    #[test]
174    fn test_round_trip_with_comments() {
175        let content = "\
176# My SSH config
177# Generated by hand
178
179Host alpha
180  HostName alpha.example.com
181  # Deploy user
182  User deploy
183
184Host beta
185  HostName beta.example.com
186  User root
187";
188        let config = parse_str(content);
189        assert_eq!(config.serialize(), content);
190    }
191
192    #[test]
193    fn test_round_trip_with_globals_and_wildcards() {
194        let content = "\
195# Global settings
196Host *
197  ServerAliveInterval 60
198  ServerAliveCountMax 3
199
200Host production
201  HostName prod.example.com
202  User deployer
203  IdentityFile ~/.ssh/prod_key
204";
205        let config = parse_str(content);
206        assert_eq!(config.serialize(), content);
207    }
208
209    #[test]
210    fn test_add_host_serializes() {
211        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
212        config.add_host(&HostEntry {
213            alias: "newhost".to_string(),
214            hostname: "10.0.0.2".to_string(),
215            user: "admin".to_string(),
216            port: 22,
217            ..Default::default()
218        });
219        let output = config.serialize();
220        assert!(output.contains("Host newhost"));
221        assert!(output.contains("HostName 10.0.0.2"));
222        assert!(output.contains("User admin"));
223        // Port 22 is default, should not be written
224        assert!(!output.contains("Port 22"));
225    }
226
227    #[test]
228    fn test_delete_host_serializes() {
229        let content = "\
230Host alpha
231  HostName alpha.example.com
232
233Host beta
234  HostName beta.example.com
235";
236        let mut config = parse_str(content);
237        config.delete_host("alpha");
238        let output = config.serialize();
239        assert!(!output.contains("Host alpha"));
240        assert!(output.contains("Host beta"));
241    }
242
243    #[test]
244    fn test_update_host_serializes() {
245        let content = "\
246Host myserver
247  HostName 10.0.0.1
248  User old_user
249";
250        let mut config = parse_str(content);
251        config.update_host(
252            "myserver",
253            &HostEntry {
254                alias: "myserver".to_string(),
255                hostname: "10.0.0.2".to_string(),
256                user: "new_user".to_string(),
257                port: 22,
258                ..Default::default()
259            },
260        );
261        let output = config.serialize();
262        assert!(output.contains("HostName 10.0.0.2"));
263        assert!(output.contains("User new_user"));
264        assert!(!output.contains("old_user"));
265    }
266
267    #[test]
268    fn test_update_host_preserves_unknown_directives() {
269        let content = "\
270Host myserver
271  HostName 10.0.0.1
272  User admin
273  ForwardAgent yes
274  LocalForward 8080 localhost:80
275  Compression yes
276";
277        let mut config = parse_str(content);
278        config.update_host(
279            "myserver",
280            &HostEntry {
281                alias: "myserver".to_string(),
282                hostname: "10.0.0.2".to_string(),
283                user: "admin".to_string(),
284                port: 22,
285                ..Default::default()
286            },
287        );
288        let output = config.serialize();
289        assert!(output.contains("HostName 10.0.0.2"));
290        assert!(output.contains("ForwardAgent yes"));
291        assert!(output.contains("LocalForward 8080 localhost:80"));
292        assert!(output.contains("Compression yes"));
293    }
294}