ralph_workflow/runtime/
environment.rs1use std::collections::HashMap;
7
8pub trait Environment: Send + Sync {
10 fn var(&self, key: &str) -> Option<String>;
12
13 fn vars(&self) -> HashMap<String, String>;
15}
16
17pub struct RealEnvironment;
19
20impl Environment for RealEnvironment {
21 fn var(&self, key: &str) -> Option<String> {
22 std::env::var(key).ok()
23 }
24
25 fn vars(&self) -> HashMap<String, String> {
26 std::env::vars().collect()
27 }
28}
29
30pub trait GitEnvironment: Send + Sync {
35 fn configure_git_ssh_command(&self, key_path: &str) -> Result<(), GitEnvError>;
37
38 fn disable_git_terminal_prompt(&self) -> Result<(), GitEnvError>;
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct GitEnvError(String);
45
46impl GitEnvError {
47 #[must_use]
48 pub fn new(msg: impl Into<String>) -> Self {
49 Self(msg.into())
50 }
51}
52
53impl std::fmt::Display for GitEnvError {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}", self.0)
56 }
57}
58
59impl std::error::Error for GitEnvError {}
60
61pub struct RealGitEnvironment;
63
64fn validate_ssh_key_path(key_path: &str) -> Result<(), GitEnvError> {
65 if key_path.trim().is_empty() {
66 return Err(GitEnvError::new("empty SSH key path"));
67 }
68 if key_path.contains('\0') || key_path.contains('\n') || key_path.contains('\r') {
69 return Err(GitEnvError::new("SSH key path contains invalid characters"));
70 }
71 Ok(())
72}
73
74impl GitEnvironment for RealGitEnvironment {
75 fn configure_git_ssh_command(&self, key_path: &str) -> Result<(), GitEnvError> {
76 validate_ssh_key_path(key_path)?;
77 let escaped = shell_escape_posix(key_path);
78 let cmd = format!("ssh -o 'IdentitiesOnly=yes' -i {escaped}");
79 std::env::set_var("GIT_SSH_COMMAND", &cmd);
80 Ok(())
81 }
82
83 fn disable_git_terminal_prompt(&self) -> Result<(), GitEnvError> {
84 std::env::set_var("GIT_TERMINAL_PROMPT", "0");
85 Ok(())
86 }
87}
88
89fn shell_escape_posix(s: &str) -> String {
90 let inner: String = s
91 .chars()
92 .flat_map(|ch| {
93 if ch == '\'' {
94 "'\"'\"'".chars().collect::<Vec<_>>()
95 } else {
96 vec![ch]
97 }
98 })
99 .collect();
100 format!("'{inner}'")
101}
102
103#[cfg(any(test, feature = "test-utils"))]
107pub mod mock {
108 use super::GitEnvError;
109 use std::sync::Mutex;
110
111 pub struct MockGitEnvironment {
112 pub ssh_commands: Mutex<Vec<String>>,
113 pub terminal_prompts_disabled: Mutex<bool>,
114 pub errors: Mutex<Vec<GitEnvError>>,
115 }
116
117 impl Clone for MockGitEnvironment {
118 fn clone(&self) -> Self {
119 Self {
120 ssh_commands: Mutex::new(self.ssh_commands.lock().unwrap().clone()),
121 terminal_prompts_disabled: Mutex::new(
122 *self.terminal_prompts_disabled.lock().unwrap(),
123 ),
124 errors: Mutex::new(self.errors.lock().unwrap().clone()),
125 }
126 }
127 }
128
129 impl MockGitEnvironment {
130 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 ssh_commands: Mutex::new(Vec::new()),
134 terminal_prompts_disabled: Mutex::new(false),
135 errors: Mutex::new(Vec::new()),
136 }
137 }
138
139 #[must_use]
140 pub fn configured_ssh_keys(&self) -> Vec<String> {
141 self.ssh_commands.lock().unwrap().clone()
142 }
143
144 #[must_use]
145 pub fn terminal_prompt_disabled(&self) -> bool {
146 *self.terminal_prompts_disabled.lock().unwrap()
147 }
148 }
149
150 impl Default for MockGitEnvironment {
151 fn default() -> Self {
152 Self::new()
153 }
154 }
155
156 impl super::GitEnvironment for MockGitEnvironment {
157 fn configure_git_ssh_command(&self, key_path: &str) -> Result<(), GitEnvError> {
158 super::validate_ssh_key_path(key_path)?;
159 let escaped = super::shell_escape_posix(key_path);
160 let cmd = format!("ssh -o 'IdentitiesOnly=yes' -i {escaped}");
161 self.ssh_commands.lock().unwrap().push(cmd);
162 Ok(())
163 }
164
165 fn disable_git_terminal_prompt(&self) -> Result<(), GitEnvError> {
166 *self.terminal_prompts_disabled.lock().unwrap() = true;
167 Ok(())
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::mock::MockGitEnvironment;
175 use super::GitEnvironment;
176
177 #[test]
178 fn mock_git_environment_configures_ssh_command() {
179 let env = MockGitEnvironment::new();
180 env.configure_git_ssh_command("/home/user/.ssh/id_rsa")
181 .unwrap();
182 let keys = env.configured_ssh_keys();
183 assert_eq!(keys.len(), 1);
184 assert!(keys[0].contains("id_rsa"));
185 }
186
187 #[test]
188 fn mock_git_environment_rejects_empty_ssh_key_path() {
189 let env = MockGitEnvironment::new();
190 let result = env.configure_git_ssh_command("");
191 assert!(result.is_err());
192 }
193
194 #[test]
195 fn mock_git_environment_rejects_newline_in_ssh_key_path() {
196 let env = MockGitEnvironment::new();
197 let result = env.configure_git_ssh_command("/tmp/key\n-oProxyCommand=evil");
198 assert!(result.is_err());
199 }
200
201 #[test]
202 fn mock_git_environment_rejects_carriage_return_in_ssh_key_path() {
203 let env = MockGitEnvironment::new();
204 let result = env.configure_git_ssh_command("/tmp/key\r-oProxyCommand=evil");
205 assert!(result.is_err());
206 }
207
208 #[test]
209 fn mock_git_environment_disables_terminal_prompt() {
210 let env = MockGitEnvironment::new();
211 env.disable_git_terminal_prompt().unwrap();
212 assert!(env.terminal_prompt_disabled());
213 }
214
215 #[test]
216 fn shell_escape_wraps_in_single_quotes() {
217 assert_eq!(super::shell_escape_posix("/a b"), "'/a b'");
218 }
219
220 #[test]
221 fn shell_escape_handles_single_quotes() {
222 assert_eq!(super::shell_escape_posix("a'b"), "'a'\"'\"'b'");
223 }
224}