purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 return Ok(());
19 }
20 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23 let _lock =
25 fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27 if self.path.exists() {
29 self.create_backup()
30 .context("Failed to create backup of SSH config")?;
31 self.prune_backups(5).ok();
32 }
33
34 let content = self.serialize();
35
36 fs_util::atomic_write(&target_path, content.as_bytes())
37 .map_err(|err| {
38 error!(
39 "[purple] SSH config write failed: {}: {err}",
40 target_path.display()
41 );
42 err
43 })
44 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46 Ok(())
48 }
49
50 pub fn serialize(&self) -> String {
53 let mut lines = Vec::new();
54
55 for element in &self.elements {
56 match element {
57 ConfigElement::GlobalLine(line) => {
58 lines.push(line.clone());
59 }
60 ConfigElement::HostBlock(block) => {
61 lines.push(block.raw_host_line.clone());
62 for directive in &block.directives {
63 lines.push(directive.raw_line.clone());
64 }
65 }
66 ConfigElement::Include(include) => {
67 lines.push(include.raw_line.clone());
68 }
69 }
70 }
71
72 let mut collapsed = Vec::with_capacity(lines.len());
74 let mut prev_blank = false;
75 for line in lines {
76 let is_blank = line.trim().is_empty();
77 if is_blank && prev_blank {
78 continue;
79 }
80 prev_blank = is_blank;
81 collapsed.push(line);
82 }
83
84 let line_ending = if self.crlf { "\r\n" } else { "\n" };
85 let mut result = String::new();
86 if self.bom {
88 result.push('\u{FEFF}');
89 }
90 for line in &collapsed {
91 result.push_str(line);
92 result.push_str(line_ending);
93 }
94 if collapsed.is_empty() {
97 result.push_str(line_ending);
98 }
99 result
100 }
101
102 fn create_backup(&self) -> Result<()> {
105 let timestamp = SystemTime::now()
106 .duration_since(SystemTime::UNIX_EPOCH)
107 .unwrap_or_default()
108 .as_millis();
109 let backup_name = format!(
110 "{}.bak.{}",
111 self.path.file_name().unwrap_or_default().to_string_lossy(),
112 timestamp
113 );
114 let backup_path = self.path.with_file_name(backup_name);
115 fs::copy(&self.path, &backup_path).with_context(|| {
116 format!(
117 "Failed to copy {} to {}",
118 self.path.display(),
119 backup_path.display()
120 )
121 })?;
122
123 #[cfg(unix)]
125 {
126 use std::os::unix::fs::PermissionsExt;
127 if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
128 debug!(
129 "[config] Failed to set backup permissions on {}: {e}",
130 backup_path.display()
131 );
132 }
133 }
134
135 Ok(())
136 }
137
138 fn prune_backups(&self, keep: usize) -> Result<()> {
140 let parent = self.path.parent().context("No parent directory")?;
141 let prefix = format!(
142 "{}.bak.",
143 self.path.file_name().unwrap_or_default().to_string_lossy()
144 );
145 let mut backups: Vec<_> = fs::read_dir(parent)?
146 .filter_map(|e| e.ok())
147 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148 .collect();
149 backups.sort_by_key(|e| e.file_name());
150 if backups.len() > keep {
151 for old in &backups[..backups.len() - keep] {
152 if let Err(e) = fs::remove_file(old.path()) {
153 debug!(
154 "[config] Failed to prune old backup {}: {e}",
155 old.path().display()
156 );
157 }
158 }
159 }
160 Ok(())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::ssh_config::model::HostEntry;
168
169 fn parse_str(content: &str) -> SshConfigFile {
170 SshConfigFile {
171 elements: SshConfigFile::parse_content(content),
172 path: tempfile::tempdir()
173 .expect("tempdir")
174 .keep()
175 .join("test_config"),
176 crlf: content.contains("\r\n"),
177 bom: false,
178 }
179 }
180
181 #[test]
182 fn test_round_trip_basic() {
183 let content = "\
184Host myserver
185 HostName 192.168.1.10
186 User admin
187 Port 2222
188";
189 let config = parse_str(content);
190 assert_eq!(config.serialize(), content);
191 }
192
193 #[test]
194 fn test_round_trip_with_comments() {
195 let content = "\
196# My SSH config
197# Generated by hand
198
199Host alpha
200 HostName alpha.example.com
201 # Deploy user
202 User deploy
203
204Host beta
205 HostName beta.example.com
206 User root
207";
208 let config = parse_str(content);
209 assert_eq!(config.serialize(), content);
210 }
211
212 #[test]
213 fn test_round_trip_with_globals_and_wildcards() {
214 let content = "\
215# Global settings
216Host *
217 ServerAliveInterval 60
218 ServerAliveCountMax 3
219
220Host production
221 HostName prod.example.com
222 User deployer
223 IdentityFile ~/.ssh/prod_key
224";
225 let config = parse_str(content);
226 assert_eq!(config.serialize(), content);
227 }
228
229 #[test]
230 fn test_add_host_serializes() {
231 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
232 config.add_host(&HostEntry {
233 alias: "newhost".to_string(),
234 hostname: "10.0.0.2".to_string(),
235 user: "admin".to_string(),
236 port: 22,
237 ..Default::default()
238 });
239 let output = config.serialize();
240 assert!(output.contains("Host newhost"));
241 assert!(output.contains("HostName 10.0.0.2"));
242 assert!(output.contains("User admin"));
243 assert!(!output.contains("Port 22"));
245 }
246
247 #[test]
248 fn test_delete_host_serializes() {
249 let content = "\
250Host alpha
251 HostName alpha.example.com
252
253Host beta
254 HostName beta.example.com
255";
256 let mut config = parse_str(content);
257 config.delete_host("alpha");
258 let output = config.serialize();
259 assert!(!output.contains("Host alpha"));
260 assert!(output.contains("Host beta"));
261 }
262
263 #[test]
264 fn test_update_host_serializes() {
265 let content = "\
266Host myserver
267 HostName 10.0.0.1
268 User old_user
269";
270 let mut config = parse_str(content);
271 config.update_host(
272 "myserver",
273 &HostEntry {
274 alias: "myserver".to_string(),
275 hostname: "10.0.0.2".to_string(),
276 user: "new_user".to_string(),
277 port: 22,
278 ..Default::default()
279 },
280 );
281 let output = config.serialize();
282 assert!(output.contains("HostName 10.0.0.2"));
283 assert!(output.contains("User new_user"));
284 assert!(!output.contains("old_user"));
285 }
286
287 #[test]
288 fn test_update_host_preserves_unknown_directives() {
289 let content = "\
290Host myserver
291 HostName 10.0.0.1
292 User admin
293 ForwardAgent yes
294 LocalForward 8080 localhost:80
295 Compression yes
296";
297 let mut config = parse_str(content);
298 config.update_host(
299 "myserver",
300 &HostEntry {
301 alias: "myserver".to_string(),
302 hostname: "10.0.0.2".to_string(),
303 user: "admin".to_string(),
304 port: 22,
305 ..Default::default()
306 },
307 );
308 let output = config.serialize();
309 assert!(output.contains("HostName 10.0.0.2"));
310 assert!(output.contains("ForwardAgent yes"));
311 assert!(output.contains("LocalForward 8080 localhost:80"));
312 assert!(output.contains("Compression yes"));
313 }
314}