Skip to main content

bamboo_agent/core/
encryption.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit},
3    Aes256Gcm, Nonce,
4};
5use anyhow::{anyhow, Result};
6use rand::Rng;
7use sha2::{Digest, Sha256};
8use std::process::Command;
9
10const KEY_ENV_VAR: &str = "BAMBOO_CONFIG_ENCRYPTION_KEY";
11const KEY_DERIVATION_CONTEXT: &[u8] = b"bamboo-config-encryption-v1";
12
13#[cfg(test)]
14use std::cell::RefCell;
15
16// Test-only override to avoid mutating process-wide environment variables (which
17// is not thread-safe under concurrent test execution). Thread-local to prevent
18// cross-test interference under `cargo test` parallelism.
19#[cfg(test)]
20thread_local! {
21    static TEST_KEY_OVERRIDE: RefCell<Option<Vec<u8>>> = const { RefCell::new(None) };
22}
23
24/// Get the encryption key.
25/// Priority: environment variable, machine-derived key, then random fallback.
26pub fn get_encryption_key() -> Vec<u8> {
27    #[cfg(test)]
28    if let Some(key) = TEST_KEY_OVERRIDE.with(|cell| cell.borrow().clone()) {
29        return key;
30    }
31
32    if let Some(key) = read_env_key() {
33        return key;
34    }
35
36    if let Some(machine_id) = machine_identifier() {
37        return derive_key(machine_id.as_bytes());
38    }
39
40    // Last-resort fallback keeps behavior safe if host identifiers are unavailable.
41    rand::thread_rng().gen::<[u8; 32]>().to_vec()
42}
43
44#[cfg(test)]
45pub struct TestKeyGuard {
46    previous: Option<Vec<u8>>,
47}
48
49#[cfg(test)]
50impl Drop for TestKeyGuard {
51    fn drop(&mut self) {
52        TEST_KEY_OVERRIDE.with(|cell| {
53            *cell.borrow_mut() = self.previous.clone();
54        });
55    }
56}
57
58#[cfg(test)]
59pub fn set_test_encryption_key(key: [u8; 32]) -> TestKeyGuard {
60    let previous = TEST_KEY_OVERRIDE.with(|cell| cell.replace(Some(key.to_vec())));
61    TestKeyGuard { previous }
62}
63
64fn read_env_key() -> Option<Vec<u8>> {
65    let key_hex = std::env::var(KEY_ENV_VAR).ok()?;
66    let key = hex::decode(key_hex).ok()?;
67    (key.len() == 32).then_some(key)
68}
69
70fn derive_key(material: &[u8]) -> Vec<u8> {
71    let mut hasher = Sha256::new();
72    hasher.update(KEY_DERIVATION_CONTEXT);
73    hasher.update(material);
74    hasher.finalize().to_vec()
75}
76
77fn machine_identifier() -> Option<String> {
78    read_machine_id().or_else(derived_fallback_identifier)
79}
80
81fn read_machine_id() -> Option<String> {
82    for path in ["/etc/machine-id", "/var/lib/dbus/machine-id"] {
83        if let Some(machine_id) = read_trimmed_file(path) {
84            return Some(machine_id);
85        }
86    }
87
88    #[cfg(target_os = "macos")]
89    {
90        if let Some(machine_id) = read_macos_platform_uuid() {
91            return Some(machine_id);
92        }
93    }
94
95    None
96}
97
98fn derived_fallback_identifier() -> Option<String> {
99    let mut parts = vec![
100        format!("os={}", std::env::consts::OS),
101        format!("arch={}", std::env::consts::ARCH),
102    ];
103
104    if let Some(hostname) = system_hostname() {
105        parts.push(format!("host={hostname}"));
106    }
107    if let Some(username) = read_first_env_var(&["USER", "USERNAME"]) {
108        parts.push(format!("user={username}"));
109    }
110    if let Some(home) = read_first_env_path(&["HOME", "USERPROFILE"]) {
111        parts.push(format!("home={home}"));
112    }
113    if let Ok(exe_path) = std::env::current_exe() {
114        parts.push(format!("exe={}", exe_path.display()));
115    }
116
117    (parts.len() > 2).then(|| parts.join("|"))
118}
119
120fn system_hostname() -> Option<String> {
121    if let Some(hostname) = read_first_env_var(&["HOSTNAME", "COMPUTERNAME"]) {
122        return Some(hostname);
123    }
124
125    if let Some(hostname) = read_trimmed_file("/etc/hostname") {
126        return Some(hostname);
127    }
128
129    #[cfg(target_os = "macos")]
130    {
131        if let Some(hostname) = run_command_first_line("scutil", &["--get", "ComputerName"]) {
132            return Some(hostname);
133        }
134        if let Some(hostname) = run_command_first_line("scutil", &["--get", "LocalHostName"]) {
135            return Some(hostname);
136        }
137    }
138
139    run_command_first_line("hostname", &[])
140}
141
142fn read_first_env_var(keys: &[&str]) -> Option<String> {
143    keys.iter().find_map(|key| {
144        let value = std::env::var(key).ok()?;
145        let trimmed = value.trim();
146        (!trimmed.is_empty()).then(|| trimmed.to_string())
147    })
148}
149
150fn read_first_env_path(keys: &[&str]) -> Option<String> {
151    keys.iter().find_map(|key| {
152        let value = std::env::var_os(key)?;
153        let value = value.to_string_lossy();
154        let trimmed = value.trim();
155        (!trimmed.is_empty()).then(|| trimmed.to_string())
156    })
157}
158
159fn read_trimmed_file(path: &str) -> Option<String> {
160    let value = std::fs::read_to_string(path).ok()?;
161    let trimmed = value.trim();
162    (!trimmed.is_empty()).then(|| trimmed.to_string())
163}
164
165fn run_command_first_line(program: &str, args: &[&str]) -> Option<String> {
166    let output = Command::new(program).args(args).output().ok()?;
167    if !output.status.success() {
168        return None;
169    }
170
171    let stdout = String::from_utf8(output.stdout).ok()?;
172    let line = stdout.lines().next()?.trim();
173    (!line.is_empty()).then(|| line.to_string())
174}
175
176#[cfg(target_os = "macos")]
177fn read_macos_platform_uuid() -> Option<String> {
178    let output = Command::new("ioreg")
179        .args(["-rd1", "-c", "IOPlatformExpertDevice"])
180        .output()
181        .ok()?;
182    if !output.status.success() {
183        return None;
184    }
185
186    let stdout = String::from_utf8(output.stdout).ok()?;
187    extract_quoted_property(&stdout, "IOPlatformUUID")
188}
189
190#[cfg(target_os = "macos")]
191fn extract_quoted_property(content: &str, key: &str) -> Option<String> {
192    content.lines().find_map(|line| {
193        if !line.contains(key) {
194            return None;
195        }
196
197        let mut quoted = line.split('"').skip(1).step_by(2);
198        let found_key = quoted.next()?;
199        let value = quoted.next()?;
200        (found_key == key).then(|| value.trim().to_string())
201    })
202}
203
204/// Encrypt data.
205/// Returns: nonce(12 bytes) + ciphertext.
206pub fn encrypt(plaintext: &str) -> Result<String> {
207    let key = get_encryption_key();
208    let cipher =
209        Aes256Gcm::new_from_slice(&key).map_err(|e| anyhow!("Failed to create cipher: {e}"))?;
210
211    let nonce_bytes: [u8; 12] = rand::thread_rng().gen();
212    let nonce = Nonce::from(nonce_bytes);
213
214    let ciphertext = cipher
215        .encrypt(&nonce, plaintext.as_bytes())
216        .map_err(|e| anyhow!("Encryption failed: {e}"))?;
217
218    // Format: hex(nonce) + ":" + hex(ciphertext)
219    let result = format!("{}:{}", hex::encode(nonce_bytes), hex::encode(ciphertext));
220    Ok(result)
221}
222
223/// Decrypt data.
224pub fn decrypt(encrypted: &str) -> Result<String> {
225    let parts: Vec<&str> = encrypted.split(':').collect();
226    if parts.len() != 2 {
227        return Err(anyhow!("Invalid encrypted format"));
228    }
229
230    let nonce_bytes = hex::decode(parts[0]).map_err(|e| anyhow!("Invalid nonce: {e}"))?;
231    let ciphertext = hex::decode(parts[1]).map_err(|e| anyhow!("Invalid ciphertext: {e}"))?;
232
233    if nonce_bytes.len() != 12 {
234        return Err(anyhow!(
235            "Invalid nonce length: expected 12, got {}",
236            nonce_bytes.len()
237        ));
238    }
239
240    let key = get_encryption_key();
241    let cipher =
242        Aes256Gcm::new_from_slice(&key).map_err(|e| anyhow!("Failed to create cipher: {e}"))?;
243
244    let nonce_array: [u8; 12] = nonce_bytes.try_into().expect("nonce length checked above");
245    let nonce = Nonce::from(nonce_array);
246    let plaintext = cipher
247        .decrypt(&nonce, ciphertext.as_ref())
248        .map_err(|e| anyhow!("Decryption failed: {e}"))?;
249
250    String::from_utf8(plaintext).map_err(|e| anyhow!("Invalid UTF-8: {e}"))
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use std::ffi::OsString;
257    use std::sync::{Mutex, OnceLock};
258
259    struct EnvVarGuard {
260        key: &'static str,
261        previous: Option<OsString>,
262    }
263
264    impl EnvVarGuard {
265        fn set(key: &'static str, value: &str) -> Self {
266            let previous = std::env::var_os(key);
267            std::env::set_var(key, value);
268            Self { key, previous }
269        }
270
271        fn unset(key: &'static str) -> Self {
272            let previous = std::env::var_os(key);
273            std::env::remove_var(key);
274            Self { key, previous }
275        }
276    }
277
278    impl Drop for EnvVarGuard {
279        fn drop(&mut self) {
280            match &self.previous {
281                Some(value) => std::env::set_var(self.key, value),
282                None => std::env::remove_var(self.key),
283            }
284        }
285    }
286
287    fn env_lock() -> &'static Mutex<()> {
288        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
289        LOCK.get_or_init(|| Mutex::new(()))
290    }
291
292    #[test]
293    fn test_encrypt_decrypt() {
294        let _lock = env_lock().lock().unwrap_or_else(|e| e.into_inner());
295        let _key = EnvVarGuard::unset(KEY_ENV_VAR);
296        let plaintext = "my_secret_password";
297        let encrypted = encrypt(plaintext).unwrap();
298        let decrypted = decrypt(&encrypted).unwrap();
299        assert_eq!(plaintext, decrypted);
300    }
301
302    #[test]
303    fn test_get_encryption_key_prefers_valid_env_key() {
304        let _lock = env_lock().lock().unwrap_or_else(|e| e.into_inner());
305        let expected = "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
306        let _key = EnvVarGuard::set(KEY_ENV_VAR, expected);
307
308        assert_eq!(get_encryption_key(), hex::decode(expected).unwrap());
309    }
310
311    #[test]
312    fn test_get_encryption_key_is_stable_without_env_var() {
313        let _lock = env_lock().lock().unwrap_or_else(|e| e.into_inner());
314        let _key = EnvVarGuard::unset(KEY_ENV_VAR);
315
316        let first = get_encryption_key();
317        let second = get_encryption_key();
318
319        assert_eq!(first.len(), 32);
320        assert_eq!(second.len(), 32);
321        assert_eq!(first, second);
322    }
323
324    #[test]
325    fn test_get_encryption_key_ignores_invalid_env_key() {
326        let _lock = env_lock().lock().unwrap_or_else(|e| e.into_inner());
327        let _key = EnvVarGuard::set(KEY_ENV_VAR, "abcd");
328
329        let first = get_encryption_key();
330        let second = get_encryption_key();
331
332        assert_eq!(first.len(), 32);
333        assert_eq!(second.len(), 32);
334        assert_eq!(first, second);
335    }
336}