purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
13 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
15
16 if self.path.exists() {
18 self.create_backup()
19 .context("Failed to create backup of SSH config")?;
20 self.prune_backups(5).ok();
21 }
22
23 let content = self.serialize();
24
25 if let Some(parent) = target_path.parent() {
27 fs::create_dir_all(parent)
28 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
29 }
30
31 let tmp_path =
33 target_path.with_extension(format!("purple_tmp.{}", std::process::id()));
34
35 #[cfg(unix)]
36 {
37 use std::io::Write;
38 use std::os::unix::fs::OpenOptionsExt;
39 let mut file = fs::OpenOptions::new()
40 .write(true)
41 .create(true)
42 .truncate(true)
43 .mode(0o600)
44 .open(&tmp_path)
45 .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
46 file.write_all(content.as_bytes())
47 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48 }
49
50 #[cfg(not(unix))]
51 fs::write(&tmp_path, &content)
52 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
53
54 let result = fs::rename(&tmp_path, &target_path);
55 if result.is_err() {
56 let _ = fs::remove_file(&tmp_path);
57 }
58 result.with_context(|| {
59 format!(
60 "Failed to rename {} to {}",
61 tmp_path.display(),
62 target_path.display()
63 )
64 })?;
65
66 Ok(())
67 }
68
69 pub fn serialize(&self) -> String {
72 let mut lines = Vec::new();
73
74 for element in &self.elements {
75 match element {
76 ConfigElement::GlobalLine(line) => {
77 lines.push(line.clone());
78 }
79 ConfigElement::HostBlock(block) => {
80 lines.push(block.raw_host_line.clone());
81 for directive in &block.directives {
82 lines.push(directive.raw_line.clone());
83 }
84 }
85 ConfigElement::Include(include) => {
86 lines.push(include.raw_line.clone());
87 }
88 }
89 }
90
91 let mut collapsed = Vec::with_capacity(lines.len());
93 let mut prev_blank = false;
94 for line in lines {
95 let is_blank = line.trim().is_empty();
96 if is_blank && prev_blank {
97 continue;
98 }
99 prev_blank = is_blank;
100 collapsed.push(line);
101 }
102
103 let line_ending = if self.crlf { "\r\n" } else { "\n" };
104 let mut result = String::new();
105 for line in &collapsed {
106 result.push_str(line);
107 result.push_str(line_ending);
108 }
109 if result.is_empty() {
111 result.push_str(line_ending);
112 }
113 result
114 }
115
116 fn create_backup(&self) -> Result<()> {
118 let timestamp = SystemTime::now()
119 .duration_since(SystemTime::UNIX_EPOCH)
120 .unwrap_or_default()
121 .as_millis();
122 let backup_name = format!(
123 "{}.bak.{}",
124 self.path.file_name().unwrap_or_default().to_string_lossy(),
125 timestamp
126 );
127 let backup_path = self.path.with_file_name(backup_name);
128 fs::copy(&self.path, &backup_path).with_context(|| {
129 format!(
130 "Failed to copy {} to {}",
131 self.path.display(),
132 backup_path.display()
133 )
134 })?;
135 Ok(())
136 }
137
138 fn prune_backups(&self, keep: usize) -> Result<()> {
140 let parent = self.path.parent().context("No parent directory")?;
141 let prefix = format!(
142 "{}.bak.",
143 self.path.file_name().unwrap_or_default().to_string_lossy()
144 );
145 let mut backups: Vec<_> = fs::read_dir(parent)?
146 .filter_map(|e| e.ok())
147 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148 .collect();
149 backups.sort_by_key(|e| e.file_name());
150 if backups.len() > keep {
151 for old in &backups[..backups.len() - keep] {
152 let _ = fs::remove_file(old.path());
153 }
154 }
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::ssh_config::model::HostEntry;
163 use std::path::PathBuf;
164
165 fn parse_str(content: &str) -> SshConfigFile {
166 SshConfigFile {
167 elements: SshConfigFile::parse_content(content),
168 path: PathBuf::from("/tmp/test_config"),
169 crlf: content.contains("\r\n"),
170 }
171 }
172
173 #[test]
174 fn test_round_trip_basic() {
175 let content = "\
176Host myserver
177 HostName 192.168.1.10
178 User admin
179 Port 2222
180";
181 let config = parse_str(content);
182 assert_eq!(config.serialize(), content);
183 }
184
185 #[test]
186 fn test_round_trip_with_comments() {
187 let content = "\
188# My SSH config
189# Generated by hand
190
191Host alpha
192 HostName alpha.example.com
193 # Deploy user
194 User deploy
195
196Host beta
197 HostName beta.example.com
198 User root
199";
200 let config = parse_str(content);
201 assert_eq!(config.serialize(), content);
202 }
203
204 #[test]
205 fn test_round_trip_with_globals_and_wildcards() {
206 let content = "\
207# Global settings
208Host *
209 ServerAliveInterval 60
210 ServerAliveCountMax 3
211
212Host production
213 HostName prod.example.com
214 User deployer
215 IdentityFile ~/.ssh/prod_key
216";
217 let config = parse_str(content);
218 assert_eq!(config.serialize(), content);
219 }
220
221 #[test]
222 fn test_add_host_serializes() {
223 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
224 config.add_host(&HostEntry {
225 alias: "newhost".to_string(),
226 hostname: "10.0.0.2".to_string(),
227 user: "admin".to_string(),
228 port: 22,
229 ..Default::default()
230 });
231 let output = config.serialize();
232 assert!(output.contains("Host newhost"));
233 assert!(output.contains("HostName 10.0.0.2"));
234 assert!(output.contains("User admin"));
235 assert!(!output.contains("Port 22"));
237 }
238
239 #[test]
240 fn test_delete_host_serializes() {
241 let content = "\
242Host alpha
243 HostName alpha.example.com
244
245Host beta
246 HostName beta.example.com
247";
248 let mut config = parse_str(content);
249 config.delete_host("alpha");
250 let output = config.serialize();
251 assert!(!output.contains("Host alpha"));
252 assert!(output.contains("Host beta"));
253 }
254
255 #[test]
256 fn test_update_host_serializes() {
257 let content = "\
258Host myserver
259 HostName 10.0.0.1
260 User old_user
261";
262 let mut config = parse_str(content);
263 config.update_host(
264 "myserver",
265 &HostEntry {
266 alias: "myserver".to_string(),
267 hostname: "10.0.0.2".to_string(),
268 user: "new_user".to_string(),
269 port: 22,
270 ..Default::default()
271 },
272 );
273 let output = config.serialize();
274 assert!(output.contains("HostName 10.0.0.2"));
275 assert!(output.contains("User new_user"));
276 assert!(!output.contains("old_user"));
277 }
278
279 #[test]
280 fn test_update_host_preserves_unknown_directives() {
281 let content = "\
282Host myserver
283 HostName 10.0.0.1
284 User admin
285 ForwardAgent yes
286 LocalForward 8080 localhost:80
287 Compression yes
288";
289 let mut config = parse_str(content);
290 config.update_host(
291 "myserver",
292 &HostEntry {
293 alias: "myserver".to_string(),
294 hostname: "10.0.0.2".to_string(),
295 user: "admin".to_string(),
296 port: 22,
297 ..Default::default()
298 },
299 );
300 let output = config.serialize();
301 assert!(output.contains("HostName 10.0.0.2"));
302 assert!(output.contains("ForwardAgent yes"));
303 assert!(output.contains("LocalForward 8080 localhost:80"));
304 assert!(output.contains("Compression yes"));
305 }
306}