purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 debug!("[purple] ssh_config.write skipped (demo mode)");
19 return Ok(());
20 }
21 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
23 debug!(
24 "[config] ssh_config.write: target={} elements={}",
25 target_path.display(),
26 self.elements.len()
27 );
28
29 let _lock = fs_util::FileLock::acquire(&target_path)
31 .inspect_err(|e| {
32 debug!(
33 "[config] ssh_config.write: lock acquire failed: {} ({})",
34 target_path.display(),
35 e
36 );
37 })
38 .context("Failed to acquire config lock")?;
39
40 if self.path.exists() {
42 self.create_backup()
43 .context("Failed to create backup of SSH config")?;
44 self.prune_backups(5).ok();
45 }
46
47 let content = self.serialize();
48
49 fs_util::atomic_write(&target_path, content.as_bytes())
50 .map_err(|err| {
51 error!(
52 "[purple] SSH config write failed: {}: {err}",
53 target_path.display()
54 );
55 err
56 })
57 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
58
59 debug!(
60 "[config] ssh_config.write: wrote {} bytes to {}",
61 content.len(),
62 target_path.display()
63 );
64
65 Ok(())
67 }
68
69 pub fn serialize(&self) -> String {
72 let mut lines = Vec::new();
73
74 for element in &self.elements {
75 match element {
76 ConfigElement::GlobalLine(line) => {
77 lines.push(line.clone());
78 }
79 ConfigElement::HostBlock(block) => {
80 lines.push(block.raw_host_line.clone());
81 for directive in &block.directives {
82 lines.push(directive.raw_line.clone());
83 }
84 }
85 ConfigElement::Include(include) => {
86 lines.push(include.raw_line.clone());
87 }
88 }
89 }
90
91 let mut collapsed = Vec::with_capacity(lines.len());
93 let mut prev_blank = false;
94 for line in lines {
95 let is_blank = line.trim().is_empty();
96 if is_blank && prev_blank {
97 continue;
98 }
99 prev_blank = is_blank;
100 collapsed.push(line);
101 }
102
103 let line_ending = if self.crlf { "\r\n" } else { "\n" };
104 let mut result = String::new();
105 if self.bom {
107 result.push('\u{FEFF}');
108 }
109 for line in &collapsed {
110 result.push_str(line);
111 result.push_str(line_ending);
112 }
113 if collapsed.is_empty() {
116 result.push_str(line_ending);
117 }
118 result
119 }
120
121 fn create_backup(&self) -> Result<()> {
124 let timestamp = SystemTime::now()
125 .duration_since(SystemTime::UNIX_EPOCH)
126 .unwrap_or_default()
127 .as_millis();
128 let backup_name = format!(
129 "{}.bak.{}",
130 self.path.file_name().unwrap_or_default().to_string_lossy(),
131 timestamp
132 );
133 let backup_path = self.path.with_file_name(backup_name);
134 fs::copy(&self.path, &backup_path).with_context(|| {
135 format!(
136 "Failed to copy {} to {}",
137 self.path.display(),
138 backup_path.display()
139 )
140 })?;
141
142 #[cfg(unix)]
144 {
145 use std::os::unix::fs::PermissionsExt;
146 if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
147 debug!(
148 "[config] Failed to set backup permissions on {}: {e}",
149 backup_path.display()
150 );
151 }
152 }
153
154 Ok(())
155 }
156
157 fn prune_backups(&self, keep: usize) -> Result<()> {
159 let parent = self.path.parent().context("No parent directory")?;
160 let prefix = format!(
161 "{}.bak.",
162 self.path.file_name().unwrap_or_default().to_string_lossy()
163 );
164 let mut backups: Vec<_> = fs::read_dir(parent)?
165 .filter_map(|e| e.ok())
166 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
167 .collect();
168 backups.sort_by_key(|e| e.file_name());
169 if backups.len() > keep {
170 for old in &backups[..backups.len() - keep] {
171 if let Err(e) = fs::remove_file(old.path()) {
172 debug!(
173 "[config] Failed to prune old backup {}: {e}",
174 old.path().display()
175 );
176 }
177 }
178 }
179 Ok(())
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::ssh_config::model::HostEntry;
187
188 fn parse_str(content: &str) -> SshConfigFile {
189 SshConfigFile {
190 elements: SshConfigFile::parse_content(content),
191 path: tempfile::tempdir()
192 .expect("tempdir")
193 .keep()
194 .join("test_config"),
195 crlf: content.contains("\r\n"),
196 bom: false,
197 }
198 }
199
200 #[test]
201 fn test_round_trip_basic() {
202 let content = "\
203Host myserver
204 HostName 192.168.1.10
205 User admin
206 Port 2222
207";
208 let config = parse_str(content);
209 assert_eq!(config.serialize(), content);
210 }
211
212 #[test]
213 fn test_round_trip_with_comments() {
214 let content = "\
215# My SSH config
216# Generated by hand
217
218Host alpha
219 HostName alpha.example.com
220 # Deploy user
221 User deploy
222
223Host beta
224 HostName beta.example.com
225 User root
226";
227 let config = parse_str(content);
228 assert_eq!(config.serialize(), content);
229 }
230
231 #[test]
232 fn test_round_trip_with_globals_and_wildcards() {
233 let content = "\
234# Global settings
235Host *
236 ServerAliveInterval 60
237 ServerAliveCountMax 3
238
239Host production
240 HostName prod.example.com
241 User deployer
242 IdentityFile ~/.ssh/prod_key
243";
244 let config = parse_str(content);
245 assert_eq!(config.serialize(), content);
246 }
247
248 #[test]
249 fn test_add_host_serializes() {
250 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
251 config.add_host(&HostEntry {
252 alias: "newhost".to_string(),
253 hostname: "10.0.0.2".to_string(),
254 user: "admin".to_string(),
255 port: 22,
256 ..Default::default()
257 });
258 let output = config.serialize();
259 assert!(output.contains("Host newhost"));
260 assert!(output.contains("HostName 10.0.0.2"));
261 assert!(output.contains("User admin"));
262 assert!(!output.contains("Port 22"));
264 }
265
266 #[test]
267 fn test_delete_host_serializes() {
268 let content = "\
269Host alpha
270 HostName alpha.example.com
271
272Host beta
273 HostName beta.example.com
274";
275 let mut config = parse_str(content);
276 config.delete_host("alpha");
277 let output = config.serialize();
278 assert!(!output.contains("Host alpha"));
279 assert!(output.contains("Host beta"));
280 }
281
282 #[test]
283 fn test_update_host_serializes() {
284 let content = "\
285Host myserver
286 HostName 10.0.0.1
287 User old_user
288";
289 let mut config = parse_str(content);
290 config.update_host(
291 "myserver",
292 &HostEntry {
293 alias: "myserver".to_string(),
294 hostname: "10.0.0.2".to_string(),
295 user: "new_user".to_string(),
296 port: 22,
297 ..Default::default()
298 },
299 );
300 let output = config.serialize();
301 assert!(output.contains("HostName 10.0.0.2"));
302 assert!(output.contains("User new_user"));
303 assert!(!output.contains("old_user"));
304 }
305
306 #[test]
307 fn test_update_host_preserves_unknown_directives() {
308 let content = "\
309Host myserver
310 HostName 10.0.0.1
311 User admin
312 ForwardAgent yes
313 LocalForward 8080 localhost:80
314 Compression yes
315";
316 let mut config = parse_str(content);
317 config.update_host(
318 "myserver",
319 &HostEntry {
320 alias: "myserver".to_string(),
321 hostname: "10.0.0.2".to_string(),
322 user: "admin".to_string(),
323 port: 22,
324 ..Default::default()
325 },
326 );
327 let output = config.serialize();
328 assert!(output.contains("HostName 10.0.0.2"));
329 assert!(output.contains("ForwardAgent yes"));
330 assert!(output.contains("LocalForward 8080 localhost:80"));
331 assert!(output.contains("Compression yes"));
332 }
333}