purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 debug!("[purple] ssh_config.write skipped (demo mode)");
19 return Ok(());
20 }
21 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23 debug!(
24 "[config] ssh_config.write: target={} elements={}",
25 target_path.display(),
26 self.elements.len()
27 );
28
29 let _lock = fs_util::FileLock::acquire(&target_path)
31 .inspect_err(|e| {
32 debug!(
33 "[config] ssh_config.write: lock acquire failed: {} ({})",
34 target_path.display(),
35 e
36 );
37 })
38 .context("Failed to acquire config lock")?;
39
40 if target_path.exists() {
47 self.create_backup(&target_path)
48 .context("Failed to create backup of SSH config")?;
49 self.prune_backups(&target_path, 5).ok();
50 }
51
52 let content = self.serialize();
53
54 fs_util::atomic_write(&target_path, content.as_bytes())
55 .map_err(|err| {
56 error!(
57 "[purple] SSH config write failed: {}: {err}",
58 target_path.display()
59 );
60 err
61 })
62 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
63
64 debug!(
65 "[config] ssh_config.write: wrote {} bytes to {}",
66 content.len(),
67 target_path.display()
68 );
69
70 Ok(())
72 }
73
74 pub fn serialize(&self) -> String {
77 let mut lines = Vec::new();
78
79 for element in &self.elements {
80 match element {
81 ConfigElement::GlobalLine(line) => {
82 lines.push(line.clone());
83 }
84 ConfigElement::HostBlock(block) => {
85 lines.push(block.raw_host_line.clone());
86 for directive in &block.directives {
87 lines.push(directive.raw_line.clone());
88 }
89 }
90 ConfigElement::Include(include) => {
91 lines.push(include.raw_line.clone());
92 }
93 }
94 }
95
96 let mut collapsed = Vec::with_capacity(lines.len());
98 let mut prev_blank = false;
99 for line in lines {
100 let is_blank = line.trim().is_empty();
101 if is_blank && prev_blank {
102 continue;
103 }
104 prev_blank = is_blank;
105 collapsed.push(line);
106 }
107
108 let line_ending = if self.crlf { "\r\n" } else { "\n" };
109 let mut result = String::new();
110 if self.bom {
112 result.push('\u{FEFF}');
113 }
114 for line in &collapsed {
115 result.push_str(line);
116 result.push_str(line_ending);
117 }
118 if collapsed.is_empty() {
121 result.push_str(line_ending);
122 }
123 result
124 }
125
126 fn create_backup(&self, target_path: &std::path::Path) -> Result<()> {
129 let timestamp = SystemTime::now()
130 .duration_since(SystemTime::UNIX_EPOCH)
131 .unwrap_or_default()
132 .as_millis();
133 let backup_name = format!(
134 "{}.bak.{}",
135 target_path
136 .file_name()
137 .unwrap_or_default()
138 .to_string_lossy(),
139 timestamp
140 );
141 let backup_path = target_path.with_file_name(backup_name);
142 fs::copy(target_path, &backup_path).with_context(|| {
143 format!(
144 "Failed to copy {} to {}",
145 target_path.display(),
146 backup_path.display()
147 )
148 })?;
149
150 #[cfg(unix)]
152 {
153 use std::os::unix::fs::PermissionsExt;
154 if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
155 debug!(
156 "[config] Failed to set backup permissions on {}: {e}",
157 backup_path.display()
158 );
159 }
160 }
161
162 Ok(())
163 }
164
165 fn prune_backups(&self, target_path: &std::path::Path, keep: usize) -> Result<()> {
167 let parent = target_path.parent().context("No parent directory")?;
168 let prefix = format!(
169 "{}.bak.",
170 target_path
171 .file_name()
172 .unwrap_or_default()
173 .to_string_lossy()
174 );
175 let mut backups: Vec<_> = fs::read_dir(parent)?
176 .filter_map(|e| e.ok())
177 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
178 .collect();
179 backups.sort_by_key(|e| {
183 e.metadata()
184 .and_then(|m| m.modified())
185 .unwrap_or(SystemTime::UNIX_EPOCH)
186 });
187 if backups.len() > keep {
188 for old in &backups[..backups.len() - keep] {
189 if let Err(e) = fs::remove_file(old.path()) {
190 debug!(
191 "[config] Failed to prune old backup {}: {e}",
192 old.path().display()
193 );
194 }
195 }
196 }
197 Ok(())
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::ssh_config::model::HostEntry;
205
206 fn parse_str(content: &str) -> SshConfigFile {
207 SshConfigFile {
208 elements: SshConfigFile::parse_content(content),
209 path: tempfile::tempdir()
210 .expect("tempdir")
211 .keep()
212 .join("test_config"),
213 crlf: crate::ssh_config::parser::detect_crlf_majority(content),
214 bom: false,
215 }
216 }
217
218 #[test]
219 fn test_round_trip_basic() {
220 let content = "\
221Host myserver
222 HostName 192.168.1.10
223 User admin
224 Port 2222
225";
226 let config = parse_str(content);
227 assert_eq!(config.serialize(), content);
228 }
229
230 #[test]
231 fn test_round_trip_with_comments() {
232 let content = "\
233# My SSH config
234# Generated by hand
235
236Host alpha
237 HostName alpha.example.com
238 # Deploy user
239 User deploy
240
241Host beta
242 HostName beta.example.com
243 User root
244";
245 let config = parse_str(content);
246 assert_eq!(config.serialize(), content);
247 }
248
249 #[test]
250 fn test_round_trip_with_globals_and_wildcards() {
251 let content = "\
252# Global settings
253Host *
254 ServerAliveInterval 60
255 ServerAliveCountMax 3
256
257Host production
258 HostName prod.example.com
259 User deployer
260 IdentityFile ~/.ssh/prod_key
261";
262 let config = parse_str(content);
263 assert_eq!(config.serialize(), content);
264 }
265
266 #[test]
267 fn test_add_host_serializes() {
268 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
269 config.add_host(&HostEntry {
270 alias: "newhost".to_string(),
271 hostname: "10.0.0.2".to_string(),
272 user: "admin".to_string(),
273 port: 22,
274 ..Default::default()
275 });
276 let output = config.serialize();
277 assert!(output.contains("Host newhost"));
278 assert!(output.contains("HostName 10.0.0.2"));
279 assert!(output.contains("User admin"));
280 assert!(!output.contains("Port 22"));
282 }
283
284 #[test]
285 fn test_delete_host_serializes() {
286 let content = "\
287Host alpha
288 HostName alpha.example.com
289
290Host beta
291 HostName beta.example.com
292";
293 let mut config = parse_str(content);
294 config.delete_host("alpha");
295 let output = config.serialize();
296 assert!(!output.contains("Host alpha"));
297 assert!(output.contains("Host beta"));
298 }
299
300 #[test]
301 fn test_update_host_serializes() {
302 let content = "\
303Host myserver
304 HostName 10.0.0.1
305 User old_user
306";
307 let mut config = parse_str(content);
308 config.update_host(
309 "myserver",
310 &HostEntry {
311 alias: "myserver".to_string(),
312 hostname: "10.0.0.2".to_string(),
313 user: "new_user".to_string(),
314 port: 22,
315 ..Default::default()
316 },
317 );
318 let output = config.serialize();
319 assert!(output.contains("HostName 10.0.0.2"));
320 assert!(output.contains("User new_user"));
321 assert!(!output.contains("old_user"));
322 }
323
324 #[test]
325 fn test_update_host_preserves_unknown_directives() {
326 let content = "\
327Host myserver
328 HostName 10.0.0.1
329 User admin
330 ForwardAgent yes
331 LocalForward 8080 localhost:80
332 Compression yes
333";
334 let mut config = parse_str(content);
335 config.update_host(
336 "myserver",
337 &HostEntry {
338 alias: "myserver".to_string(),
339 hostname: "10.0.0.2".to_string(),
340 user: "admin".to_string(),
341 port: 22,
342 ..Default::default()
343 },
344 );
345 let output = config.serialize();
346 assert!(output.contains("HostName 10.0.0.2"));
347 assert!(output.contains("ForwardAgent yes"));
348 assert!(output.contains("LocalForward 8080 localhost:80"));
349 assert!(output.contains("Compression yes"));
350 }
351}