purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
12 if self.path.exists() {
14 self.create_backup()
15 .context("Failed to create backup of SSH config")?;
16 self.prune_backups(5).ok();
17 }
18
19 let content = self.serialize();
20
21 if let Some(parent) = self.path.parent() {
23 fs::create_dir_all(parent)
24 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25 }
26
27 let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30 #[cfg(unix)]
31 {
32 use std::io::Write;
33 use std::os::unix::fs::OpenOptionsExt;
34 let mut file = fs::OpenOptions::new()
35 .write(true)
36 .create(true)
37 .truncate(true)
38 .mode(0o600)
39 .open(&tmp_path)
40 .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41 file.write_all(content.as_bytes())
42 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43 }
44
45 #[cfg(not(unix))]
46 fs::write(&tmp_path, &content)
47 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49 let result = fs::rename(&tmp_path, &self.path);
50 if result.is_err() {
51 let _ = fs::remove_file(&tmp_path);
52 }
53 result.with_context(|| {
54 format!(
55 "Failed to rename {} to {}",
56 tmp_path.display(),
57 self.path.display()
58 )
59 })?;
60
61 Ok(())
62 }
63
64 pub fn serialize(&self) -> String {
66 let mut lines = Vec::new();
67
68 for element in &self.elements {
69 match element {
70 ConfigElement::GlobalLine(line) => {
71 lines.push(line.clone());
72 }
73 ConfigElement::HostBlock(block) => {
74 lines.push(block.raw_host_line.clone());
75 for directive in &block.directives {
76 lines.push(directive.raw_line.clone());
77 }
78 }
79 ConfigElement::Include(include) => {
80 lines.push(include.raw_line.clone());
81 }
82 }
83 }
84
85 let mut result = lines.join("\n");
86 if !result.ends_with('\n') {
88 result.push('\n');
89 }
90 result
91 }
92
93 fn create_backup(&self) -> Result<()> {
95 let timestamp = SystemTime::now()
96 .duration_since(SystemTime::UNIX_EPOCH)
97 .unwrap_or_default()
98 .as_millis();
99 let backup_name = format!(
100 "{}.bak.{}",
101 self.path.file_name().unwrap_or_default().to_string_lossy(),
102 timestamp
103 );
104 let backup_path = self.path.with_file_name(backup_name);
105 fs::copy(&self.path, &backup_path).with_context(|| {
106 format!(
107 "Failed to copy {} to {}",
108 self.path.display(),
109 backup_path.display()
110 )
111 })?;
112 Ok(())
113 }
114
115 fn prune_backups(&self, keep: usize) -> Result<()> {
117 let parent = self.path.parent().context("No parent directory")?;
118 let prefix = format!(
119 "{}.bak.",
120 self.path.file_name().unwrap_or_default().to_string_lossy()
121 );
122 let mut backups: Vec<_> = fs::read_dir(parent)?
123 .filter_map(|e| e.ok())
124 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
125 .collect();
126 backups.sort_by_key(|e| e.file_name());
127 if backups.len() > keep {
128 for old in &backups[..backups.len() - keep] {
129 let _ = fs::remove_file(old.path());
130 }
131 }
132 Ok(())
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::ssh_config::model::HostEntry;
140 use std::path::PathBuf;
141
142 fn parse_str(content: &str) -> SshConfigFile {
143 SshConfigFile {
144 elements: SshConfigFile::parse_content(content),
145 path: PathBuf::from("/tmp/test_config"),
146 }
147 }
148
149 #[test]
150 fn test_round_trip_basic() {
151 let content = "\
152Host myserver
153 HostName 192.168.1.10
154 User admin
155 Port 2222
156";
157 let config = parse_str(content);
158 assert_eq!(config.serialize(), content);
159 }
160
161 #[test]
162 fn test_round_trip_with_comments() {
163 let content = "\
164# My SSH config
165# Generated by hand
166
167Host alpha
168 HostName alpha.example.com
169 # Deploy user
170 User deploy
171
172Host beta
173 HostName beta.example.com
174 User root
175";
176 let config = parse_str(content);
177 assert_eq!(config.serialize(), content);
178 }
179
180 #[test]
181 fn test_round_trip_with_globals_and_wildcards() {
182 let content = "\
183# Global settings
184Host *
185 ServerAliveInterval 60
186 ServerAliveCountMax 3
187
188Host production
189 HostName prod.example.com
190 User deployer
191 IdentityFile ~/.ssh/prod_key
192";
193 let config = parse_str(content);
194 assert_eq!(config.serialize(), content);
195 }
196
197 #[test]
198 fn test_add_host_serializes() {
199 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
200 config.add_host(&HostEntry {
201 alias: "newhost".to_string(),
202 hostname: "10.0.0.2".to_string(),
203 user: "admin".to_string(),
204 port: 22,
205 ..Default::default()
206 });
207 let output = config.serialize();
208 assert!(output.contains("Host newhost"));
209 assert!(output.contains("HostName 10.0.0.2"));
210 assert!(output.contains("User admin"));
211 assert!(!output.contains("Port 22"));
213 }
214
215 #[test]
216 fn test_delete_host_serializes() {
217 let content = "\
218Host alpha
219 HostName alpha.example.com
220
221Host beta
222 HostName beta.example.com
223";
224 let mut config = parse_str(content);
225 config.delete_host("alpha");
226 let output = config.serialize();
227 assert!(!output.contains("Host alpha"));
228 assert!(output.contains("Host beta"));
229 }
230
231 #[test]
232 fn test_update_host_serializes() {
233 let content = "\
234Host myserver
235 HostName 10.0.0.1
236 User old_user
237";
238 let mut config = parse_str(content);
239 config.update_host(
240 "myserver",
241 &HostEntry {
242 alias: "myserver".to_string(),
243 hostname: "10.0.0.2".to_string(),
244 user: "new_user".to_string(),
245 port: 22,
246 ..Default::default()
247 },
248 );
249 let output = config.serialize();
250 assert!(output.contains("HostName 10.0.0.2"));
251 assert!(output.contains("User new_user"));
252 assert!(!output.contains("old_user"));
253 }
254
255 #[test]
256 fn test_update_host_preserves_unknown_directives() {
257 let content = "\
258Host myserver
259 HostName 10.0.0.1
260 User admin
261 ForwardAgent yes
262 LocalForward 8080 localhost:80
263 Compression yes
264";
265 let mut config = parse_str(content);
266 config.update_host(
267 "myserver",
268 &HostEntry {
269 alias: "myserver".to_string(),
270 hostname: "10.0.0.2".to_string(),
271 user: "admin".to_string(),
272 port: 22,
273 ..Default::default()
274 },
275 );
276 let output = config.serialize();
277 assert!(output.contains("HostName 10.0.0.2"));
278 assert!(output.contains("ForwardAgent yes"));
279 assert!(output.contains("LocalForward 8080 localhost:80"));
280 assert!(output.contains("Compression yes"));
281 }
282}