purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7use crate::fs_util;
8
9impl SshConfigFile {
10 pub fn write(&self) -> Result<()> {
16 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
18
19 let _lock = fs_util::FileLock::acquire(&target_path)
21 .context("Failed to acquire config lock")?;
22
23 if self.path.exists() {
25 self.create_backup()
26 .context("Failed to create backup of SSH config")?;
27 self.prune_backups(5).ok();
28 }
29
30 let content = self.serialize();
31
32 fs_util::atomic_write(&target_path, content.as_bytes())
33 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
34
35 Ok(())
37 }
38
39 pub fn serialize(&self) -> String {
42 let mut lines = Vec::new();
43
44 for element in &self.elements {
45 match element {
46 ConfigElement::GlobalLine(line) => {
47 lines.push(line.clone());
48 }
49 ConfigElement::HostBlock(block) => {
50 lines.push(block.raw_host_line.clone());
51 for directive in &block.directives {
52 lines.push(directive.raw_line.clone());
53 }
54 }
55 ConfigElement::Include(include) => {
56 lines.push(include.raw_line.clone());
57 }
58 }
59 }
60
61 let mut collapsed = Vec::with_capacity(lines.len());
63 let mut prev_blank = false;
64 for line in lines {
65 let is_blank = line.trim().is_empty();
66 if is_blank && prev_blank {
67 continue;
68 }
69 prev_blank = is_blank;
70 collapsed.push(line);
71 }
72
73 let line_ending = if self.crlf { "\r\n" } else { "\n" };
74 let mut result = String::new();
75 if self.bom {
77 result.push('\u{FEFF}');
78 }
79 for line in &collapsed {
80 result.push_str(line);
81 result.push_str(line_ending);
82 }
83 if collapsed.is_empty() {
86 result.push_str(line_ending);
87 }
88 result
89 }
90
91 fn create_backup(&self) -> Result<()> {
94 let timestamp = SystemTime::now()
95 .duration_since(SystemTime::UNIX_EPOCH)
96 .unwrap_or_default()
97 .as_millis();
98 let backup_name = format!(
99 "{}.bak.{}",
100 self.path.file_name().unwrap_or_default().to_string_lossy(),
101 timestamp
102 );
103 let backup_path = self.path.with_file_name(backup_name);
104 fs::copy(&self.path, &backup_path).with_context(|| {
105 format!(
106 "Failed to copy {} to {}",
107 self.path.display(),
108 backup_path.display()
109 )
110 })?;
111
112 #[cfg(unix)]
114 {
115 use std::os::unix::fs::PermissionsExt;
116 let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
117 }
118
119 Ok(())
120 }
121
122 fn prune_backups(&self, keep: usize) -> Result<()> {
124 let parent = self.path.parent().context("No parent directory")?;
125 let prefix = format!(
126 "{}.bak.",
127 self.path.file_name().unwrap_or_default().to_string_lossy()
128 );
129 let mut backups: Vec<_> = fs::read_dir(parent)?
130 .filter_map(|e| e.ok())
131 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
132 .collect();
133 backups.sort_by_key(|e| e.file_name());
134 if backups.len() > keep {
135 for old in &backups[..backups.len() - keep] {
136 let _ = fs::remove_file(old.path());
137 }
138 }
139 Ok(())
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::ssh_config::model::HostEntry;
147 use std::path::PathBuf;
148
149 fn parse_str(content: &str) -> SshConfigFile {
150 SshConfigFile {
151 elements: SshConfigFile::parse_content(content),
152 path: PathBuf::from("/tmp/test_config"),
153 crlf: content.contains("\r\n"),
154 bom: false,
155 }
156 }
157
158 #[test]
159 fn test_round_trip_basic() {
160 let content = "\
161Host myserver
162 HostName 192.168.1.10
163 User admin
164 Port 2222
165";
166 let config = parse_str(content);
167 assert_eq!(config.serialize(), content);
168 }
169
170 #[test]
171 fn test_round_trip_with_comments() {
172 let content = "\
173# My SSH config
174# Generated by hand
175
176Host alpha
177 HostName alpha.example.com
178 # Deploy user
179 User deploy
180
181Host beta
182 HostName beta.example.com
183 User root
184";
185 let config = parse_str(content);
186 assert_eq!(config.serialize(), content);
187 }
188
189 #[test]
190 fn test_round_trip_with_globals_and_wildcards() {
191 let content = "\
192# Global settings
193Host *
194 ServerAliveInterval 60
195 ServerAliveCountMax 3
196
197Host production
198 HostName prod.example.com
199 User deployer
200 IdentityFile ~/.ssh/prod_key
201";
202 let config = parse_str(content);
203 assert_eq!(config.serialize(), content);
204 }
205
206 #[test]
207 fn test_add_host_serializes() {
208 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
209 config.add_host(&HostEntry {
210 alias: "newhost".to_string(),
211 hostname: "10.0.0.2".to_string(),
212 user: "admin".to_string(),
213 port: 22,
214 ..Default::default()
215 });
216 let output = config.serialize();
217 assert!(output.contains("Host newhost"));
218 assert!(output.contains("HostName 10.0.0.2"));
219 assert!(output.contains("User admin"));
220 assert!(!output.contains("Port 22"));
222 }
223
224 #[test]
225 fn test_delete_host_serializes() {
226 let content = "\
227Host alpha
228 HostName alpha.example.com
229
230Host beta
231 HostName beta.example.com
232";
233 let mut config = parse_str(content);
234 config.delete_host("alpha");
235 let output = config.serialize();
236 assert!(!output.contains("Host alpha"));
237 assert!(output.contains("Host beta"));
238 }
239
240 #[test]
241 fn test_update_host_serializes() {
242 let content = "\
243Host myserver
244 HostName 10.0.0.1
245 User old_user
246";
247 let mut config = parse_str(content);
248 config.update_host(
249 "myserver",
250 &HostEntry {
251 alias: "myserver".to_string(),
252 hostname: "10.0.0.2".to_string(),
253 user: "new_user".to_string(),
254 port: 22,
255 ..Default::default()
256 },
257 );
258 let output = config.serialize();
259 assert!(output.contains("HostName 10.0.0.2"));
260 assert!(output.contains("User new_user"));
261 assert!(!output.contains("old_user"));
262 }
263
264 #[test]
265 fn test_update_host_preserves_unknown_directives() {
266 let content = "\
267Host myserver
268 HostName 10.0.0.1
269 User admin
270 ForwardAgent yes
271 LocalForward 8080 localhost:80
272 Compression yes
273";
274 let mut config = parse_str(content);
275 config.update_host(
276 "myserver",
277 &HostEntry {
278 alias: "myserver".to_string(),
279 hostname: "10.0.0.2".to_string(),
280 user: "admin".to_string(),
281 port: 22,
282 ..Default::default()
283 },
284 );
285 let output = config.serialize();
286 assert!(output.contains("HostName 10.0.0.2"));
287 assert!(output.contains("ForwardAgent yes"));
288 assert!(output.contains("LocalForward 8080 localhost:80"));
289 assert!(output.contains("Compression yes"));
290 }
291}