purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::{debug, error};
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 return Ok(());
19 }
20 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23 let _lock =
25 fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27 if self.path.exists() {
29 self.create_backup()
30 .context("Failed to create backup of SSH config")?;
31 self.prune_backups(5).ok();
32 }
33
34 let content = self.serialize();
35
36 fs_util::atomic_write(&target_path, content.as_bytes())
37 .map_err(|err| {
38 error!(
39 "[purple] SSH config write failed: {}: {err}",
40 target_path.display()
41 );
42 err
43 })
44 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46 Ok(())
48 }
49
50 pub fn serialize(&self) -> String {
53 let mut lines = Vec::new();
54
55 for element in &self.elements {
56 match element {
57 ConfigElement::GlobalLine(line) => {
58 lines.push(line.clone());
59 }
60 ConfigElement::HostBlock(block) => {
61 lines.push(block.raw_host_line.clone());
62 for directive in &block.directives {
63 lines.push(directive.raw_line.clone());
64 }
65 }
66 ConfigElement::Include(include) => {
67 lines.push(include.raw_line.clone());
68 }
69 }
70 }
71
72 let mut collapsed = Vec::with_capacity(lines.len());
74 let mut prev_blank = false;
75 for line in lines {
76 let is_blank = line.trim().is_empty();
77 if is_blank && prev_blank {
78 continue;
79 }
80 prev_blank = is_blank;
81 collapsed.push(line);
82 }
83
84 let line_ending = if self.crlf { "\r\n" } else { "\n" };
85 let mut result = String::new();
86 if self.bom {
88 result.push('\u{FEFF}');
89 }
90 for line in &collapsed {
91 result.push_str(line);
92 result.push_str(line_ending);
93 }
94 if collapsed.is_empty() {
97 result.push_str(line_ending);
98 }
99 result
100 }
101
102 fn create_backup(&self) -> Result<()> {
105 let timestamp = SystemTime::now()
106 .duration_since(SystemTime::UNIX_EPOCH)
107 .unwrap_or_default()
108 .as_millis();
109 let backup_name = format!(
110 "{}.bak.{}",
111 self.path.file_name().unwrap_or_default().to_string_lossy(),
112 timestamp
113 );
114 let backup_path = self.path.with_file_name(backup_name);
115 fs::copy(&self.path, &backup_path).with_context(|| {
116 format!(
117 "Failed to copy {} to {}",
118 self.path.display(),
119 backup_path.display()
120 )
121 })?;
122
123 #[cfg(unix)]
125 {
126 use std::os::unix::fs::PermissionsExt;
127 if let Err(e) = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600)) {
128 debug!(
129 "[config] Failed to set backup permissions on {}: {e}",
130 backup_path.display()
131 );
132 }
133 }
134
135 Ok(())
136 }
137
138 fn prune_backups(&self, keep: usize) -> Result<()> {
140 let parent = self.path.parent().context("No parent directory")?;
141 let prefix = format!(
142 "{}.bak.",
143 self.path.file_name().unwrap_or_default().to_string_lossy()
144 );
145 let mut backups: Vec<_> = fs::read_dir(parent)?
146 .filter_map(|e| e.ok())
147 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
148 .collect();
149 backups.sort_by_key(|e| e.file_name());
150 if backups.len() > keep {
151 for old in &backups[..backups.len() - keep] {
152 if let Err(e) = fs::remove_file(old.path()) {
153 debug!(
154 "[config] Failed to prune old backup {}: {e}",
155 old.path().display()
156 );
157 }
158 }
159 }
160 Ok(())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::ssh_config::model::HostEntry;
168 use std::path::PathBuf;
169
170 fn parse_str(content: &str) -> SshConfigFile {
171 SshConfigFile {
172 elements: SshConfigFile::parse_content(content),
173 path: PathBuf::from("/tmp/test_config"),
174 crlf: content.contains("\r\n"),
175 bom: false,
176 }
177 }
178
179 #[test]
180 fn test_round_trip_basic() {
181 let content = "\
182Host myserver
183 HostName 192.168.1.10
184 User admin
185 Port 2222
186";
187 let config = parse_str(content);
188 assert_eq!(config.serialize(), content);
189 }
190
191 #[test]
192 fn test_round_trip_with_comments() {
193 let content = "\
194# My SSH config
195# Generated by hand
196
197Host alpha
198 HostName alpha.example.com
199 # Deploy user
200 User deploy
201
202Host beta
203 HostName beta.example.com
204 User root
205";
206 let config = parse_str(content);
207 assert_eq!(config.serialize(), content);
208 }
209
210 #[test]
211 fn test_round_trip_with_globals_and_wildcards() {
212 let content = "\
213# Global settings
214Host *
215 ServerAliveInterval 60
216 ServerAliveCountMax 3
217
218Host production
219 HostName prod.example.com
220 User deployer
221 IdentityFile ~/.ssh/prod_key
222";
223 let config = parse_str(content);
224 assert_eq!(config.serialize(), content);
225 }
226
227 #[test]
228 fn test_add_host_serializes() {
229 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
230 config.add_host(&HostEntry {
231 alias: "newhost".to_string(),
232 hostname: "10.0.0.2".to_string(),
233 user: "admin".to_string(),
234 port: 22,
235 ..Default::default()
236 });
237 let output = config.serialize();
238 assert!(output.contains("Host newhost"));
239 assert!(output.contains("HostName 10.0.0.2"));
240 assert!(output.contains("User admin"));
241 assert!(!output.contains("Port 22"));
243 }
244
245 #[test]
246 fn test_delete_host_serializes() {
247 let content = "\
248Host alpha
249 HostName alpha.example.com
250
251Host beta
252 HostName beta.example.com
253";
254 let mut config = parse_str(content);
255 config.delete_host("alpha");
256 let output = config.serialize();
257 assert!(!output.contains("Host alpha"));
258 assert!(output.contains("Host beta"));
259 }
260
261 #[test]
262 fn test_update_host_serializes() {
263 let content = "\
264Host myserver
265 HostName 10.0.0.1
266 User old_user
267";
268 let mut config = parse_str(content);
269 config.update_host(
270 "myserver",
271 &HostEntry {
272 alias: "myserver".to_string(),
273 hostname: "10.0.0.2".to_string(),
274 user: "new_user".to_string(),
275 port: 22,
276 ..Default::default()
277 },
278 );
279 let output = config.serialize();
280 assert!(output.contains("HostName 10.0.0.2"));
281 assert!(output.contains("User new_user"));
282 assert!(!output.contains("old_user"));
283 }
284
285 #[test]
286 fn test_update_host_preserves_unknown_directives() {
287 let content = "\
288Host myserver
289 HostName 10.0.0.1
290 User admin
291 ForwardAgent yes
292 LocalForward 8080 localhost:80
293 Compression yes
294";
295 let mut config = parse_str(content);
296 config.update_host(
297 "myserver",
298 &HostEntry {
299 alias: "myserver".to_string(),
300 hostname: "10.0.0.2".to_string(),
301 user: "admin".to_string(),
302 port: 22,
303 ..Default::default()
304 },
305 );
306 let output = config.serialize();
307 assert!(output.contains("HostName 10.0.0.2"));
308 assert!(output.contains("ForwardAgent yes"));
309 assert!(output.contains("LocalForward 8080 localhost:80"));
310 assert!(output.contains("Compression yes"));
311 }
312}