purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
12 if self.path.exists() {
14 self.create_backup()
15 .context("Failed to create backup of SSH config")?;
16 self.prune_backups(5).ok();
17 }
18
19 let content = self.serialize();
20
21 if let Some(parent) = self.path.parent() {
23 fs::create_dir_all(parent)
24 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25 }
26
27 let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30 #[cfg(unix)]
31 {
32 use std::io::Write;
33 use std::os::unix::fs::OpenOptionsExt;
34 let mut file = fs::OpenOptions::new()
35 .write(true)
36 .create(true)
37 .truncate(true)
38 .mode(0o600)
39 .open(&tmp_path)
40 .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41 file.write_all(content.as_bytes())
42 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43 }
44
45 #[cfg(not(unix))]
46 fs::write(&tmp_path, &content)
47 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49 let result = fs::rename(&tmp_path, &self.path);
50 if result.is_err() {
51 let _ = fs::remove_file(&tmp_path);
52 }
53 result.with_context(|| {
54 format!(
55 "Failed to rename {} to {}",
56 tmp_path.display(),
57 self.path.display()
58 )
59 })?;
60
61 Ok(())
62 }
63
64 pub fn serialize(&self) -> String {
66 let mut lines = Vec::new();
67
68 for element in &self.elements {
69 match element {
70 ConfigElement::GlobalLine(line) => {
71 lines.push(line.clone());
72 }
73 ConfigElement::HostBlock(block) => {
74 lines.push(block.raw_host_line.clone());
75 for directive in &block.directives {
76 lines.push(directive.raw_line.clone());
77 }
78 }
79 ConfigElement::Include(include) => {
80 lines.push(include.raw_line.clone());
81 }
82 }
83 }
84
85 let line_ending = if self.crlf { "\r\n" } else { "\n" };
86 let mut result = lines.join(line_ending);
87 if !result.ends_with('\n') {
89 result.push_str(line_ending);
90 }
91 result
92 }
93
94 fn create_backup(&self) -> Result<()> {
96 let timestamp = SystemTime::now()
97 .duration_since(SystemTime::UNIX_EPOCH)
98 .unwrap_or_default()
99 .as_millis();
100 let backup_name = format!(
101 "{}.bak.{}",
102 self.path.file_name().unwrap_or_default().to_string_lossy(),
103 timestamp
104 );
105 let backup_path = self.path.with_file_name(backup_name);
106 fs::copy(&self.path, &backup_path).with_context(|| {
107 format!(
108 "Failed to copy {} to {}",
109 self.path.display(),
110 backup_path.display()
111 )
112 })?;
113 Ok(())
114 }
115
116 fn prune_backups(&self, keep: usize) -> Result<()> {
118 let parent = self.path.parent().context("No parent directory")?;
119 let prefix = format!(
120 "{}.bak.",
121 self.path.file_name().unwrap_or_default().to_string_lossy()
122 );
123 let mut backups: Vec<_> = fs::read_dir(parent)?
124 .filter_map(|e| e.ok())
125 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
126 .collect();
127 backups.sort_by_key(|e| e.file_name());
128 if backups.len() > keep {
129 for old in &backups[..backups.len() - keep] {
130 let _ = fs::remove_file(old.path());
131 }
132 }
133 Ok(())
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::ssh_config::model::HostEntry;
141 use std::path::PathBuf;
142
143 fn parse_str(content: &str) -> SshConfigFile {
144 SshConfigFile {
145 elements: SshConfigFile::parse_content(content),
146 path: PathBuf::from("/tmp/test_config"),
147 crlf: content.contains("\r\n"),
148 }
149 }
150
151 #[test]
152 fn test_round_trip_basic() {
153 let content = "\
154Host myserver
155 HostName 192.168.1.10
156 User admin
157 Port 2222
158";
159 let config = parse_str(content);
160 assert_eq!(config.serialize(), content);
161 }
162
163 #[test]
164 fn test_round_trip_with_comments() {
165 let content = "\
166# My SSH config
167# Generated by hand
168
169Host alpha
170 HostName alpha.example.com
171 # Deploy user
172 User deploy
173
174Host beta
175 HostName beta.example.com
176 User root
177";
178 let config = parse_str(content);
179 assert_eq!(config.serialize(), content);
180 }
181
182 #[test]
183 fn test_round_trip_with_globals_and_wildcards() {
184 let content = "\
185# Global settings
186Host *
187 ServerAliveInterval 60
188 ServerAliveCountMax 3
189
190Host production
191 HostName prod.example.com
192 User deployer
193 IdentityFile ~/.ssh/prod_key
194";
195 let config = parse_str(content);
196 assert_eq!(config.serialize(), content);
197 }
198
199 #[test]
200 fn test_add_host_serializes() {
201 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
202 config.add_host(&HostEntry {
203 alias: "newhost".to_string(),
204 hostname: "10.0.0.2".to_string(),
205 user: "admin".to_string(),
206 port: 22,
207 ..Default::default()
208 });
209 let output = config.serialize();
210 assert!(output.contains("Host newhost"));
211 assert!(output.contains("HostName 10.0.0.2"));
212 assert!(output.contains("User admin"));
213 assert!(!output.contains("Port 22"));
215 }
216
217 #[test]
218 fn test_delete_host_serializes() {
219 let content = "\
220Host alpha
221 HostName alpha.example.com
222
223Host beta
224 HostName beta.example.com
225";
226 let mut config = parse_str(content);
227 config.delete_host("alpha");
228 let output = config.serialize();
229 assert!(!output.contains("Host alpha"));
230 assert!(output.contains("Host beta"));
231 }
232
233 #[test]
234 fn test_update_host_serializes() {
235 let content = "\
236Host myserver
237 HostName 10.0.0.1
238 User old_user
239";
240 let mut config = parse_str(content);
241 config.update_host(
242 "myserver",
243 &HostEntry {
244 alias: "myserver".to_string(),
245 hostname: "10.0.0.2".to_string(),
246 user: "new_user".to_string(),
247 port: 22,
248 ..Default::default()
249 },
250 );
251 let output = config.serialize();
252 assert!(output.contains("HostName 10.0.0.2"));
253 assert!(output.contains("User new_user"));
254 assert!(!output.contains("old_user"));
255 }
256
257 #[test]
258 fn test_update_host_preserves_unknown_directives() {
259 let content = "\
260Host myserver
261 HostName 10.0.0.1
262 User admin
263 ForwardAgent yes
264 LocalForward 8080 localhost:80
265 Compression yes
266";
267 let mut config = parse_str(content);
268 config.update_host(
269 "myserver",
270 &HostEntry {
271 alias: "myserver".to_string(),
272 hostname: "10.0.0.2".to_string(),
273 user: "admin".to_string(),
274 port: 22,
275 ..Default::default()
276 },
277 );
278 let output = config.serialize();
279 assert!(output.contains("HostName 10.0.0.2"));
280 assert!(output.contains("ForwardAgent yes"));
281 assert!(output.contains("LocalForward 8080 localhost:80"));
282 assert!(output.contains("Compression yes"));
283 }
284}