purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7use crate::fs_util;
8
9impl SshConfigFile {
10 pub fn write(&self) -> Result<()> {
16 if crate::demo_flag::is_demo() {
17 return Ok(());
18 }
19 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
21
22 let _lock =
24 fs_util::FileLock::acquire(&target_path).context("Failed to acquire config lock")?;
25
26 if self.path.exists() {
28 self.create_backup()
29 .context("Failed to create backup of SSH config")?;
30 self.prune_backups(5).ok();
31 }
32
33 let content = self.serialize();
34
35 fs_util::atomic_write(&target_path, content.as_bytes())
36 .with_context(|| format!("Failed to write SSH config to {}", target_path.display()))?;
37
38 Ok(())
40 }
41
42 pub fn serialize(&self) -> String {
45 let mut lines = Vec::new();
46
47 for element in &self.elements {
48 match element {
49 ConfigElement::GlobalLine(line) => {
50 lines.push(line.clone());
51 }
52 ConfigElement::HostBlock(block) => {
53 lines.push(block.raw_host_line.clone());
54 for directive in &block.directives {
55 lines.push(directive.raw_line.clone());
56 }
57 }
58 ConfigElement::Include(include) => {
59 lines.push(include.raw_line.clone());
60 }
61 }
62 }
63
64 let mut collapsed = Vec::with_capacity(lines.len());
66 let mut prev_blank = false;
67 for line in lines {
68 let is_blank = line.trim().is_empty();
69 if is_blank && prev_blank {
70 continue;
71 }
72 prev_blank = is_blank;
73 collapsed.push(line);
74 }
75
76 let line_ending = if self.crlf { "\r\n" } else { "\n" };
77 let mut result = String::new();
78 if self.bom {
80 result.push('\u{FEFF}');
81 }
82 for line in &collapsed {
83 result.push_str(line);
84 result.push_str(line_ending);
85 }
86 if collapsed.is_empty() {
89 result.push_str(line_ending);
90 }
91 result
92 }
93
94 fn create_backup(&self) -> Result<()> {
97 let timestamp = SystemTime::now()
98 .duration_since(SystemTime::UNIX_EPOCH)
99 .unwrap_or_default()
100 .as_millis();
101 let backup_name = format!(
102 "{}.bak.{}",
103 self.path.file_name().unwrap_or_default().to_string_lossy(),
104 timestamp
105 );
106 let backup_path = self.path.with_file_name(backup_name);
107 fs::copy(&self.path, &backup_path).with_context(|| {
108 format!(
109 "Failed to copy {} to {}",
110 self.path.display(),
111 backup_path.display()
112 )
113 })?;
114
115 #[cfg(unix)]
117 {
118 use std::os::unix::fs::PermissionsExt;
119 let _ = fs::set_permissions(&backup_path, fs::Permissions::from_mode(0o600));
120 }
121
122 Ok(())
123 }
124
125 fn prune_backups(&self, keep: usize) -> Result<()> {
127 let parent = self.path.parent().context("No parent directory")?;
128 let prefix = format!(
129 "{}.bak.",
130 self.path.file_name().unwrap_or_default().to_string_lossy()
131 );
132 let mut backups: Vec<_> = fs::read_dir(parent)?
133 .filter_map(|e| e.ok())
134 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
135 .collect();
136 backups.sort_by_key(|e| e.file_name());
137 if backups.len() > keep {
138 for old in &backups[..backups.len() - keep] {
139 let _ = fs::remove_file(old.path());
140 }
141 }
142 Ok(())
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::ssh_config::model::HostEntry;
150 use std::path::PathBuf;
151
152 fn parse_str(content: &str) -> SshConfigFile {
153 SshConfigFile {
154 elements: SshConfigFile::parse_content(content),
155 path: PathBuf::from("/tmp/test_config"),
156 crlf: content.contains("\r\n"),
157 bom: false,
158 }
159 }
160
161 #[test]
162 fn test_round_trip_basic() {
163 let content = "\
164Host myserver
165 HostName 192.168.1.10
166 User admin
167 Port 2222
168";
169 let config = parse_str(content);
170 assert_eq!(config.serialize(), content);
171 }
172
173 #[test]
174 fn test_round_trip_with_comments() {
175 let content = "\
176# My SSH config
177# Generated by hand
178
179Host alpha
180 HostName alpha.example.com
181 # Deploy user
182 User deploy
183
184Host beta
185 HostName beta.example.com
186 User root
187";
188 let config = parse_str(content);
189 assert_eq!(config.serialize(), content);
190 }
191
192 #[test]
193 fn test_round_trip_with_globals_and_wildcards() {
194 let content = "\
195# Global settings
196Host *
197 ServerAliveInterval 60
198 ServerAliveCountMax 3
199
200Host production
201 HostName prod.example.com
202 User deployer
203 IdentityFile ~/.ssh/prod_key
204";
205 let config = parse_str(content);
206 assert_eq!(config.serialize(), content);
207 }
208
209 #[test]
210 fn test_add_host_serializes() {
211 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
212 config.add_host(&HostEntry {
213 alias: "newhost".to_string(),
214 hostname: "10.0.0.2".to_string(),
215 user: "admin".to_string(),
216 port: 22,
217 ..Default::default()
218 });
219 let output = config.serialize();
220 assert!(output.contains("Host newhost"));
221 assert!(output.contains("HostName 10.0.0.2"));
222 assert!(output.contains("User admin"));
223 assert!(!output.contains("Port 22"));
225 }
226
227 #[test]
228 fn test_delete_host_serializes() {
229 let content = "\
230Host alpha
231 HostName alpha.example.com
232
233Host beta
234 HostName beta.example.com
235";
236 let mut config = parse_str(content);
237 config.delete_host("alpha");
238 let output = config.serialize();
239 assert!(!output.contains("Host alpha"));
240 assert!(output.contains("Host beta"));
241 }
242
243 #[test]
244 fn test_update_host_serializes() {
245 let content = "\
246Host myserver
247 HostName 10.0.0.1
248 User old_user
249";
250 let mut config = parse_str(content);
251 config.update_host(
252 "myserver",
253 &HostEntry {
254 alias: "myserver".to_string(),
255 hostname: "10.0.0.2".to_string(),
256 user: "new_user".to_string(),
257 port: 22,
258 ..Default::default()
259 },
260 );
261 let output = config.serialize();
262 assert!(output.contains("HostName 10.0.0.2"));
263 assert!(output.contains("User new_user"));
264 assert!(!output.contains("old_user"));
265 }
266
267 #[test]
268 fn test_update_host_preserves_unknown_directives() {
269 let content = "\
270Host myserver
271 HostName 10.0.0.1
272 User admin
273 ForwardAgent yes
274 LocalForward 8080 localhost:80
275 Compression yes
276";
277 let mut config = parse_str(content);
278 config.update_host(
279 "myserver",
280 &HostEntry {
281 alias: "myserver".to_string(),
282 hostname: "10.0.0.2".to_string(),
283 user: "admin".to_string(),
284 port: 22,
285 ..Default::default()
286 },
287 );
288 let output = config.serialize();
289 assert!(output.contains("HostName 10.0.0.2"));
290 assert!(output.contains("ForwardAgent yes"));
291 assert!(output.contains("LocalForward 8080 localhost:80"));
292 assert!(output.contains("Compression yes"));
293 }
294}