purple_ssh/ssh_config/
writer.rs1use std::fs;
2use std::time::SystemTime;
3
4use anyhow::{Context, Result};
5
6use super::model::{ConfigElement, SshConfigFile};
7
8impl SshConfigFile {
9 pub fn write(&self) -> Result<()> {
12 if self.path.exists() {
14 self.create_backup()
15 .context("Failed to create backup of SSH config")?;
16 self.prune_backups(5).ok();
17 }
18
19 let content = self.serialize();
20
21 if let Some(parent) = self.path.parent() {
23 fs::create_dir_all(parent)
24 .with_context(|| format!("Failed to create directory {}", parent.display()))?;
25 }
26
27 let tmp_path = self.path.with_extension(format!("purple_tmp.{}", std::process::id()));
29 fs::write(&tmp_path, &content)
30 .with_context(|| format!("Failed to write temp file {}", tmp_path.display()))?;
31
32 #[cfg(unix)]
34 {
35 use std::os::unix::fs::PermissionsExt;
36 let perms = fs::Permissions::from_mode(0o600);
37 fs::set_permissions(&tmp_path, perms)
38 .with_context(|| format!("Failed to set permissions on {}", tmp_path.display()))?;
39 }
40
41 fs::rename(&tmp_path, &self.path).with_context(|| {
42 format!(
43 "Failed to rename {} to {}",
44 tmp_path.display(),
45 self.path.display()
46 )
47 })?;
48
49 Ok(())
50 }
51
52 pub fn serialize(&self) -> String {
54 let mut lines = Vec::new();
55
56 for element in &self.elements {
57 match element {
58 ConfigElement::GlobalLine(line) => {
59 lines.push(line.clone());
60 }
61 ConfigElement::HostBlock(block) => {
62 lines.push(block.raw_host_line.clone());
63 for directive in &block.directives {
64 lines.push(directive.raw_line.clone());
65 }
66 }
67 ConfigElement::Include(include) => {
68 lines.push(include.raw_line.clone());
69 }
70 }
71 }
72
73 let mut result = lines.join("\n");
74 if !result.ends_with('\n') {
76 result.push('\n');
77 }
78 result
79 }
80
81 fn create_backup(&self) -> Result<()> {
83 let timestamp = SystemTime::now()
84 .duration_since(SystemTime::UNIX_EPOCH)
85 .unwrap_or_default()
86 .as_millis();
87 let backup_name = format!(
88 "{}.bak.{}",
89 self.path.file_name().unwrap_or_default().to_string_lossy(),
90 timestamp
91 );
92 let backup_path = self.path.with_file_name(backup_name);
93 fs::copy(&self.path, &backup_path).with_context(|| {
94 format!(
95 "Failed to copy {} to {}",
96 self.path.display(),
97 backup_path.display()
98 )
99 })?;
100 Ok(())
101 }
102
103 fn prune_backups(&self, keep: usize) -> Result<()> {
105 let parent = self.path.parent().context("No parent directory")?;
106 let prefix = format!(
107 "{}.bak.",
108 self.path.file_name().unwrap_or_default().to_string_lossy()
109 );
110 let mut backups: Vec<_> = fs::read_dir(parent)?
111 .filter_map(|e| e.ok())
112 .filter(|e| e.file_name().to_string_lossy().starts_with(&prefix))
113 .collect();
114 backups.sort_by_key(|e| e.file_name());
115 if backups.len() > keep {
116 for old in &backups[..backups.len() - keep] {
117 let _ = fs::remove_file(old.path());
118 }
119 }
120 Ok(())
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::ssh_config::model::HostEntry;
128 use std::path::PathBuf;
129
130 fn parse_str(content: &str) -> SshConfigFile {
131 SshConfigFile {
132 elements: SshConfigFile::parse_content(content),
133 path: PathBuf::from("/tmp/test_config"),
134 }
135 }
136
137 #[test]
138 fn test_round_trip_basic() {
139 let content = "\
140Host myserver
141 HostName 192.168.1.10
142 User admin
143 Port 2222
144";
145 let config = parse_str(content);
146 assert_eq!(config.serialize(), content);
147 }
148
149 #[test]
150 fn test_round_trip_with_comments() {
151 let content = "\
152# My SSH config
153# Generated by hand
154
155Host alpha
156 HostName alpha.example.com
157 # Deploy user
158 User deploy
159
160Host beta
161 HostName beta.example.com
162 User root
163";
164 let config = parse_str(content);
165 assert_eq!(config.serialize(), content);
166 }
167
168 #[test]
169 fn test_round_trip_with_globals_and_wildcards() {
170 let content = "\
171# Global settings
172Host *
173 ServerAliveInterval 60
174 ServerAliveCountMax 3
175
176Host production
177 HostName prod.example.com
178 User deployer
179 IdentityFile ~/.ssh/prod_key
180";
181 let config = parse_str(content);
182 assert_eq!(config.serialize(), content);
183 }
184
185 #[test]
186 fn test_add_host_serializes() {
187 let mut config = parse_str("Host existing\n HostName 10.0.0.1\n");
188 config.add_host(&HostEntry {
189 alias: "newhost".to_string(),
190 hostname: "10.0.0.2".to_string(),
191 user: "admin".to_string(),
192 port: 22,
193 ..Default::default()
194 });
195 let output = config.serialize();
196 assert!(output.contains("Host newhost"));
197 assert!(output.contains("HostName 10.0.0.2"));
198 assert!(output.contains("User admin"));
199 assert!(!output.contains("Port 22"));
201 }
202
203 #[test]
204 fn test_delete_host_serializes() {
205 let content = "\
206Host alpha
207 HostName alpha.example.com
208
209Host beta
210 HostName beta.example.com
211";
212 let mut config = parse_str(content);
213 config.delete_host("alpha");
214 let output = config.serialize();
215 assert!(!output.contains("Host alpha"));
216 assert!(output.contains("Host beta"));
217 }
218
219 #[test]
220 fn test_update_host_serializes() {
221 let content = "\
222Host myserver
223 HostName 10.0.0.1
224 User old_user
225";
226 let mut config = parse_str(content);
227 config.update_host(
228 "myserver",
229 &HostEntry {
230 alias: "myserver".to_string(),
231 hostname: "10.0.0.2".to_string(),
232 user: "new_user".to_string(),
233 port: 22,
234 ..Default::default()
235 },
236 );
237 let output = config.serialize();
238 assert!(output.contains("HostName 10.0.0.2"));
239 assert!(output.contains("User new_user"));
240 assert!(!output.contains("old_user"));
241 }
242
243 #[test]
244 fn test_update_host_preserves_unknown_directives() {
245 let content = "\
246Host myserver
247 HostName 10.0.0.1
248 User admin
249 ForwardAgent yes
250 LocalForward 8080 localhost:80
251 Compression yes
252";
253 let mut config = parse_str(content);
254 config.update_host(
255 "myserver",
256 &HostEntry {
257 alias: "myserver".to_string(),
258 hostname: "10.0.0.2".to_string(),
259 user: "admin".to_string(),
260 port: 22,
261 ..Default::default()
262 },
263 );
264 let output = config.serialize();
265 assert!(output.contains("HostName 10.0.0.2"));
266 assert!(output.contains("ForwardAgent yes"));
267 assert!(output.contains("LocalForward 8080 localhost:80"));
268 assert!(output.contains("Compression yes"));
269 }
270}