1use serde::{Deserialize, Serialize};
6use tracing::info;
7
8pub fn is_retryable_transport_error(err: &anyhow::Error) -> bool {
17 let mut parts = Vec::new();
18 for cause in err.chain() {
19 parts.push(cause.to_string());
20 }
21 is_retryable_transport_error_text(&parts.join(": "))
22}
23
24pub fn is_retryable_transport_error_text(message: &str) -> bool {
26 let message = message.to_lowercase();
27
28 if message.contains("permission denied")
30 || message.contains("host key verification failed")
31 || message.contains("could not resolve hostname")
32 || message.contains("no such file or directory")
33 || message.contains("identity file")
34 || message.contains("keyfile")
35 || message.contains("invalid format")
36 || message.contains("unknown option")
37 {
38 return false;
39 }
40
41 message.contains("connection timed out")
43 || message.contains("timed out")
44 || message.contains("connection reset")
45 || message.contains("broken pipe")
46 || message.contains("connection refused")
47 || message.contains("network is unreachable")
48 || message.contains("no route to host")
49 || message.contains("connection closed")
50 || message.contains("connection lost")
51 || message.contains("ssh_exchange_identification")
52 || message.contains("kex_exchange_identification")
53 || message.contains("temporary failure in name resolution")
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CommandResult {
59 pub exit_code: i32,
61 pub stdout: String,
63 pub stderr: String,
65 pub duration_ms: u64,
67}
68
69impl CommandResult {
70 pub fn success(&self) -> bool {
72 self.exit_code == 0
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct EnvPrefix {
79 pub prefix: String,
81 pub applied: Vec<String>,
83 pub rejected: Vec<String>,
85}
86
87pub fn build_env_prefix<F>(allowlist: &[String], mut get_env: F) -> EnvPrefix
93where
94 F: FnMut(&str) -> Option<String>,
95{
96 let mut parts = Vec::new();
97 let mut applied = Vec::new();
98 let mut rejected = Vec::new();
99
100 for raw_key in allowlist {
101 let key = raw_key.trim();
102 if key.is_empty() {
103 continue;
104 }
105 if !is_valid_env_key(key) {
106 info!(
107 "Rejecting env var '{}': invalid key name (must start with letter/underscore, contain only alphanumeric/underscore)",
108 key
109 );
110 rejected.push(key.to_string());
111 continue;
112 }
113 let Some(value) = get_env(key) else {
114 continue;
116 };
117 let Some(escaped) = shell_escape_value(&value) else {
118 info!(
119 "Rejecting env var '{}': value contains unsafe characters (newline, carriage return, or NUL)",
120 key
121 );
122 rejected.push(key.to_string());
123 continue;
124 };
125 parts.push(format!("{}={}", key, escaped));
126 applied.push(key.to_string());
127 }
128
129 let prefix = if parts.is_empty() {
130 String::new()
131 } else {
132 format!("{} ", parts.join(" "))
133 };
134
135 EnvPrefix {
136 prefix,
137 applied,
138 rejected,
139 }
140}
141
142pub fn is_valid_env_key(key: &str) -> bool {
144 let mut chars = key.chars();
145 let Some(first) = chars.next() else {
146 return false;
147 };
148 if !(first == '_' || first.is_ascii_alphabetic()) {
149 return false;
150 }
151 chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
152}
153
154pub fn shell_escape_value(value: &str) -> Option<String> {
159 if value.contains('\n') || value.contains('\r') || value.contains('\0') {
162 return None;
163 }
164
165 if value.is_empty() {
166 return Some("''".to_string());
167 }
168
169 let needs_quotes = value
170 .chars()
171 .any(|c| !c.is_ascii_alphanumeric() && c != '_');
172 if !needs_quotes {
173 return Some(value.to_string());
174 }
175
176 let mut escaped = String::with_capacity(value.len() + 2);
177 escaped.push('\'');
178 for ch in value.chars() {
179 if ch == '\'' {
180 escaped.push_str("'\\''");
181 } else {
182 escaped.push(ch);
183 }
184 }
185 escaped.push('\'');
186 Some(escaped)
187}
188
189pub fn shell_escape_path_with_home(path: &str) -> Option<String> {
195 if path.contains('\n') || path.contains('\r') || path.contains('\0') {
196 return None;
197 }
198
199 if path == "~" {
200 return Some("\"$HOME\"".to_string());
201 }
202
203 if let Some(suffix) = path.strip_prefix("~/") {
204 let escaped_suffix = escape_for_double_quotes(suffix);
205 return Some(format!("\"$HOME/{}\"", escaped_suffix));
206 }
207
208 shell_escape_value(path)
209}
210
211fn escape_for_double_quotes(value: &str) -> String {
212 let mut escaped = String::with_capacity(value.len());
213 for ch in value.chars() {
214 match ch {
215 '\\' => escaped.push_str("\\\\"),
216 '"' => escaped.push_str("\\\""),
217 '$' => escaped.push_str("\\$"),
218 '`' => escaped.push_str("\\`"),
219 _ => escaped.push(ch),
220 }
221 }
222 escaped
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::test_guard;
229
230 #[test]
231 fn test_retryable_transport_error_text() {
232 let _guard = test_guard!();
233 assert!(is_retryable_transport_error_text(
234 "ssh: connect to host 1.2.3.4 port 22: Connection timed out"
235 ));
236 assert!(is_retryable_transport_error_text(
237 "kex_exchange_identification: Connection reset by peer"
238 ));
239 assert!(is_retryable_transport_error_text("Broken pipe"));
240 assert!(is_retryable_transport_error_text("Network is unreachable"));
241 }
242
243 #[test]
244 fn test_non_retryable_transport_error_text() {
245 let _guard = test_guard!();
246 assert!(!is_retryable_transport_error_text(
247 "Permission denied (publickey)."
248 ));
249 assert!(!is_retryable_transport_error_text(
250 "Host key verification failed."
251 ));
252 assert!(!is_retryable_transport_error_text(
253 "Could not resolve hostname worker.example.com: Name or service not known"
254 ));
255 assert!(!is_retryable_transport_error_text(
256 "Identity file /nope/id_rsa not accessible: No such file or directory"
257 ));
258 }
259
260 #[test]
261 fn test_command_result_success() {
262 let _guard = test_guard!();
263 let result = CommandResult {
264 exit_code: 0,
265 stdout: "output".to_string(),
266 stderr: String::new(),
267 duration_ms: 100,
268 };
269 assert!(result.success());
270
271 let failed = CommandResult {
272 exit_code: 1,
273 stdout: String::new(),
274 stderr: "error".to_string(),
275 duration_ms: 50,
276 };
277 assert!(!failed.success());
278 }
279
280 #[test]
281 fn test_shell_escape_value() {
282 let _guard = test_guard!();
283 assert_eq!(shell_escape_value("simple"), Some("simple".to_string()));
285
286 assert_eq!(shell_escape_value(""), Some("''".to_string()));
288
289 assert_eq!(
291 shell_escape_value("with spaces"),
292 Some("'with spaces'".to_string())
293 );
294
295 assert_eq!(shell_escape_value("it's"), Some("'it'\\''s'".to_string()));
297
298 assert!(shell_escape_value("line1\nline2").is_none());
300 assert!(shell_escape_value("line1\rline2").is_none());
301 assert!(shell_escape_value("line1\0line2").is_none());
302 }
303
304 #[test]
305 fn test_shell_escape_path_with_home() {
306 let _guard = test_guard!();
307 assert_eq!(
308 shell_escape_path_with_home("~/.local/bin"),
309 Some("\"$HOME/.local/bin\"".to_string())
310 );
311 assert_eq!(
312 shell_escape_path_with_home("~"),
313 Some("\"$HOME\"".to_string())
314 );
315 assert_eq!(
316 shell_escape_path_with_home("/usr/local/bin"),
317 Some("'/usr/local/bin'".to_string())
318 );
319 }
320
321 #[test]
322 fn test_is_valid_env_key() {
323 let _guard = test_guard!();
324 assert!(is_valid_env_key("PATH"));
325 assert!(is_valid_env_key("_PRIVATE"));
326 assert!(is_valid_env_key("MY_VAR_123"));
327 assert!(!is_valid_env_key("123VAR"));
328 assert!(!is_valid_env_key("MY-VAR"));
329 assert!(!is_valid_env_key(""));
330 }
331
332 #[test]
333 fn test_build_env_prefix() {
334 let _guard = test_guard!();
335 use std::collections::HashMap;
336
337 let mut env = HashMap::new();
338 env.insert("RUSTFLAGS".to_string(), "-C target-cpu=native".to_string());
339 env.insert("QUOTED".to_string(), "a'b".to_string());
340 env.insert("BADVAL".to_string(), "line1\nline2".to_string());
341
342 let allowlist = vec![
343 "RUSTFLAGS".to_string(),
344 "QUOTED".to_string(),
345 "MISSING".to_string(),
346 "BADVAL".to_string(),
347 "BAD=KEY".to_string(),
348 ];
349
350 let prefix = build_env_prefix(&allowlist, |key| env.get(key).cloned());
351
352 assert!(prefix.prefix.contains("RUSTFLAGS='-C target-cpu=native'"));
353 assert!(prefix.prefix.contains("QUOTED='a'\\''b'"));
354 assert!(!prefix.prefix.contains("MISSING="));
355 assert!(!prefix.prefix.contains("BADVAL="));
356 assert!(prefix.rejected.contains(&"BADVAL".to_string()));
357 assert!(prefix.rejected.contains(&"BAD=KEY".to_string()));
358 }
359}