purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
13 let target_path = fs::canonicalize(&self.path).unwrap_or_else(|_| self.path.clone());
15
16 if self.path.exists() {
18 self.create_backup()
19 .context("Failed to create backup of SSH config")?;
20 self.prune_backups(5).ok();
21 }
22
23 let content = self.serialize();
24
25 if let Some(parent) = target_path.parent() {
27 fs::create_dir_all(parent)
28 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
29 }
30
31 let tmp_path =
33 target_path.with_extension(format!("purple_tmp.{}", std::process::id()));
34
35 #[cfg(unix)]
36 {
37 use std::io::Write;
38 use std::os::unix::fs::OpenOptionsExt;
39 let mut file = fs::OpenOptions::new()
40 .write(true)
41 .create(true)
42 .truncate(true)
43 .mode(0o600)
44 .open(&tmp_path)
45 .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
46 file.write_all(content.as_bytes())
47 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48 }
49
50 #[cfg(not(unix))]
51 fs::write(&tmp_path, &content)
52 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
53
54 let result = fs::rename(&tmp_path, &target_path);
55 if result.is_err() {
56 let _ = fs::remove_file(&tmp_path);
57 }
58 result.with_context(|| {
59 format!(
60 "Failed to rename {} to {}",
61 tmp_path.display(),
62 target_path.display()
63 )
64 })?;
65
66 Ok(())
67 }
68
69 pub fn serialize(&self) -> String {
71 let mut lines = Vec::new();
72
73 for element in &self.elements {
74 match element {
75 ConfigElement::GlobalLine(line) => {
76 lines.push(line.clone());
77 }
78 ConfigElement::HostBlock(block) => {
79 lines.push(block.raw_host_line.clone());
80 for directive in &block.directives {
81 lines.push(directive.raw_line.clone());
82 }
83 }
84 ConfigElement::Include(include) => {
85 lines.push(include.raw_line.clone());
86 }
87 }
88 }
89
90 let line_ending = if self.crlf { "\r\n" } else { "\n" };
91 let mut result = lines.join(line_ending);
92 if !result.ends_with('\n') {
94 result.push_str(line_ending);
95 }
96 result
97 }
98
99 fn create_backup(&self) -> Result<()> {
101 let timestamp = SystemTime::now()
102 .duration_since(SystemTime::UNIX_EPOCH)
103 .unwrap_or_default()
104 .as_millis();
105 let backup_name = format!(
106 "{}.bak.{}",
107 self.path.file_name().unwrap_or_default().to_string_lossy(),
108 timestamp
109 );
110 let backup_path = self.path.with_file_name(backup_name);
111 fs::copy(&self.path, &backup_path).with_context(|| {
112 format!(
113 "Failed to copy {} to {}",
114 self.path.display(),
115 backup_path.display()
116 )
117 })?;
118 Ok(())
119 }
120
121 fn prune_backups(&self, keep: usize) -> Result<()> {
123 let parent = self.path.parent().context("No parent directory")?;
124 let prefix = format!(
125 "{}.bak.",
126 self.path.file_name().unwrap_or_default().to_string_lossy()
127 );
128 let mut backups: Vec<_> = fs::read_dir(parent)?
129 .filter_map(|e| e.ok())
130 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
131 .collect();
132 backups.sort_by_key(|e| e.file_name());
133 if backups.len() > keep {
134 for old in &backups[..backups.len() - keep] {
135 let _ = fs::remove_file(old.path());
136 }
137 }
138 Ok(())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::ssh_config::model::HostEntry;
146 use std::path::PathBuf;
147
148 fn parse_str(content: &str) -> SshConfigFile {
149 SshConfigFile {
150 elements: SshConfigFile::parse_content(content),
151 path: PathBuf::from("/tmp/test_config"),
152 crlf: content.contains("\r\n"),
153 }
154 }
155
156 #[test]
157 fn test_round_trip_basic() {
158 let content = "\
159Host myserver
160 HostName 192.168.1.10
161 User admin
162 Port 2222
163";
164 let config = parse_str(content);
165 assert_eq!(config.serialize(), content);
166 }
167
168 #[test]
169 fn test_round_trip_with_comments() {
170 let content = "\
171# My SSH config
172# Generated by hand
173
174Host alpha
175 HostName alpha.example.com
176 # Deploy user
177 User deploy
178
179Host beta
180 HostName beta.example.com
181 User root
182";
183 let config = parse_str(content);
184 assert_eq!(config.serialize(), content);
185 }
186
187 #[test]
188 fn test_round_trip_with_globals_and_wildcards() {
189 let content = "\
190# Global settings
191Host *
192 ServerAliveInterval 60
193 ServerAliveCountMax 3
194
195Host production
196 HostName prod.example.com
197 User deployer
198 IdentityFile ~/.ssh/prod_key
199";
200 let config = parse_str(content);
201 assert_eq!(config.serialize(), content);
202 }
203
204 #[test]
205 fn test_add_host_serializes() {
206 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
207 config.add_host(&HostEntry {
208 alias: "newhost".to_string(),
209 hostname: "10.0.0.2".to_string(),
210 user: "admin".to_string(),
211 port: 22,
212 ..Default::default()
213 });
214 let output = config.serialize();
215 assert!(output.contains("Host newhost"));
216 assert!(output.contains("HostName 10.0.0.2"));
217 assert!(output.contains("User admin"));
218 assert!(!output.contains("Port 22"));
220 }
221
222 #[test]
223 fn test_delete_host_serializes() {
224 let content = "\
225Host alpha
226 HostName alpha.example.com
227
228Host beta
229 HostName beta.example.com
230";
231 let mut config = parse_str(content);
232 config.delete_host("alpha");
233 let output = config.serialize();
234 assert!(!output.contains("Host alpha"));
235 assert!(output.contains("Host beta"));
236 }
237
238 #[test]
239 fn test_update_host_serializes() {
240 let content = "\
241Host myserver
242 HostName 10.0.0.1
243 User old_user
244";
245 let mut config = parse_str(content);
246 config.update_host(
247 "myserver",
248 &HostEntry {
249 alias: "myserver".to_string(),
250 hostname: "10.0.0.2".to_string(),
251 user: "new_user".to_string(),
252 port: 22,
253 ..Default::default()
254 },
255 );
256 let output = config.serialize();
257 assert!(output.contains("HostName 10.0.0.2"));
258 assert!(output.contains("User new_user"));
259 assert!(!output.contains("old_user"));
260 }
261
262 #[test]
263 fn test_update_host_preserves_unknown_directives() {
264 let content = "\
265Host myserver
266 HostName 10.0.0.1
267 User admin
268 ForwardAgent yes
269 LocalForward 8080 localhost:80
270 Compression yes
271";
272 let mut config = parse_str(content);
273 config.update_host(
274 "myserver",
275 &HostEntry {
276 alias: "myserver".to_string(),
277 hostname: "10.0.0.2".to_string(),
278 user: "admin".to_string(),
279 port: 22,
280 ..Default::default()
281 },
282 );
283 let output = config.serialize();
284 assert!(output.contains("HostName 10.0.0.2"));
285 assert!(output.contains("ForwardAgent yes"));
286 assert!(output.contains("LocalForward 8080 localhost:80"));
287 assert!(output.contains("Compression yes"));
288 }
289}