Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    pub fn write(&self) -> Result<()> {
12        // Create backup if the file exists, keep only last 5
13        if self.path.exists() {
14            self.create_backup()
15                .context("Failed to create backup of SSH config")?;
16            self.prune_backups(5).ok();
17        }
18
19        let content = self.serialize();
20
21        // Ensure parent directory exists
22        if let Some(parent) = self.path.parent() {
23            fs::create_dir_all(parent)
24                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25        }
26
27        // Atomic write: write to temp file (created with 0o600), then rename
28        let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30        #[cfg(unix)]
31        {
32            use std::io::Write;
33            use std::os::unix::fs::OpenOptionsExt;
34            let mut file = fs::OpenOptions::new()
35                .write(true)
36                .create(true)
37                .truncate(true)
38                .mode(0o600)
39                .open(&tmp_path)
40                .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41            file.write_all(content.as_bytes())
42                .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43        }
44
45        #[cfg(not(unix))]
46        fs::write(&tmp_path, &content)
47            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49        let result = fs::rename(&tmp_path, &self.path);
50        if result.is_err() {
51            let _ = fs::remove_file(&tmp_path);
52        }
53        result.with_context(|| {
54            format!(
55                "Failed to rename {} to {}",
56                tmp_path.display(),
57                self.path.display()
58            )
59        })?;
60
61        Ok(())
62    }
63
64    /// Serialize the config to a string.
65    pub fn serialize(&self) -> String {
66        let mut lines = Vec::new();
67
68        for element in &self.elements {
69            match element {
70                ConfigElement::GlobalLine(line) => {
71                    lines.push(line.clone());
72                }
73                ConfigElement::HostBlock(block) => {
74                    lines.push(block.raw_host_line.clone());
75                    for directive in &block.directives {
76                        lines.push(directive.raw_line.clone());
77                    }
78                }
79                ConfigElement::Include(include) => {
80                    lines.push(include.raw_line.clone());
81                }
82            }
83        }
84
85        let line_ending = if self.crlf { "\r\n" } else { "\n" };
86        let mut result = lines.join(line_ending);
87        // Ensure file ends with a newline
88        if !result.ends_with('\n') {
89            result.push_str(line_ending);
90        }
91        result
92    }
93
94    /// Create a timestamped backup of the current config file.
95    fn create_backup(&self) -> Result<()> {
96        let timestamp = SystemTime::now()
97            .duration_since(SystemTime::UNIX_EPOCH)
98            .unwrap_or_default()
99            .as_millis();
100        let backup_name = format!(
101            "{}.bak.{}",
102            self.path.file_name().unwrap_or_default().to_string_lossy(),
103            timestamp
104        );
105        let backup_path = self.path.with_file_name(backup_name);
106        fs::copy(&self.path, &backup_path).with_context(|| {
107            format!(
108                "Failed to copy {} to {}",
109                self.path.display(),
110                backup_path.display()
111            )
112        })?;
113        Ok(())
114    }
115
116    /// Remove old backups, keeping only the most recent `keep` files.
117    fn prune_backups(&self, keep: usize) -> Result<()> {
118        let parent = self.path.parent().context("No parent directory")?;
119        let prefix = format!(
120            "{}.bak.",
121            self.path.file_name().unwrap_or_default().to_string_lossy()
122        );
123        let mut backups: Vec<_> = fs::read_dir(parent)?
124            .filter_map(|e| e.ok())
125            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
126            .collect();
127        backups.sort_by_key(|e| e.file_name());
128        if backups.len() > keep {
129            for old in &backups[..backups.len() - keep] {
130                let _ = fs::remove_file(old.path());
131            }
132        }
133        Ok(())
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::ssh_config::model::HostEntry;
141    use std::path::PathBuf;
142
143    fn parse_str(content: &str) -> SshConfigFile {
144        SshConfigFile {
145            elements: SshConfigFile::parse_content(content),
146            path: PathBuf::from("/tmp/test_config"),
147            crlf: content.contains("\r\n"),
148        }
149    }
150
151    #[test]
152    fn test_round_trip_basic() {
153        let content = "\
154Host myserver
155  HostName 192.168.1.10
156  User admin
157  Port 2222
158";
159        let config = parse_str(content);
160        assert_eq!(config.serialize(), content);
161    }
162
163    #[test]
164    fn test_round_trip_with_comments() {
165        let content = "\
166# My SSH config
167# Generated by hand
168
169Host alpha
170  HostName alpha.example.com
171  # Deploy user
172  User deploy
173
174Host beta
175  HostName beta.example.com
176  User root
177";
178        let config = parse_str(content);
179        assert_eq!(config.serialize(), content);
180    }
181
182    #[test]
183    fn test_round_trip_with_globals_and_wildcards() {
184        let content = "\
185# Global settings
186Host *
187  ServerAliveInterval 60
188  ServerAliveCountMax 3
189
190Host production
191  HostName prod.example.com
192  User deployer
193  IdentityFile ~/.ssh/prod_key
194";
195        let config = parse_str(content);
196        assert_eq!(config.serialize(), content);
197    }
198
199    #[test]
200    fn test_add_host_serializes() {
201        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
202        config.add_host(&HostEntry {
203            alias: "newhost".to_string(),
204            hostname: "10.0.0.2".to_string(),
205            user: "admin".to_string(),
206            port: 22,
207            ..Default::default()
208        });
209        let output = config.serialize();
210        assert!(output.contains("Host newhost"));
211        assert!(output.contains("HostName 10.0.0.2"));
212        assert!(output.contains("User admin"));
213        // Port 22 is default, should not be written
214        assert!(!output.contains("Port 22"));
215    }
216
217    #[test]
218    fn test_delete_host_serializes() {
219        let content = "\
220Host alpha
221  HostName alpha.example.com
222
223Host beta
224  HostName beta.example.com
225";
226        let mut config = parse_str(content);
227        config.delete_host("alpha");
228        let output = config.serialize();
229        assert!(!output.contains("Host alpha"));
230        assert!(output.contains("Host beta"));
231    }
232
233    #[test]
234    fn test_update_host_serializes() {
235        let content = "\
236Host myserver
237  HostName 10.0.0.1
238  User old_user
239";
240        let mut config = parse_str(content);
241        config.update_host(
242            "myserver",
243            &HostEntry {
244                alias: "myserver".to_string(),
245                hostname: "10.0.0.2".to_string(),
246                user: "new_user".to_string(),
247                port: 22,
248                ..Default::default()
249            },
250        );
251        let output = config.serialize();
252        assert!(output.contains("HostName 10.0.0.2"));
253        assert!(output.contains("User new_user"));
254        assert!(!output.contains("old_user"));
255    }
256
257    #[test]
258    fn test_update_host_preserves_unknown_directives() {
259        let content = "\
260Host myserver
261  HostName 10.0.0.1
262  User admin
263  ForwardAgent yes
264  LocalForward 8080 localhost:80
265  Compression yes
266";
267        let mut config = parse_str(content);
268        config.update_host(
269            "myserver",
270            &HostEntry {
271                alias: "myserver".to_string(),
272                hostname: "10.0.0.2".to_string(),
273                user: "admin".to_string(),
274                port: 22,
275                ..Default::default()
276            },
277        );
278        let output = config.serialize();
279        assert!(output.contains("HostName 10.0.0.2"));
280        assert!(output.contains("ForwardAgent yes"));
281        assert!(output.contains("LocalForward 8080 localhost:80"));
282        assert!(output.contains("Compression yes"));
283    }
284}