Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    pub fn write(&self) -> Result<()> {
12        // Create backup if the file exists, keep only last 5
13        if self.path.exists() {
14            self.create_backup()
15                .context("Failed to create backup of SSH config")?;
16            self.prune_backups(5).ok();
17        }
18
19        let content = self.serialize();
20
21        // Ensure parent directory exists
22        if let Some(parent) = self.path.parent() {
23            fs::create_dir_all(parent)
24                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25        }
26
27        // Atomic write: write to temp file (created with 0o600), then rename
28        let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30        #[cfg(unix)]
31        {
32            use std::io::Write;
33            use std::os::unix::fs::OpenOptionsExt;
34            let mut file = fs::OpenOptions::new()
35                .write(true)
36                .create(true)
37                .truncate(true)
38                .mode(0o600)
39                .open(&tmp_path)
40                .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41            file.write_all(content.as_bytes())
42                .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43        }
44
45        #[cfg(not(unix))]
46        fs::write(&tmp_path, &content)
47            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49        fs::rename(&tmp_path, &self.path).with_context(|| {
50            format!(
51                "Failed to rename {} to {}",
52                tmp_path.display(),
53                self.path.display()
54            )
55        })?;
56
57        Ok(())
58    }
59
60    /// Serialize the config to a string.
61    pub fn serialize(&self) -> String {
62        let mut lines = Vec::new();
63
64        for element in &self.elements {
65            match element {
66                ConfigElement::GlobalLine(line) => {
67                    lines.push(line.clone());
68                }
69                ConfigElement::HostBlock(block) => {
70                    lines.push(block.raw_host_line.clone());
71                    for directive in &block.directives {
72                        lines.push(directive.raw_line.clone());
73                    }
74                }
75                ConfigElement::Include(include) => {
76                    lines.push(include.raw_line.clone());
77                }
78            }
79        }
80
81        let mut result = lines.join("\n");
82        // Ensure file ends with a newline
83        if !result.ends_with('\n') {
84            result.push('\n');
85        }
86        result
87    }
88
89    /// Create a timestamped backup of the current config file.
90    fn create_backup(&self) -> Result<()> {
91        let timestamp = SystemTime::now()
92            .duration_since(SystemTime::UNIX_EPOCH)
93            .unwrap_or_default()
94            .as_millis();
95        let backup_name = format!(
96            "{}.bak.{}",
97            self.path.file_name().unwrap_or_default().to_string_lossy(),
98            timestamp
99        );
100        let backup_path = self.path.with_file_name(backup_name);
101        fs::copy(&self.path, &backup_path).with_context(|| {
102            format!(
103                "Failed to copy {} to {}",
104                self.path.display(),
105                backup_path.display()
106            )
107        })?;
108        Ok(())
109    }
110
111    /// Remove old backups, keeping only the most recent `keep` files.
112    fn prune_backups(&self, keep: usize) -> Result<()> {
113        let parent = self.path.parent().context("No parent directory")?;
114        let prefix = format!(
115            "{}.bak.",
116            self.path.file_name().unwrap_or_default().to_string_lossy()
117        );
118        let mut backups: Vec<_> = fs::read_dir(parent)?
119            .filter_map(|e| e.ok())
120            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
121            .collect();
122        backups.sort_by_key(|e| e.file_name());
123        if backups.len() > keep {
124            for old in &backups[..backups.len() - keep] {
125                let _ = fs::remove_file(old.path());
126            }
127        }
128        Ok(())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::ssh_config::model::HostEntry;
136    use std::path::PathBuf;
137
138    fn parse_str(content: &str) -> SshConfigFile {
139        SshConfigFile {
140            elements: SshConfigFile::parse_content(content),
141            path: PathBuf::from("/tmp/test_config"),
142        }
143    }
144
145    #[test]
146    fn test_round_trip_basic() {
147        let content = "\
148Host myserver
149  HostName 192.168.1.10
150  User admin
151  Port 2222
152";
153        let config = parse_str(content);
154        assert_eq!(config.serialize(), content);
155    }
156
157    #[test]
158    fn test_round_trip_with_comments() {
159        let content = "\
160# My SSH config
161# Generated by hand
162
163Host alpha
164  HostName alpha.example.com
165  # Deploy user
166  User deploy
167
168Host beta
169  HostName beta.example.com
170  User root
171";
172        let config = parse_str(content);
173        assert_eq!(config.serialize(), content);
174    }
175
176    #[test]
177    fn test_round_trip_with_globals_and_wildcards() {
178        let content = "\
179# Global settings
180Host *
181  ServerAliveInterval 60
182  ServerAliveCountMax 3
183
184Host production
185  HostName prod.example.com
186  User deployer
187  IdentityFile ~/.ssh/prod_key
188";
189        let config = parse_str(content);
190        assert_eq!(config.serialize(), content);
191    }
192
193    #[test]
194    fn test_add_host_serializes() {
195        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
196        config.add_host(&HostEntry {
197            alias: "newhost".to_string(),
198            hostname: "10.0.0.2".to_string(),
199            user: "admin".to_string(),
200            port: 22,
201            ..Default::default()
202        });
203        let output = config.serialize();
204        assert!(output.contains("Host newhost"));
205        assert!(output.contains("HostName 10.0.0.2"));
206        assert!(output.contains("User admin"));
207        // Port 22 is default, should not be written
208        assert!(!output.contains("Port 22"));
209    }
210
211    #[test]
212    fn test_delete_host_serializes() {
213        let content = "\
214Host alpha
215  HostName alpha.example.com
216
217Host beta
218  HostName beta.example.com
219";
220        let mut config = parse_str(content);
221        config.delete_host("alpha");
222        let output = config.serialize();
223        assert!(!output.contains("Host alpha"));
224        assert!(output.contains("Host beta"));
225    }
226
227    #[test]
228    fn test_update_host_serializes() {
229        let content = "\
230Host myserver
231  HostName 10.0.0.1
232  User old_user
233";
234        let mut config = parse_str(content);
235        config.update_host(
236            "myserver",
237            &HostEntry {
238                alias: "myserver".to_string(),
239                hostname: "10.0.0.2".to_string(),
240                user: "new_user".to_string(),
241                port: 22,
242                ..Default::default()
243            },
244        );
245        let output = config.serialize();
246        assert!(output.contains("HostName 10.0.0.2"));
247        assert!(output.contains("User new_user"));
248        assert!(!output.contains("old_user"));
249    }
250
251    #[test]
252    fn test_update_host_preserves_unknown_directives() {
253        let content = "\
254Host myserver
255  HostName 10.0.0.1
256  User admin
257  ForwardAgent yes
258  LocalForward 8080 localhost:80
259  Compression yes
260";
261        let mut config = parse_str(content);
262        config.update_host(
263            "myserver",
264            &HostEntry {
265                alias: "myserver".to_string(),
266                hostname: "10.0.0.2".to_string(),
267                user: "admin".to_string(),
268                port: 22,
269                ..Default::default()
270            },
271        );
272        let output = config.serialize();
273        assert!(output.contains("HostName 10.0.0.2"));
274        assert!(output.contains("ForwardAgent yes"));
275        assert!(output.contains("LocalForward 8080 localhost:80"));
276        assert!(output.contains("Compression yes"));
277    }
278}