Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    /// Resolves symlinks so the rename targets the real file, not the link.
12    pub fn write(&self) -> Result<()> {
13        // Resolve symlinks so we write through to the real file
14        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
15
16        // Create backup if the file exists, keep only last 5
17        if self.path.exists() {
18            self.create_backup()
19                .context("Failed to create backup of SSH config")?;
20            self.prune_backups(5).ok();
21        }
22
23        let content = self.serialize();
24
25        // Ensure parent directory exists
26        if let Some(parent) = target_path.parent() {
27            fs::create_dir_all(parent)
28                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
29        }
30
31        // Atomic write: write to temp file (created with 0o600), then rename
32        let tmp_path =
33            target_path.with_extension(format!("purple_tmp.{}", std::process::id()));
34
35        #[cfg(unix)]
36        {
37            use std::io::Write;
38            use std::os::unix::fs::OpenOptionsExt;
39            let mut file = fs::OpenOptions::new()
40                .write(true)
41                .create(true)
42                .truncate(true)
43                .mode(0o600)
44                .open(&tmp_path)
45                .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
46            file.write_all(content.as_bytes())
47                .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48        }
49
50        #[cfg(not(unix))]
51        fs::write(&tmp_path, &content)
52            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
53
54        let result = fs::rename(&tmp_path, &target_path);
55        if result.is_err() {
56            let _ = fs::remove_file(&tmp_path);
57        }
58        result.with_context(|| {
59            format!(
60                "Failed to rename {} to {}",
61                tmp_path.display(),
62                target_path.display()
63            )
64        })?;
65
66        Ok(())
67    }
68
69    /// Serialize the config to a string.
70    pub fn serialize(&self) -> String {
71        let mut lines = Vec::new();
72
73        for element in &self.elements {
74            match element {
75                ConfigElement::GlobalLine(line) => {
76                    lines.push(line.clone());
77                }
78                ConfigElement::HostBlock(block) => {
79                    lines.push(block.raw_host_line.clone());
80                    for directive in &block.directives {
81                        lines.push(directive.raw_line.clone());
82                    }
83                }
84                ConfigElement::Include(include) => {
85                    lines.push(include.raw_line.clone());
86                }
87            }
88        }
89
90        let line_ending = if self.crlf { "\r\n" } else { "\n" };
91        let mut result = lines.join(line_ending);
92        // Ensure file ends with a newline
93        if !result.ends_with('\n') {
94            result.push_str(line_ending);
95        }
96        result
97    }
98
99    /// Create a timestamped backup of the current config file.
100    fn create_backup(&self) -> Result<()> {
101        let timestamp = SystemTime::now()
102            .duration_since(SystemTime::UNIX_EPOCH)
103            .unwrap_or_default()
104            .as_millis();
105        let backup_name = format!(
106            "{}.bak.{}",
107            self.path.file_name().unwrap_or_default().to_string_lossy(),
108            timestamp
109        );
110        let backup_path = self.path.with_file_name(backup_name);
111        fs::copy(&self.path, &backup_path).with_context(|| {
112            format!(
113                "Failed to copy {} to {}",
114                self.path.display(),
115                backup_path.display()
116            )
117        })?;
118        Ok(())
119    }
120
121    /// Remove old backups, keeping only the most recent `keep` files.
122    fn prune_backups(&self, keep: usize) -> Result<()> {
123        let parent = self.path.parent().context("No parent directory")?;
124        let prefix = format!(
125            "{}.bak.",
126            self.path.file_name().unwrap_or_default().to_string_lossy()
127        );
128        let mut backups: Vec<_> = fs::read_dir(parent)?
129            .filter_map(|e| e.ok())
130            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
131            .collect();
132        backups.sort_by_key(|e| e.file_name());
133        if backups.len() > keep {
134            for old in &backups[..backups.len() - keep] {
135                let _ = fs::remove_file(old.path());
136            }
137        }
138        Ok(())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::ssh_config::model::HostEntry;
146    use std::path::PathBuf;
147
148    fn parse_str(content: &str) -> SshConfigFile {
149        SshConfigFile {
150            elements: SshConfigFile::parse_content(content),
151            path: PathBuf::from("/tmp/test_config"),
152            crlf: content.contains("\r\n"),
153        }
154    }
155
156    #[test]
157    fn test_round_trip_basic() {
158        let content = "\
159Host myserver
160  HostName 192.168.1.10
161  User admin
162  Port 2222
163";
164        let config = parse_str(content);
165        assert_eq!(config.serialize(), content);
166    }
167
168    #[test]
169    fn test_round_trip_with_comments() {
170        let content = "\
171# My SSH config
172# Generated by hand
173
174Host alpha
175  HostName alpha.example.com
176  # Deploy user
177  User deploy
178
179Host beta
180  HostName beta.example.com
181  User root
182";
183        let config = parse_str(content);
184        assert_eq!(config.serialize(), content);
185    }
186
187    #[test]
188    fn test_round_trip_with_globals_and_wildcards() {
189        let content = "\
190# Global settings
191Host *
192  ServerAliveInterval 60
193  ServerAliveCountMax 3
194
195Host production
196  HostName prod.example.com
197  User deployer
198  IdentityFile ~/.ssh/prod_key
199";
200        let config = parse_str(content);
201        assert_eq!(config.serialize(), content);
202    }
203
204    #[test]
205    fn test_add_host_serializes() {
206        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
207        config.add_host(&HostEntry {
208            alias: "newhost".to_string(),
209            hostname: "10.0.0.2".to_string(),
210            user: "admin".to_string(),
211            port: 22,
212            ..Default::default()
213        });
214        let output = config.serialize();
215        assert!(output.contains("Host newhost"));
216        assert!(output.contains("HostName 10.0.0.2"));
217        assert!(output.contains("User admin"));
218        // Port 22 is default, should not be written
219        assert!(!output.contains("Port 22"));
220    }
221
222    #[test]
223    fn test_delete_host_serializes() {
224        let content = "\
225Host alpha
226  HostName alpha.example.com
227
228Host beta
229  HostName beta.example.com
230";
231        let mut config = parse_str(content);
232        config.delete_host("alpha");
233        let output = config.serialize();
234        assert!(!output.contains("Host alpha"));
235        assert!(output.contains("Host beta"));
236    }
237
238    #[test]
239    fn test_update_host_serializes() {
240        let content = "\
241Host myserver
242  HostName 10.0.0.1
243  User old_user
244";
245        let mut config = parse_str(content);
246        config.update_host(
247            "myserver",
248            &HostEntry {
249                alias: "myserver".to_string(),
250                hostname: "10.0.0.2".to_string(),
251                user: "new_user".to_string(),
252                port: 22,
253                ..Default::default()
254            },
255        );
256        let output = config.serialize();
257        assert!(output.contains("HostName 10.0.0.2"));
258        assert!(output.contains("User new_user"));
259        assert!(!output.contains("old_user"));
260    }
261
262    #[test]
263    fn test_update_host_preserves_unknown_directives() {
264        let content = "\
265Host myserver
266  HostName 10.0.0.1
267  User admin
268  ForwardAgent yes
269  LocalForward 8080 localhost:80
270  Compression yes
271";
272        let mut config = parse_str(content);
273        config.update_host(
274            "myserver",
275            &HostEntry {
276                alias: "myserver".to_string(),
277                hostname: "10.0.0.2".to_string(),
278                user: "admin".to_string(),
279                port: 22,
280                ..Default::default()
281            },
282        );
283        let output = config.serialize();
284        assert!(output.contains("HostName 10.0.0.2"));
285        assert!(output.contains("ForwardAgent yes"));
286        assert!(output.contains("LocalForward 8080 localhost:80"));
287        assert!(output.contains("Compression yes"));
288    }
289}