Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9    /// Write the config back to disk.
10    /// Creates a backup before writing and uses atomic write (temp file + rename).
11    pub fn write(&self) -> Result<()> {
12        // Create backup if the file exists, keep only last 5
13        if self.path.exists() {
14            self.create_backup()
15                .context("Failed to create backup of SSH config")?;
16            self.prune_backups(5).ok();
17        }
18
19        let content = self.serialize();
20
21        // Ensure parent directory exists
22        if let Some(parent) = self.path.parent() {
23            fs::create_dir_all(parent)
24                .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25        }
26
27        // Atomic write: write to temp file, set permissions, then rename
28        let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29        fs::write(&tmp_path, &content)
30            .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
31
32        // Set permissions BEFORE rename so the file is never world-readable
33        #[cfg(unix)]
34        {
35            use std::os::unix::fs::PermissionsExt;
36            let perms = fs::Permissions::from_mode(0o600);
37            fs::set_permissions(&tmp_path, perms)
38                .with_context(|| format!("Failed to set permissions on {}", tmp_path.display()))?;
39        }
40
41        fs::rename(&tmp_path, &self.path).with_context(|| {
42            format!(
43                "Failed to rename {} to {}",
44                tmp_path.display(),
45                self.path.display()
46            )
47        })?;
48
49        Ok(())
50    }
51
52    /// Serialize the config to a string.
53    pub fn serialize(&self) -> String {
54        let mut lines = Vec::new();
55
56        for element in &self.elements {
57            match element {
58                ConfigElement::GlobalLine(line) => {
59                    lines.push(line.clone());
60                }
61                ConfigElement::HostBlock(block) => {
62                    lines.push(block.raw_host_line.clone());
63                    for directive in &block.directives {
64                        lines.push(directive.raw_line.clone());
65                    }
66                }
67                ConfigElement::Include(include) => {
68                    lines.push(include.raw_line.clone());
69                }
70            }
71        }
72
73        let mut result = lines.join("\n");
74        // Ensure file ends with a newline
75        if !result.ends_with('\n') {
76            result.push('\n');
77        }
78        result
79    }
80
81    /// Create a timestamped backup of the current config file.
82    fn create_backup(&self) -> Result<()> {
83        let timestamp = SystemTime::now()
84            .duration_since(SystemTime::UNIX_EPOCH)
85            .unwrap_or_default()
86            .as_millis();
87        let backup_name = format!(
88            "{}.bak.{}",
89            self.path.file_name().unwrap_or_default().to_string_lossy(),
90            timestamp
91        );
92        let backup_path = self.path.with_file_name(backup_name);
93        fs::copy(&self.path, &backup_path).with_context(|| {
94            format!(
95                "Failed to copy {} to {}",
96                self.path.display(),
97                backup_path.display()
98            )
99        })?;
100        Ok(())
101    }
102
103    /// Remove old backups, keeping only the most recent `keep` files.
104    fn prune_backups(&self, keep: usize) -> Result<()> {
105        let parent = self.path.parent().context("No parent directory")?;
106        let prefix = format!(
107            "{}.bak.",
108            self.path.file_name().unwrap_or_default().to_string_lossy()
109        );
110        let mut backups: Vec<_> = fs::read_dir(parent)?
111            .filter_map(|e| e.ok())
112            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
113            .collect();
114        backups.sort_by_key(|e| e.file_name());
115        if backups.len() > keep {
116            for old in &backups[..backups.len() - keep] {
117                let _ = fs::remove_file(old.path());
118            }
119        }
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::ssh_config::model::HostEntry;
128    use std::path::PathBuf;
129
130    fn parse_str(content: &str) -> SshConfigFile {
131        SshConfigFile {
132            elements: SshConfigFile::parse_content(content),
133            path: PathBuf::from("/tmp/test_config"),
134        }
135    }
136
137    #[test]
138    fn test_round_trip_basic() {
139        let content = "\
140Host myserver
141  HostName 192.168.1.10
142  User admin
143  Port 2222
144";
145        let config = parse_str(content);
146        assert_eq!(config.serialize(), content);
147    }
148
149    #[test]
150    fn test_round_trip_with_comments() {
151        let content = "\
152# My SSH config
153# Generated by hand
154
155Host alpha
156  HostName alpha.example.com
157  # Deploy user
158  User deploy
159
160Host beta
161  HostName beta.example.com
162  User root
163";
164        let config = parse_str(content);
165        assert_eq!(config.serialize(), content);
166    }
167
168    #[test]
169    fn test_round_trip_with_globals_and_wildcards() {
170        let content = "\
171# Global settings
172Host *
173  ServerAliveInterval 60
174  ServerAliveCountMax 3
175
176Host production
177  HostName prod.example.com
178  User deployer
179  IdentityFile ~/.ssh/prod_key
180";
181        let config = parse_str(content);
182        assert_eq!(config.serialize(), content);
183    }
184
185    #[test]
186    fn test_add_host_serializes() {
187        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
188        config.add_host(&HostEntry {
189            alias: "newhost".to_string(),
190            hostname: "10.0.0.2".to_string(),
191            user: "admin".to_string(),
192            port: 22,
193            ..Default::default()
194        });
195        let output = config.serialize();
196        assert!(output.contains("Host newhost"));
197        assert!(output.contains("HostName 10.0.0.2"));
198        assert!(output.contains("User admin"));
199        // Port 22 is default, should not be written
200        assert!(!output.contains("Port 22"));
201    }
202
203    #[test]
204    fn test_delete_host_serializes() {
205        let content = "\
206Host alpha
207  HostName alpha.example.com
208
209Host beta
210  HostName beta.example.com
211";
212        let mut config = parse_str(content);
213        config.delete_host("alpha");
214        let output = config.serialize();
215        assert!(!output.contains("Host alpha"));
216        assert!(output.contains("Host beta"));
217    }
218
219    #[test]
220    fn test_update_host_serializes() {
221        let content = "\
222Host myserver
223  HostName 10.0.0.1
224  User old_user
225";
226        let mut config = parse_str(content);
227        config.update_host(
228            "myserver",
229            &HostEntry {
230                alias: "myserver".to_string(),
231                hostname: "10.0.0.2".to_string(),
232                user: "new_user".to_string(),
233                port: 22,
234                ..Default::default()
235            },
236        );
237        let output = config.serialize();
238        assert!(output.contains("HostName 10.0.0.2"));
239        assert!(output.contains("User new_user"));
240        assert!(!output.contains("old_user"));
241    }
242
243    #[test]
244    fn test_update_host_preserves_unknown_directives() {
245        let content = "\
246Host myserver
247  HostName 10.0.0.1
248  User admin
249  ForwardAgent yes
250  LocalForward 8080 localhost:80
251  Compression yes
252";
253        let mut config = parse_str(content);
254        config.update_host(
255            "myserver",
256            &HostEntry {
257                alias: "myserver".to_string(),
258                hostname: "10.0.0.2".to_string(),
259                user: "admin".to_string(),
260                port: 22,
261                ..Default::default()
262            },
263        );
264        let output = config.serialize();
265        assert!(output.contains("HostName 10.0.0.2"));
266        assert!(output.contains("ForwardAgent yes"));
267        assert!(output.contains("LocalForward 8080 localhost:80"));
268        assert!(output.contains("Compression yes"));
269    }
270}