Skip to main content

rch_common/
ssh_utils.rs

1//! Platform-independent SSH utilities.
2//!
3//! These utilities work on all platforms and don't depend on openssh.
4
5use serde::{Deserialize, Serialize};
6use tracing::info;
7
8// ============================================================================
9// Retry Classification
10// ============================================================================
11
12/// True if an SSH/transport error looks retryable (transient network / transport).
13///
14/// This is intentionally conservative: false negatives are acceptable (fail-open
15/// to local execution), false positives can cause needless retries.
16pub fn is_retryable_transport_error(err: &anyhow::Error) -> bool {
17    let mut parts = Vec::new();
18    for cause in err.chain() {
19        parts.push(cause.to_string());
20    }
21    is_retryable_transport_error_text(&parts.join(": "))
22}
23
24/// Message-only variant of [`is_retryable_transport_error`] (useful for tests).
25pub fn is_retryable_transport_error_text(message: &str) -> bool {
26    let message = message.to_lowercase();
27
28    // Fail-fast: non-retryable authentication / host trust issues.
29    if message.contains("permission denied")
30        || message.contains("host key verification failed")
31        || message.contains("could not resolve hostname")
32        || message.contains("no such file or directory")
33        || message.contains("identity file")
34        || message.contains("keyfile")
35        || message.contains("invalid format")
36        || message.contains("unknown option")
37    {
38        return false;
39    }
40
41    // Common transient transport failures.
42    message.contains("connection timed out")
43        || message.contains("timed out")
44        || message.contains("connection reset")
45        || message.contains("broken pipe")
46        || message.contains("connection refused")
47        || message.contains("network is unreachable")
48        || message.contains("no route to host")
49        || message.contains("connection closed")
50        || message.contains("connection lost")
51        || message.contains("ssh_exchange_identification")
52        || message.contains("kex_exchange_identification")
53        || message.contains("temporary failure in name resolution")
54}
55
56/// Result of a remote command execution.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CommandResult {
59    /// Exit code of the command.
60    pub exit_code: i32,
61    /// Standard output.
62    pub stdout: String,
63    /// Standard error.
64    pub stderr: String,
65    /// Execution duration in milliseconds.
66    pub duration_ms: u64,
67}
68
69impl CommandResult {
70    /// Check if the command succeeded (exit code 0).
71    pub fn success(&self) -> bool {
72        self.exit_code == 0
73    }
74}
75
76/// Environment variable prefix for remote command execution.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct EnvPrefix {
79    /// Shell-safe prefix (includes trailing space when non-empty).
80    pub prefix: String,
81    /// Keys applied to the command.
82    pub applied: Vec<String>,
83    /// Keys rejected due to invalid name or unsafe value.
84    pub rejected: Vec<String>,
85}
86
87/// Build a shell-safe environment variable prefix from an allowlist.
88///
89/// - Missing variables are ignored silently.
90/// - Unsafe values (newline, carriage return, NUL) are rejected.
91/// - Invalid keys are rejected.
92pub fn build_env_prefix<F>(allowlist: &[String], mut get_env: F) -> EnvPrefix
93where
94    F: FnMut(&str) -> Option<String>,
95{
96    let mut parts = Vec::new();
97    let mut applied = Vec::new();
98    let mut rejected = Vec::new();
99
100    for raw_key in allowlist {
101        let key = raw_key.trim();
102        if key.is_empty() {
103            continue;
104        }
105        if !is_valid_env_key(key) {
106            info!(
107                "Rejecting env var '{}': invalid key name (must start with letter/underscore, contain only alphanumeric/underscore)",
108                key
109            );
110            rejected.push(key.to_string());
111            continue;
112        }
113        let Some(value) = get_env(key) else {
114            // Variable not set - this is normal, don't log
115            continue;
116        };
117        let Some(escaped) = shell_escape_value(&value) else {
118            info!(
119                "Rejecting env var '{}': value contains unsafe characters (newline, carriage return, or NUL)",
120                key
121            );
122            rejected.push(key.to_string());
123            continue;
124        };
125        parts.push(format!("{}={}", key, escaped));
126        applied.push(key.to_string());
127    }
128
129    let prefix = if parts.is_empty() {
130        String::new()
131    } else {
132        format!("{} ", parts.join(" "))
133    };
134
135    EnvPrefix {
136        prefix,
137        applied,
138        rejected,
139    }
140}
141
142/// Check if a string is a valid environment variable key.
143pub fn is_valid_env_key(key: &str) -> bool {
144    let mut chars = key.chars();
145    let Some(first) = chars.next() else {
146        return false;
147    };
148    if !(first == '_' || first.is_ascii_alphabetic()) {
149        return false;
150    }
151    chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
152}
153
154/// Escape a string for use in a shell command.
155///
156/// Wraps the string in single quotes and escapes internal single quotes.
157/// Returns None if the string contains unsafe control characters (newline, carriage return, NUL).
158pub fn shell_escape_value(value: &str) -> Option<String> {
159    // Reject values with control characters that could break shell parsing
160    // Note: These are logged at the call site with the variable name for debugging
161    if value.contains('\n') || value.contains('\r') || value.contains('\0') {
162        return None;
163    }
164
165    if value.is_empty() {
166        return Some("''".to_string());
167    }
168
169    let needs_quotes = value
170        .chars()
171        .any(|c| !c.is_ascii_alphanumeric() && c != '_');
172    if !needs_quotes {
173        return Some(value.to_string());
174    }
175
176    let mut escaped = String::with_capacity(value.len() + 2);
177    escaped.push('\'');
178    for ch in value.chars() {
179        if ch == '\'' {
180            escaped.push_str("'\\''");
181        } else {
182            escaped.push(ch);
183        }
184    }
185    escaped.push('\'');
186    Some(escaped)
187}
188
189/// Escape a path for use in a shell command, allowing `~` to expand to `$HOME`.
190///
191/// For paths that start with `~/` or are exactly `~`, this returns a double-quoted
192/// string that expands `$HOME` while escaping special characters inside the suffix.
193/// For all other paths, this defers to `shell_escape_value`.
194pub fn shell_escape_path_with_home(path: &str) -> Option<String> {
195    if path.contains('\n') || path.contains('\r') || path.contains('\0') {
196        return None;
197    }
198
199    if path == "~" {
200        return Some("\"$HOME\"".to_string());
201    }
202
203    if let Some(suffix) = path.strip_prefix("~/") {
204        let escaped_suffix = escape_for_double_quotes(suffix);
205        return Some(format!("\"$HOME/{}\"", escaped_suffix));
206    }
207
208    shell_escape_value(path)
209}
210
211fn escape_for_double_quotes(value: &str) -> String {
212    let mut escaped = String::with_capacity(value.len());
213    for ch in value.chars() {
214        match ch {
215            '\\' => escaped.push_str("\\\\"),
216            '"' => escaped.push_str("\\\""),
217            '$' => escaped.push_str("\\$"),
218            '`' => escaped.push_str("\\`"),
219            _ => escaped.push(ch),
220        }
221    }
222    escaped
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::test_guard;
229
230    #[test]
231    fn test_retryable_transport_error_text() {
232        let _guard = test_guard!();
233        assert!(is_retryable_transport_error_text(
234            "ssh: connect to host 1.2.3.4 port 22: Connection timed out"
235        ));
236        assert!(is_retryable_transport_error_text(
237            "kex_exchange_identification: Connection reset by peer"
238        ));
239        assert!(is_retryable_transport_error_text("Broken pipe"));
240        assert!(is_retryable_transport_error_text("Network is unreachable"));
241    }
242
243    #[test]
244    fn test_non_retryable_transport_error_text() {
245        let _guard = test_guard!();
246        assert!(!is_retryable_transport_error_text(
247            "Permission denied (publickey)."
248        ));
249        assert!(!is_retryable_transport_error_text(
250            "Host key verification failed."
251        ));
252        assert!(!is_retryable_transport_error_text(
253            "Could not resolve hostname worker.example.com: Name or service not known"
254        ));
255        assert!(!is_retryable_transport_error_text(
256            "Identity file /nope/id_rsa not accessible: No such file or directory"
257        ));
258    }
259
260    #[test]
261    fn test_command_result_success() {
262        let _guard = test_guard!();
263        let result = CommandResult {
264            exit_code: 0,
265            stdout: "output".to_string(),
266            stderr: String::new(),
267            duration_ms: 100,
268        };
269        assert!(result.success());
270
271        let failed = CommandResult {
272            exit_code: 1,
273            stdout: String::new(),
274            stderr: "error".to_string(),
275            duration_ms: 50,
276        };
277        assert!(!failed.success());
278    }
279
280    #[test]
281    fn test_shell_escape_value() {
282        let _guard = test_guard!();
283        // Simple value
284        assert_eq!(shell_escape_value("simple"), Some("simple".to_string()));
285
286        // Empty string
287        assert_eq!(shell_escape_value(""), Some("''".to_string()));
288
289        // With spaces
290        assert_eq!(
291            shell_escape_value("with spaces"),
292            Some("'with spaces'".to_string())
293        );
294
295        // With single quote
296        assert_eq!(shell_escape_value("it's"), Some("'it'\\''s'".to_string()));
297
298        // Unsafe values
299        assert!(shell_escape_value("line1\nline2").is_none());
300        assert!(shell_escape_value("line1\rline2").is_none());
301        assert!(shell_escape_value("line1\0line2").is_none());
302    }
303
304    #[test]
305    fn test_shell_escape_path_with_home() {
306        let _guard = test_guard!();
307        assert_eq!(
308            shell_escape_path_with_home("~/.local/bin"),
309            Some("\"$HOME/.local/bin\"".to_string())
310        );
311        assert_eq!(
312            shell_escape_path_with_home("~"),
313            Some("\"$HOME\"".to_string())
314        );
315        assert_eq!(
316            shell_escape_path_with_home("/usr/local/bin"),
317            Some("'/usr/local/bin'".to_string())
318        );
319    }
320
321    #[test]
322    fn test_is_valid_env_key() {
323        let _guard = test_guard!();
324        assert!(is_valid_env_key("PATH"));
325        assert!(is_valid_env_key("_PRIVATE"));
326        assert!(is_valid_env_key("MY_VAR_123"));
327        assert!(!is_valid_env_key("123VAR"));
328        assert!(!is_valid_env_key("MY-VAR"));
329        assert!(!is_valid_env_key(""));
330    }
331
332    #[test]
333    fn test_build_env_prefix() {
334        let _guard = test_guard!();
335        use std::collections::HashMap;
336
337        let mut env = HashMap::new();
338        env.insert("RUSTFLAGS".to_string(), "-C target-cpu=native".to_string());
339        env.insert("QUOTED".to_string(), "a'b".to_string());
340        env.insert("BADVAL".to_string(), "line1\nline2".to_string());
341
342        let allowlist = vec![
343            "RUSTFLAGS".to_string(),
344            "QUOTED".to_string(),
345            "MISSING".to_string(),
346            "BADVAL".to_string(),
347            "BAD=KEY".to_string(),
348        ];
349
350        let prefix = build_env_prefix(&allowlist, |key| env.get(key).cloned());
351
352        assert!(prefix.prefix.contains("RUSTFLAGS='-C target-cpu=native'"));
353        assert!(prefix.prefix.contains("QUOTED='a'\\''b'"));
354        assert!(!prefix.prefix.contains("MISSING="));
355        assert!(!prefix.prefix.contains("BADVAL="));
356        assert!(prefix.rejected.contains(&"BADVAL".to_string()));
357        assert!(prefix.rejected.contains(&"BAD=KEY".to_string()));
358    }
359}