purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
12 if self.path.exists() {
14 self.create_backup()
15 .context("Failed to create backup of SSH config")?;
16 self.prune_backups(5).ok();
17 }
18
19 let content = self.serialize();
20
21 if let Some(parent) = self.path.parent() {
23 fs::create_dir_all(parent)
24 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25 }
26
27 let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29
30 #[cfg(unix)]
31 {
32 use std::io::Write;
33 use std::os::unix::fs::OpenOptionsExt;
34 let mut file = fs::OpenOptions::new()
35 .write(true)
36 .create(true)
37 .truncate(true)
38 .mode(0o600)
39 .open(&tmp_path)
40 .with_context(|| format!("Failed to create temp file {}", tmp_path.display()))?;
41 file.write_all(content.as_bytes())
42 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
43 }
44
45 #[cfg(not(unix))]
46 fs::write(&tmp_path, &content)
47 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
48
49 fs::rename(&tmp_path, &self.path).with_context(|| {
50 format!(
51 "Failed to rename {} to {}",
52 tmp_path.display(),
53 self.path.display()
54 )
55 })?;
56
57 Ok(())
58 }
59
60 pub fn serialize(&self) -> String {
62 let mut lines = Vec::new();
63
64 for element in &self.elements {
65 match element {
66 ConfigElement::GlobalLine(line) => {
67 lines.push(line.clone());
68 }
69 ConfigElement::HostBlock(block) => {
70 lines.push(block.raw_host_line.clone());
71 for directive in &block.directives {
72 lines.push(directive.raw_line.clone());
73 }
74 }
75 ConfigElement::Include(include) => {
76 lines.push(include.raw_line.clone());
77 }
78 }
79 }
80
81 let mut result = lines.join("\n");
82 if !result.ends_with('\n') {
84 result.push('\n');
85 }
86 result
87 }
88
89 fn create_backup(&self) -> Result<()> {
91 let timestamp = SystemTime::now()
92 .duration_since(SystemTime::UNIX_EPOCH)
93 .unwrap_or_default()
94 .as_millis();
95 let backup_name = format!(
96 "{}.bak.{}",
97 self.path.file_name().unwrap_or_default().to_string_lossy(),
98 timestamp
99 );
100 let backup_path = self.path.with_file_name(backup_name);
101 fs::copy(&self.path, &backup_path).with_context(|| {
102 format!(
103 "Failed to copy {} to {}",
104 self.path.display(),
105 backup_path.display()
106 )
107 })?;
108 Ok(())
109 }
110
111 fn prune_backups(&self, keep: usize) -> Result<()> {
113 let parent = self.path.parent().context("No parent directory")?;
114 let prefix = format!(
115 "{}.bak.",
116 self.path.file_name().unwrap_or_default().to_string_lossy()
117 );
118 let mut backups: Vec<_> = fs::read_dir(parent)?
119 .filter_map(|e| e.ok())
120 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
121 .collect();
122 backups.sort_by_key(|e| e.file_name());
123 if backups.len() > keep {
124 for old in &backups[..backups.len() - keep] {
125 let _ = fs::remove_file(old.path());
126 }
127 }
128 Ok(())
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::ssh_config::model::HostEntry;
136 use std::path::PathBuf;
137
138 fn parse_str(content: &str) -> SshConfigFile {
139 SshConfigFile {
140 elements: SshConfigFile::parse_content(content),
141 path: PathBuf::from("/tmp/test_config"),
142 }
143 }
144
145 #[test]
146 fn test_round_trip_basic() {
147 let content = "\
148Host myserver
149 HostName 192.168.1.10
150 User admin
151 Port 2222
152";
153 let config = parse_str(content);
154 assert_eq!(config.serialize(), content);
155 }
156
157 #[test]
158 fn test_round_trip_with_comments() {
159 let content = "\
160# My SSH config
161# Generated by hand
162
163Host alpha
164 HostName alpha.example.com
165 # Deploy user
166 User deploy
167
168Host beta
169 HostName beta.example.com
170 User root
171";
172 let config = parse_str(content);
173 assert_eq!(config.serialize(), content);
174 }
175
176 #[test]
177 fn test_round_trip_with_globals_and_wildcards() {
178 let content = "\
179# Global settings
180Host *
181 ServerAliveInterval 60
182 ServerAliveCountMax 3
183
184Host production
185 HostName prod.example.com
186 User deployer
187 IdentityFile ~/.ssh/prod_key
188";
189 let config = parse_str(content);
190 assert_eq!(config.serialize(), content);
191 }
192
193 #[test]
194 fn test_add_host_serializes() {
195 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
196 config.add_host(&HostEntry {
197 alias: "newhost".to_string(),
198 hostname: "10.0.0.2".to_string(),
199 user: "admin".to_string(),
200 port: 22,
201 ..Default::default()
202 });
203 let output = config.serialize();
204 assert!(output.contains("Host newhost"));
205 assert!(output.contains("HostName 10.0.0.2"));
206 assert!(output.contains("User admin"));
207 assert!(!output.contains("Port 22"));
209 }
210
211 #[test]
212 fn test_delete_host_serializes() {
213 let content = "\
214Host alpha
215 HostName alpha.example.com
216
217Host beta
218 HostName beta.example.com
219";
220 let mut config = parse_str(content);
221 config.delete_host("alpha");
222 let output = config.serialize();
223 assert!(!output.contains("Host alpha"));
224 assert!(output.contains("Host beta"));
225 }
226
227 #[test]
228 fn test_update_host_serializes() {
229 let content = "\
230Host myserver
231 HostName 10.0.0.1
232 User old_user
233";
234 let mut config = parse_str(content);
235 config.update_host(
236 "myserver",
237 &HostEntry {
238 alias: "myserver".to_string(),
239 hostname: "10.0.0.2".to_string(),
240 user: "new_user".to_string(),
241 port: 22,
242 ..Default::default()
243 },
244 );
245 let output = config.serialize();
246 assert!(output.contains("HostName 10.0.0.2"));
247 assert!(output.contains("User new_user"));
248 assert!(!output.contains("old_user"));
249 }
250
251 #[test]
252 fn test_update_host_preserves_unknown_directives() {
253 let content = "\
254Host myserver
255 HostName 10.0.0.1
256 User admin
257 ForwardAgent yes
258 LocalForward 8080 localhost:80
259 Compression yes
260";
261 let mut config = parse_str(content);
262 config.update_host(
263 "myserver",
264 &HostEntry {
265 alias: "myserver".to_string(),
266 hostname: "10.0.0.2".to_string(),
267 user: "admin".to_string(),
268 port: 22,
269 ..Default::default()
270 },
271 );
272 let output = config.serialize();
273 assert!(output.contains("HostName 10.0.0.2"));
274 assert!(output.contains("ForwardAgent yes"));
275 assert!(output.contains("LocalForward 8080 localhost:80"));
276 assert!(output.contains("Compression yes"));
277 }
278}