Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    pub fn write(&self) -> Result<()> {
12        // Create backup if the file exists, keep only last 5
13        if self.path.exists() {
14            self.create_backup()
15                .context("Failed to create backup of SSH config")?;
16            self.prune_backups(5).ok();
17        }
18
19        let content = self.serialize();
20
21        // Ensure parent directory exists
22        if let Some(parent) = self.path.parent() {
23            fs::create_dir_all(parent)
24                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25        }
26
27        // Atomic write: write to temp file (created with 0o600), then rename
28        let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30        #[cfg(unix)]
31        {
32            use std::io::Write;
33            use std::os::unix::fs::OpenOptionsExt;
34            let mut file = fs::OpenOptions::new()
35                .write(true)
36                .create(true)
37                .truncate(true)
38                .mode(0o600)
39                .open(&tmp_path)
40                .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41            file.write_all(content.as_bytes())
42                .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43        }
44
45        #[cfg(not(unix))]
46        fs::write(&tmp_path, &content)
47            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49        let result = fs::rename(&tmp_path, &self.path);
50        if result.is_err() {
51            let _ = fs::remove_file(&tmp_path);
52        }
53        result.with_context(|| {
54            format!(
55                "Failed to rename {} to {}",
56                tmp_path.display(),
57                self.path.display()
58            )
59        })?;
60
61        Ok(())
62    }
63
64    /// Serialize the config to a string.
65    pub fn serialize(&self) -> String {
66        let mut lines = Vec::new();
67
68        for element in &self.elements {
69            match element {
70                ConfigElement::GlobalLine(line) => {
71                    lines.push(line.clone());
72                }
73                ConfigElement::HostBlock(block) => {
74                    lines.push(block.raw_host_line.clone());
75                    for directive in &block.directives {
76                        lines.push(directive.raw_line.clone());
77                    }
78                }
79                ConfigElement::Include(include) => {
80                    lines.push(include.raw_line.clone());
81                }
82            }
83        }
84
85        let mut result = lines.join("\n");
86        // Ensure file ends with a newline
87        if !result.ends_with('\n') {
88            result.push('\n');
89        }
90        result
91    }
92
93    /// Create a timestamped backup of the current config file.
94    fn create_backup(&self) -> Result<()> {
95        let timestamp = SystemTime::now()
96            .duration_since(SystemTime::UNIX_EPOCH)
97            .unwrap_or_default()
98            .as_millis();
99        let backup_name = format!(
100            "{}.bak.{}",
101            self.path.file_name().unwrap_or_default().to_string_lossy(),
102            timestamp
103        );
104        let backup_path = self.path.with_file_name(backup_name);
105        fs::copy(&self.path, &backup_path).with_context(|| {
106            format!(
107                "Failed to copy {} to {}",
108                self.path.display(),
109                backup_path.display()
110            )
111        })?;
112        Ok(())
113    }
114
115    /// Remove old backups, keeping only the most recent `keep` files.
116    fn prune_backups(&self, keep: usize) -> Result<()> {
117        let parent = self.path.parent().context("No parent directory")?;
118        let prefix = format!(
119            "{}.bak.",
120            self.path.file_name().unwrap_or_default().to_string_lossy()
121        );
122        let mut backups: Vec<_> = fs::read_dir(parent)?
123            .filter_map(|e| e.ok())
124            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
125            .collect();
126        backups.sort_by_key(|e| e.file_name());
127        if backups.len() > keep {
128            for old in &backups[..backups.len() - keep] {
129                let _ = fs::remove_file(old.path());
130            }
131        }
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::ssh_config::model::HostEntry;
140    use std::path::PathBuf;
141
142    fn parse_str(content: &str) -> SshConfigFile {
143        SshConfigFile {
144            elements: SshConfigFile::parse_content(content),
145            path: PathBuf::from("/tmp/test_config"),
146        }
147    }
148
149    #[test]
150    fn test_round_trip_basic() {
151        let content = "\
152Host myserver
153  HostName 192.168.1.10
154  User admin
155  Port 2222
156";
157        let config = parse_str(content);
158        assert_eq!(config.serialize(), content);
159    }
160
161    #[test]
162    fn test_round_trip_with_comments() {
163        let content = "\
164# My SSH config
165# Generated by hand
166
167Host alpha
168  HostName alpha.example.com
169  # Deploy user
170  User deploy
171
172Host beta
173  HostName beta.example.com
174  User root
175";
176        let config = parse_str(content);
177        assert_eq!(config.serialize(), content);
178    }
179
180    #[test]
181    fn test_round_trip_with_globals_and_wildcards() {
182        let content = "\
183# Global settings
184Host *
185  ServerAliveInterval 60
186  ServerAliveCountMax 3
187
188Host production
189  HostName prod.example.com
190  User deployer
191  IdentityFile ~/.ssh/prod_key
192";
193        let config = parse_str(content);
194        assert_eq!(config.serialize(), content);
195    }
196
197    #[test]
198    fn test_add_host_serializes() {
199        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
200        config.add_host(&HostEntry {
201            alias: "newhost".to_string(),
202            hostname: "10.0.0.2".to_string(),
203            user: "admin".to_string(),
204            port: 22,
205            ..Default::default()
206        });
207        let output = config.serialize();
208        assert!(output.contains("Host newhost"));
209        assert!(output.contains("HostName 10.0.0.2"));
210        assert!(output.contains("User admin"));
211        // Port 22 is default, should not be written
212        assert!(!output.contains("Port 22"));
213    }
214
215    #[test]
216    fn test_delete_host_serializes() {
217        let content = "\
218Host alpha
219  HostName alpha.example.com
220
221Host beta
222  HostName beta.example.com
223";
224        let mut config = parse_str(content);
225        config.delete_host("alpha");
226        let output = config.serialize();
227        assert!(!output.contains("Host alpha"));
228        assert!(output.contains("Host beta"));
229    }
230
231    #[test]
232    fn test_update_host_serializes() {
233        let content = "\
234Host myserver
235  HostName 10.0.0.1
236  User old_user
237";
238        let mut config = parse_str(content);
239        config.update_host(
240            "myserver",
241            &HostEntry {
242                alias: "myserver".to_string(),
243                hostname: "10.0.0.2".to_string(),
244                user: "new_user".to_string(),
245                port: 22,
246                ..Default::default()
247            },
248        );
249        let output = config.serialize();
250        assert!(output.contains("HostName 10.0.0.2"));
251        assert!(output.contains("User new_user"));
252        assert!(!output.contains("old_user"));
253    }
254
255    #[test]
256    fn test_update_host_preserves_unknown_directives() {
257        let content = "\
258Host myserver
259  HostName 10.0.0.1
260  User admin
261  ForwardAgent yes
262  LocalForward 8080 localhost:80
263  Compression yes
264";
265        let mut config = parse_str(content);
266        config.update_host(
267            "myserver",
268            &HostEntry {
269                alias: "myserver".to_string(),
270                hostname: "10.0.0.2".to_string(),
271                user: "admin".to_string(),
272                port: 22,
273                ..Default::default()
274            },
275        );
276        let output = config.serialize();
277        assert!(output.contains("HostName 10.0.0.2"));
278        assert!(output.contains("ForwardAgent yes"));
279        assert!(output.contains("LocalForward 8080 localhost:80"));
280        assert!(output.contains("Compression yes"));
281    }
282}