1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 debug!("[purple] ssh_config.write skipped (demo mode)");
19 return Ok(());
20 }
21 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23 debug!(
24 "[config] ssh_config.write: target={} elements={}",
25 target_path.display(),
26 self.elements.len()
27 );
28
29 let _lock = fs_util::FileLock::acquire(&target_path)
31 .inspect_err(|e| {
32 debug!(
33 "[config] ssh_config.write: lock acquire failed: {} ({})",
34 target_path.display(),
35 e
36 );
37 })
38 .context("Failed to acquire config lock")?;
39
40 if target_path.exists() {
47 self.create_backup(&target_path)
48 .context("Failed to create backup of SSH config")?;
49 self.prune_backups(&target_path, 5).ok();
50 }
51
52 let raw = self.serialized_lines();
57 let normalized = ensure_block_separators(&raw);
58 let healed = normalized.len() - raw.len();
59 if healed > 0 {
60 debug!(
61 "[config] ssh_config.write: inserted {healed} block separator(s) in {}",
62 target_path.display()
63 );
64 }
65 let content = self.lines_to_string(&normalized);
66
67 fs_util::atomic_write(&target_path, content.as_bytes())
68 .map_err(|err| {
69 error!(
70 "[purple] SSH config write failed: {}: {err}",
71 target_path.display()
72 );
73 err
74 })
75 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
76
77 debug!(
78 "[config] ssh_config.write: wrote {} bytes to {}",
79 content.len(),
80 target_path.display()
81 );
82
83 Ok(())
85 }
86
87 pub fn serialize(&self) -> String {
91 self.lines_to_string(&self.serialized_lines())
92 }
93
94 fn serialized_lines(&self) -> Vec<String> {
97 let mut lines = Vec::new();
98
99 for element in &self.elements {
100 match element {
101 ConfigElement::GlobalLine(line) => {
102 lines.push(line.clone());
103 }
104 ConfigElement::HostBlock(block) => {
105 lines.push(block.raw_host_line.clone());
106 for directive in &block.directives {
107 lines.push(directive.raw_line.clone());
108 }
109 }
110 ConfigElement::Include(include) => {
111 lines.push(include.raw_line.clone());
112 }
113 }
114 }
115
116 let mut collapsed = Vec::with_capacity(lines.len());
118 let mut prev_blank = false;
119 for line in lines {
120 let is_blank = line.trim().is_empty();
121 if is_blank && prev_blank {
122 continue;
123 }
124 prev_blank = is_blank;
125 collapsed.push(line);
126 }
127 collapsed
128 }
129
130 fn lines_to_string(&self, lines: &[String]) -> String {
133 let line_ending = if self.crlf { "\r\n" } else { "\n" };
134 let mut result = String::new();
135 if self.bom {
137 result.push('\u{FEFF}');
138 }
139 for line in lines {
140 result.push_str(line);
141 result.push_str(line_ending);
142 }
143 if lines.is_empty() {
146 result.push_str(line_ending);
147 }
148 result
149 }
150
151 fn create_backup(&self, target_path: &std::path::Path) -> Result<()> {
154 let timestamp = SystemTime::now()
155 .duration_since(SystemTime::UNIX_EPOCH)
156 .unwrap_or_default()
157 .as_millis();
158 let backup_name = format!(
159 "{}.bak.{}",
160 target_path
161 .file_name()
162 .unwrap_or_default()
163 .to_string_lossy(),
164 timestamp
165 );
166 let backup_path = target_path.with_file_name(backup_name);
167 fs::copy(target_path, &backup_path).with_context(|| {
168 format!(
169 "Failed to copy {} to {}",
170 target_path.display(),
171 backup_path.display()
172 )
173 })?;
174
175 #[cfg(unix)]
177 {
178 use std::os::unix::fs::PermissionsExt;
179 if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
180 debug!(
181 "[config] Failed to set backup permissions on {}: {e}",
182 backup_path.display()
183 );
184 }
185 }
186
187 Ok(())
188 }
189
190 fn prune_backups(&self, target_path: &std::path::Path, keep: usize) -> Result<()> {
192 let parent = target_path.parent().context("No parent directory")?;
193 let prefix = format!(
194 "{}.bak.",
195 target_path
196 .file_name()
197 .unwrap_or_default()
198 .to_string_lossy()
199 );
200 let mut backups: Vec<_> = fs::read_dir(parent)?
201 .filter_map(|e| e.ok())
202 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
203 .collect();
204 backups.sort_by_key(|e| {
208 e.metadata()
209 .and_then(|m| m.modified())
210 .unwrap_or(SystemTime::UNIX_EPOCH)
211 });
212 if backups.len() > keep {
213 for old in &backups[..backups.len() - keep] {
214 if let Err(e) = fs::remove_file(old.path()) {
215 debug!(
216 "[config] Failed to prune old backup {}: {e}",
217 old.path().display()
218 );
219 }
220 }
221 }
222 Ok(())
223 }
224}
225
226fn is_block_start(line: &str) -> bool {
229 if line.starts_with(char::is_whitespace) {
230 return false;
231 }
232 match line.split_whitespace().next() {
233 Some(kw) => kw.eq_ignore_ascii_case("Host") || kw.eq_ignore_ascii_case("Match"),
234 None => false,
235 }
236}
237
238fn ensure_block_separators(lines: &[String]) -> Vec<String> {
244 let mut out: Vec<String> = Vec::with_capacity(lines.len() + 4);
245 for line in lines {
246 if is_block_start(line) {
247 if let Some(prev) = out.last() {
248 let prev_blank = prev.trim().is_empty();
249 let prev_top_level_comment =
250 !prev.starts_with(char::is_whitespace) && prev.trim_start().starts_with('#');
251 if !prev_blank && !prev_top_level_comment {
252 out.push(String::new());
253 }
254 }
255 }
256 out.push(line.clone());
257 }
258 out
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::ssh_config::model::HostEntry;
265
266 fn parse_str(content: &str) -> SshConfigFile {
267 SshConfigFile {
268 elements: SshConfigFile::parse_content(content),
269 path: tempfile::tempdir()
270 .expect("tempdir")
271 .keep()
272 .join("test_config"),
273 crlf: crate::ssh_config::parser::detect_crlf_majority(content),
274 bom: false,
275 }
276 }
277
278 #[test]
279 fn test_round_trip_basic() {
280 let content = "\
281Host myserver
282 HostName 192.168.1.10
283 User admin
284 Port 2222
285";
286 let config = parse_str(content);
287 assert_eq!(config.serialize(), content);
288 }
289
290 #[test]
291 fn test_round_trip_with_comments() {
292 let content = "\
293# My SSH config
294# Generated by hand
295
296Host alpha
297 HostName alpha.example.com
298 # Deploy user
299 User deploy
300
301Host beta
302 HostName beta.example.com
303 User root
304";
305 let config = parse_str(content);
306 assert_eq!(config.serialize(), content);
307 }
308
309 #[test]
310 fn test_round_trip_with_globals_and_wildcards() {
311 let content = "\
312# Global settings
313Host *
314 ServerAliveInterval 60
315 ServerAliveCountMax 3
316
317Host production
318 HostName prod.example.com
319 User deployer
320 IdentityFile ~/.ssh/prod_key
321";
322 let config = parse_str(content);
323 assert_eq!(config.serialize(), content);
324 }
325
326 #[test]
327 fn test_add_host_serializes() {
328 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
329 config.add_host(&HostEntry {
330 alias: "newhost".to_string(),
331 hostname: "10.0.0.2".to_string(),
332 user: "admin".to_string(),
333 port: 22,
334 ..Default::default()
335 });
336 let output = config.serialize();
337 assert!(output.contains("Host newhost"));
338 assert!(output.contains("HostName 10.0.0.2"));
339 assert!(output.contains("User admin"));
340 assert!(!output.contains("Port 22"));
342 }
343
344 #[test]
345 fn test_delete_host_serializes() {
346 let content = "\
347Host alpha
348 HostName alpha.example.com
349
350Host beta
351 HostName beta.example.com
352";
353 let mut config = parse_str(content);
354 config.delete_host("alpha");
355 let output = config.serialize();
356 assert!(!output.contains("Host alpha"));
357 assert!(output.contains("Host beta"));
358 }
359
360 #[test]
361 fn test_update_host_serializes() {
362 let content = "\
363Host myserver
364 HostName 10.0.0.1
365 User old_user
366";
367 let mut config = parse_str(content);
368 config.update_host(
369 "myserver",
370 &HostEntry {
371 alias: "myserver".to_string(),
372 hostname: "10.0.0.2".to_string(),
373 user: "new_user".to_string(),
374 port: 22,
375 ..Default::default()
376 },
377 );
378 let output = config.serialize();
379 assert!(output.contains("HostName 10.0.0.2"));
380 assert!(output.contains("User new_user"));
381 assert!(!output.contains("old_user"));
382 }
383
384 #[test]
385 fn test_update_host_preserves_unknown_directives() {
386 let content = "\
387Host myserver
388 HostName 10.0.0.1
389 User admin
390 ForwardAgent yes
391 LocalForward 8080 localhost:80
392 Compression yes
393";
394 let mut config = parse_str(content);
395 config.update_host(
396 "myserver",
397 &HostEntry {
398 alias: "myserver".to_string(),
399 hostname: "10.0.0.2".to_string(),
400 user: "admin".to_string(),
401 port: 22,
402 ..Default::default()
403 },
404 );
405 let output = config.serialize();
406 assert!(output.contains("HostName 10.0.0.2"));
407 assert!(output.contains("ForwardAgent yes"));
408 assert!(output.contains("LocalForward 8080 localhost:80"));
409 assert!(output.contains("Compression yes"));
410 }
411
412 fn lines(s: &[&str]) -> Vec<String> {
413 s.iter().map(|l| (*l).to_string()).collect()
414 }
415
416 #[test]
417 fn ensure_block_separators_splits_glued_hosts() {
418 let input = lines(&["Host a", " HostName 1", "Host b", " HostName 2"]);
419 let out = ensure_block_separators(&input);
420 assert_eq!(
421 out,
422 lines(&["Host a", " HostName 1", "", "Host b", " HostName 2"])
423 );
424 }
425
426 #[test]
427 fn ensure_block_separators_leaves_separated_hosts() {
428 let input = lines(&["Host a", " HostName 1", "", "Host b", " HostName 2"]);
429 let out = ensure_block_separators(&input);
430 assert_eq!(out, input, "already-separated input must be untouched");
431 }
432
433 #[test]
434 fn ensure_block_separators_keeps_group_header_glue() {
435 let input = lines(&["# purple:group DigitalOcean", "Host a", " HostName 1"]);
438 let out = ensure_block_separators(&input);
439 assert_eq!(out, input);
440 }
441
442 #[test]
443 fn ensure_block_separators_splits_three_glued_hosts() {
444 let input = lines(&["Host a", " HostName 1", "Host b", " HostName 2", "Host c"]);
445 let out = ensure_block_separators(&input);
446 assert_eq!(
447 out,
448 lines(&[
449 "Host a",
450 " HostName 1",
451 "",
452 "Host b",
453 " HostName 2",
454 "",
455 "Host c",
456 ])
457 );
458 }
459
460 #[test]
461 fn ensure_block_separators_splits_glued_match_block() {
462 let input = lines(&["Host a", " HostName 1", "Match host b", " User x"]);
463 let out = ensure_block_separators(&input);
464 assert_eq!(
465 out,
466 lines(&["Host a", " HostName 1", "", "Match host b", " User x"])
467 );
468 }
469
470 #[test]
471 fn ensure_block_separators_no_leading_blank() {
472 let input = lines(&["Host a", " HostName 1"]);
473 let out = ensure_block_separators(&input);
474 assert_eq!(out, input, "must not insert a blank before the first block");
475 }
476
477 #[test]
478 fn write_normalization_is_idempotent() {
479 let glued = "Host a\n HostName 1\nhost b\n HostName 2\nMatch host c\n User x\n";
483 let dir = tempfile::tempdir().unwrap();
484 let path = dir.path().join("config");
485 let config = SshConfigFile {
486 elements: SshConfigFile::parse_content(glued),
487 path: path.clone(),
488 crlf: false,
489 bom: false,
490 };
491 config.write().unwrap();
492 let first = std::fs::read_to_string(&path).unwrap();
493
494 let reparsed = SshConfigFile {
495 elements: SshConfigFile::parse_content(&first),
496 path: path.clone(),
497 crlf: false,
498 bom: false,
499 };
500 reparsed.write().unwrap();
501 let second = std::fs::read_to_string(&path).unwrap();
502 assert_eq!(first, second, "write normalization must be idempotent");
503 assert!(!first.contains("\n\n\n"), "no triple blanks:\n{first}");
504 }
505
506 #[test]
507 fn ensure_block_separators_case_insensitive_keyword() {
508 let input = lines(&["host a", " HostName 1", "MATCH host b", " User x"]);
511 let out = ensure_block_separators(&input);
512 assert_eq!(
513 out,
514 lines(&["host a", " HostName 1", "", "MATCH host b", " User x"])
515 );
516 }
517
518 #[test]
519 fn write_normalizes_glued_hosts_on_disk_serialize_stays_pure() {
520 let glued = "Host a\n HostName 1.1.1.1\nHost b\n HostName 2.2.2.2\n";
523 let dir = tempfile::tempdir().unwrap();
524 let path = dir.path().join("config");
525 let config = SshConfigFile {
526 elements: SshConfigFile::parse_content(glued),
527 path: path.clone(),
528 crlf: false,
529 bom: false,
530 };
531
532 assert_eq!(config.serialize(), glued);
534
535 config.write().unwrap();
537 let on_disk = std::fs::read_to_string(&path).unwrap();
538 assert_eq!(
539 on_disk,
540 "Host a\n HostName 1.1.1.1\n\nHost b\n HostName 2.2.2.2\n"
541 );
542 }
543
544 #[test]
545 fn write_normalizes_glued_hosts_preserves_crlf() {
546 let glued = "Host a\r\n HostName 1.1.1.1\r\nHost b\r\n HostName 2.2.2.2\r\n";
547 let dir = tempfile::tempdir().unwrap();
548 let path = dir.path().join("config");
549 let config = SshConfigFile {
550 elements: SshConfigFile::parse_content(glued),
551 path: path.clone(),
552 crlf: true,
553 bom: false,
554 };
555 config.write().unwrap();
556 let on_disk = std::fs::read_to_string(&path).unwrap();
557 assert_eq!(
558 on_disk,
559 "Host a\r\n HostName 1.1.1.1\r\n\r\nHost b\r\n HostName 2.2.2.2\r\n"
560 );
561 }
562}