purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5use log::error;
6
7use super::model::{ConfigElement, SshConfigFile};
8use crate::fs_util;
9
10impl SshConfigFile {
11 pub fn write(&self) -> Result<()> {
17 if crate::demo_flag::is_demo() {
18 return Ok(());
19 }
20 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
22
23 let _lock =
25 fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
26
27 if self.path.exists() {
29 self.create_backup()
30 .context("Failed to create backup of SSH config")?;
31 self.prune_backups(5).ok();
32 }
33
34 let content = self.serialize();
35
36 fs_util::atomic_write(&target_path, content.as_bytes())
37 .map_err(|err| {
38 error!(
39 "[purple] SSH config write failed: {}: {err}",
40 target_path.display()
41 );
42 err
43 })
44 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
45
46 Ok(())
48 }
49
50 pub fn serialize(&self) -> String {
53 let mut lines = Vec::new();
54
55 for element in &self.elements {
56 match element {
57 ConfigElement::GlobalLine(line) => {
58 lines.push(line.clone());
59 }
60 ConfigElement::HostBlock(block) => {
61 lines.push(block.raw_host_line.clone());
62 for directive in &block.directives {
63 lines.push(directive.raw_line.clone());
64 }
65 }
66 ConfigElement::Include(include) => {
67 lines.push(include.raw_line.clone());
68 }
69 }
70 }
71
72 let mut collapsed = Vec::with_capacity(lines.len());
74 let mut prev_blank = false;
75 for line in lines {
76 let is_blank = line.trim().is_empty();
77 if is_blank && prev_blank {
78 continue;
79 }
80 prev_blank = is_blank;
81 collapsed.push(line);
82 }
83
84 let line_ending = if self.crlf { "\r\n" } else { "\n" };
85 let mut result = String::new();
86 if self.bom {
88 result.push('\u{FEFF}');
89 }
90 for line in &collapsed {
91 result.push_str(line);
92 result.push_str(line_ending);
93 }
94 if collapsed.is_empty() {
97 result.push_str(line_ending);
98 }
99 result
100 }
101
102 fn create_backup(&self) -> Result<()> {
105 let timestamp = SystemTime::now()
106 .duration_since(SystemTime::UNIX_EPOCH)
107 .unwrap_or_default()
108 .as_millis();
109 let backup_name = format!(
110 "{}.bak.{}",
111 self.path.file_name().unwrap_or_default().to_string_lossy(),
112 timestamp
113 );
114 let backup_path = self.path.with_file_name(backup_name);
115 fs::copy(&self.path, &backup_path).with_context(|| {
116 format!(
117 "Failed to copy {} to {}",
118 self.path.display(),
119 backup_path.display()
120 )
121 })?;
122
123 #[cfg(unix)]
125 {
126 use std::os::unix::fs::PermissionsExt;
127 let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
128 }
129
130 Ok(())
131 }
132
133 fn prune_backups(&self, keep: usize) -> Result<()> {
135 let parent = self.path.parent().context("No parent directory")?;
136 let prefix = format!(
137 "{}.bak.",
138 self.path.file_name().unwrap_or_default().to_string_lossy()
139 );
140 let mut backups: Vec<_> = fs::read_dir(parent)?
141 .filter_map(|e| e.ok())
142 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
143 .collect();
144 backups.sort_by_key(|e| e.file_name());
145 if backups.len() > keep {
146 for old in &backups[..backups.len() - keep] {
147 let _ = fs::remove_file(old.path());
148 }
149 }
150 Ok(())
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::ssh_config::model::HostEntry;
158 use std::path::PathBuf;
159
160 fn parse_str(content: &str) -> SshConfigFile {
161 SshConfigFile {
162 elements: SshConfigFile::parse_content(content),
163 path: PathBuf::from("/tmp/test_config"),
164 crlf: content.contains("\r\n"),
165 bom: false,
166 }
167 }
168
169 #[test]
170 fn test_round_trip_basic() {
171 let content = "\
172Host myserver
173 HostName 192.168.1.10
174 User admin
175 Port 2222
176";
177 let config = parse_str(content);
178 assert_eq!(config.serialize(), content);
179 }
180
181 #[test]
182 fn test_round_trip_with_comments() {
183 let content = "\
184# My SSH config
185# Generated by hand
186
187Host alpha
188 HostName alpha.example.com
189 # Deploy user
190 User deploy
191
192Host beta
193 HostName beta.example.com
194 User root
195";
196 let config = parse_str(content);
197 assert_eq!(config.serialize(), content);
198 }
199
200 #[test]
201 fn test_round_trip_with_globals_and_wildcards() {
202 let content = "\
203# Global settings
204Host *
205 ServerAliveInterval 60
206 ServerAliveCountMax 3
207
208Host production
209 HostName prod.example.com
210 User deployer
211 IdentityFile ~/.ssh/prod_key
212";
213 let config = parse_str(content);
214 assert_eq!(config.serialize(), content);
215 }
216
217 #[test]
218 fn test_add_host_serializes() {
219 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
220 config.add_host(&HostEntry {
221 alias: "newhost".to_string(),
222 hostname: "10.0.0.2".to_string(),
223 user: "admin".to_string(),
224 port: 22,
225 ..Default::default()
226 });
227 let output = config.serialize();
228 assert!(output.contains("Host newhost"));
229 assert!(output.contains("HostName 10.0.0.2"));
230 assert!(output.contains("User admin"));
231 assert!(!output.contains("Port 22"));
233 }
234
235 #[test]
236 fn test_delete_host_serializes() {
237 let content = "\
238Host alpha
239 HostName alpha.example.com
240
241Host beta
242 HostName beta.example.com
243";
244 let mut config = parse_str(content);
245 config.delete_host("alpha");
246 let output = config.serialize();
247 assert!(!output.contains("Host alpha"));
248 assert!(output.contains("Host beta"));
249 }
250
251 #[test]
252 fn test_update_host_serializes() {
253 let content = "\
254Host myserver
255 HostName 10.0.0.1
256 User old_user
257";
258 let mut config = parse_str(content);
259 config.update_host(
260 "myserver",
261 &HostEntry {
262 alias: "myserver".to_string(),
263 hostname: "10.0.0.2".to_string(),
264 user: "new_user".to_string(),
265 port: 22,
266 ..Default::default()
267 },
268 );
269 let output = config.serialize();
270 assert!(output.contains("HostName 10.0.0.2"));
271 assert!(output.contains("User new_user"));
272 assert!(!output.contains("old_user"));
273 }
274
275 #[test]
276 fn test_update_host_preserves_unknown_directives() {
277 let content = "\
278Host myserver
279 HostName 10.0.0.1
280 User admin
281 ForwardAgent yes
282 LocalForward 8080 localhost:80
283 Compression yes
284";
285 let mut config = parse_str(content);
286 config.update_host(
287 "myserver",
288 &HostEntry {
289 alias: "myserver".to_string(),
290 hostname: "10.0.0.2".to_string(),
291 user: "admin".to_string(),
292 port: 22,
293 ..Default::default()
294 },
295 );
296 let output = config.serialize();
297 assert!(output.contains("HostName 10.0.0.2"));
298 assert!(output.contains("ForwardAgent yes"));
299 assert!(output.contains("LocalForward 8080 localhost:80"));
300 assert!(output.contains("Compression yes"));
301 }
302}