Skip to main content

purple_ssh/ssh_config/
writer.rs

1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7use crate::fs_util;
8
9impl SshConfigFile {
10    /// Write the config back to disk.
11    /// Creates a backup before writing and uses atomic write (temp file + rename).
12    /// Resolves symlinks so the rename targets the real file, not the link.
13    pub fn write(&self) -> Result<()> {
14        // Resolve symlinks so we write through to the real file
15        let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
16
17        // Create backup if the file exists, keep only last 5
18        if self.path.exists() {
19            self.create_backup()
20                .context("Failed to create backup of SSH config")?;
21            self.prune_backups(5).ok();
22        }
23
24        let content = self.serialize();
25
26        fs_util::atomic_write(&target_path, content.as_bytes())
27            .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
28
29        Ok(())
30    }
31
32    /// Serialize the config to a string.
33    /// Collapses consecutive blank lines to prevent accumulation after deletions.
34    pub fn serialize(&self) -> String {
35        let mut lines = Vec::new();
36
37        for element in &self.elements {
38            match element {
39                ConfigElement::GlobalLine(line) => {
40                    lines.push(line.clone());
41                }
42                ConfigElement::HostBlock(block) => {
43                    lines.push(block.raw_host_line.clone());
44                    for directive in &block.directives {
45                        lines.push(directive.raw_line.clone());
46                    }
47                }
48                ConfigElement::Include(include) => {
49                    lines.push(include.raw_line.clone());
50                }
51            }
52        }
53
54        // Collapse consecutive blank lines (keep at most one)
55        let mut collapsed = Vec::with_capacity(lines.len());
56        let mut prev_blank = false;
57        for line in lines {
58            let is_blank = line.trim().is_empty();
59            if is_blank && prev_blank {
60                continue;
61            }
62            prev_blank = is_blank;
63            collapsed.push(line);
64        }
65
66        let line_ending = if self.crlf { "\r\n" } else { "\n" };
67        let mut result = String::new();
68        for line in &collapsed {
69            result.push_str(line);
70            result.push_str(line_ending);
71        }
72        // Ensure non-empty files end with exactly one newline
73        if result.is_empty() {
74            result.push_str(line_ending);
75        }
76        result
77    }
78
79    /// Create a timestamped backup of the current config file.
80    /// Backup files are created with chmod 600 to match the source file's sensitivity.
81    fn create_backup(&self) -> Result<()> {
82        let timestamp = SystemTime::now()
83            .duration_since(SystemTime::UNIX_EPOCH)
84            .unwrap_or_default()
85            .as_millis();
86        let backup_name = format!(
87            "{}.bak.{}",
88            self.path.file_name().unwrap_or_default().to_string_lossy(),
89            timestamp
90        );
91        let backup_path = self.path.with_file_name(backup_name);
92        fs::copy(&self.path, &backup_path).with_context(|| {
93            format!(
94                "Failed to copy {} to {}",
95                self.path.display(),
96                backup_path.display()
97            )
98        })?;
99
100        // Set backup permissions to 600 (owner read/write only)
101        #[cfg(unix)]
102        {
103            use std::os::unix::fs::PermissionsExt;
104            let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
105        }
106
107        Ok(())
108    }
109
110    /// Remove old backups, keeping only the most recent `keep` files.
111    fn prune_backups(&self, keep: usize) -> Result<()> {
112        let parent = self.path.parent().context("No parent directory")?;
113        let prefix = format!(
114            "{}.bak.",
115            self.path.file_name().unwrap_or_default().to_string_lossy()
116        );
117        let mut backups: Vec<_> = fs::read_dir(parent)?
118            .filter_map(|e| e.ok())
119            .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
120            .collect();
121        backups.sort_by_key(|e| e.file_name());
122        if backups.len() > keep {
123            for old in &backups[..backups.len() - keep] {
124                let _ = fs::remove_file(old.path());
125            }
126        }
127        Ok(())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::ssh_config::model::HostEntry;
135    use std::path::PathBuf;
136
137    fn parse_str(content: &str) -> SshConfigFile {
138        SshConfigFile {
139            elements: SshConfigFile::parse_content(content),
140            path: PathBuf::from("/tmp/test_config"),
141            crlf: content.contains("\r\n"),
142        }
143    }
144
145    #[test]
146    fn test_round_trip_basic() {
147        let content = "\
148Host myserver
149  HostName 192.168.1.10
150  User admin
151  Port 2222
152";
153        let config = parse_str(content);
154        assert_eq!(config.serialize(), content);
155    }
156
157    #[test]
158    fn test_round_trip_with_comments() {
159        let content = "\
160# My SSH config
161# Generated by hand
162
163Host alpha
164  HostName alpha.example.com
165  # Deploy user
166  User deploy
167
168Host beta
169  HostName beta.example.com
170  User root
171";
172        let config = parse_str(content);
173        assert_eq!(config.serialize(), content);
174    }
175
176    #[test]
177    fn test_round_trip_with_globals_and_wildcards() {
178        let content = "\
179# Global settings
180Host *
181  ServerAliveInterval 60
182  ServerAliveCountMax 3
183
184Host production
185  HostName prod.example.com
186  User deployer
187  IdentityFile ~/.ssh/prod_key
188";
189        let config = parse_str(content);
190        assert_eq!(config.serialize(), content);
191    }
192
193    #[test]
194    fn test_add_host_serializes() {
195        let mut config = parse_str("Host existing\n  HostName 10.0.0.1\n");
196        config.add_host(&HostEntry {
197            alias: "newhost".to_string(),
198            hostname: "10.0.0.2".to_string(),
199            user: "admin".to_string(),
200            port: 22,
201            ..Default::default()
202        });
203        let output = config.serialize();
204        assert!(output.contains("Host newhost"));
205        assert!(output.contains("HostName 10.0.0.2"));
206        assert!(output.contains("User admin"));
207        // Port 22 is default, should not be written
208        assert!(!output.contains("Port 22"));
209    }
210
211    #[test]
212    fn test_delete_host_serializes() {
213        let content = "\
214Host alpha
215  HostName alpha.example.com
216
217Host beta
218  HostName beta.example.com
219";
220        let mut config = parse_str(content);
221        config.delete_host("alpha");
222        let output = config.serialize();
223        assert!(!output.contains("Host alpha"));
224        assert!(output.contains("Host beta"));
225    }
226
227    #[test]
228    fn test_update_host_serializes() {
229        let content = "\
230Host myserver
231  HostName 10.0.0.1
232  User old_user
233";
234        let mut config = parse_str(content);
235        config.update_host(
236            "myserver",
237            &HostEntry {
238                alias: "myserver".to_string(),
239                hostname: "10.0.0.2".to_string(),
240                user: "new_user".to_string(),
241                port: 22,
242                ..Default::default()
243            },
244        );
245        let output = config.serialize();
246        assert!(output.contains("HostName 10.0.0.2"));
247        assert!(output.contains("User new_user"));
248        assert!(!output.contains("old_user"));
249    }
250
251    #[test]
252    fn test_update_host_preserves_unknown_directives() {
253        let content = "\
254Host myserver
255  HostName 10.0.0.1
256  User admin
257  ForwardAgent yes
258  LocalForward 8080 localhost:80
259  Compression yes
260";
261        let mut config = parse_str(content);
262        config.update_host(
263            "myserver",
264            &HostEntry {
265                alias: "myserver".to_string(),
266                hostname: "10.0.0.2".to_string(),
267                user: "admin".to_string(),
268                port: 22,
269                ..Default::default()
270            },
271        );
272        let output = config.serialize();
273        assert!(output.contains("HostName 10.0.0.2"));
274        assert!(output.contains("ForwardAgent yes"));
275        assert!(output.contains("LocalForward 8080 localhost:80"));
276        assert!(output.contains("Compression yes"));
277    }
278}