1use std::process::Command;
5
6use crate::crypto::EncryptionKey;
7use crate::error::{CrablockError, Result};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum KeySource {
12 Env { var: String },
13 File { path: String },
14 Command { cmd: String },
15 Inline { key: String },
16}
17
18impl KeySource {
19 pub fn from_env(var: impl Into<String>) -> Self {
20 Self::Env { var: var.into() }
21 }
22
23 pub fn from_file(path: impl Into<String>) -> Self {
24 Self::File { path: path.into() }
25 }
26
27 pub fn from_command(cmd: impl Into<String>) -> Self {
28 Self::Command { cmd: cmd.into() }
29 }
30
31 pub fn from_inline(key: impl Into<String>) -> Self {
32 Self::Inline { key: key.into() }
33 }
34
35 pub fn retrieve(&self) -> Result<EncryptionKey> {
36 match self {
37 KeySource::Env { var } => {
38 let key_str = std::env::var(var).map_err(|_| {
39 CrablockError::KeySource(format!("Environment variable {var} not set"))
40 })?;
41 Self::parse_key(&key_str)
42 }
43 KeySource::File { path } => {
44 let key_str = std::fs::read_to_string(path).map_err(|e| {
45 CrablockError::KeySource(format!("Failed to read key file {path}: {e}"))
46 })?;
47 Self::parse_key(key_str.trim())
48 }
49 KeySource::Command { cmd } => {
50 let output = Command::new("sh")
51 .arg("-c")
52 .arg(cmd)
53 .output()
54 .map_err(|e| {
55 CrablockError::KeySource(format!("Failed to execute key command: {e}"))
56 })?;
57
58 if !output.status.success() {
59 return Err(CrablockError::KeySource(format!(
60 "Key command failed with exit code: {:?}",
61 output.status.code()
62 )));
63 }
64
65 let key_str = String::from_utf8(output.stdout).map_err(|e| {
66 CrablockError::KeySource(format!("Invalid UTF-8 in key output: {e}"))
67 })?;
68 Self::parse_key(key_str.trim())
69 }
70 KeySource::Inline { key } => Self::parse_key(key),
71 }
72 }
73
74 fn parse_key(key_str: &str) -> Result<EncryptionKey> {
75 if key_str.len() == 64 && key_str.chars().all(|c| c.is_ascii_hexdigit()) {
78 return EncryptionKey::from_hex(key_str);
79 }
80
81 if key_str.len() >= 32 {
83 if let Ok(key) = EncryptionKey::from_base64(key_str) {
84 return Ok(key);
85 }
86 }
87
88 if key_str.len() == 32 {
90 let mut key = [0u8; 32];
91 key.copy_from_slice(key_str.as_bytes());
92 return Ok(EncryptionKey::new(key));
93 }
94
95 Err(CrablockError::InvalidKey(format!(
96 "Key must be 32 bytes (64 hex chars or ~44 base64). Got {} characters",
97 key_str.len()
98 )))
99 }
100}
101
102pub fn parse_key_source(key_str: &str) -> KeySource {
103 if let Some(env) = key_str.strip_prefix("env:") {
106 KeySource::from_env(env)
107 } else if let Some(path) = key_str.strip_prefix("file:") {
108 KeySource::from_file(path)
109 } else if let Some(cmd) = key_str.strip_prefix("cmd:") {
110 KeySource::from_command(cmd)
111 } else {
112 KeySource::from_inline(key_str)
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use std::env;
120
121 #[test]
122 fn test_key_from_hex_env() {
123 let hex_key = "aabbccddaabbccddaabbccddaabbccddaabbccddaabbccddaabbccddaabbccdd";
124 env::set_var("TEST_HEX_KEY", hex_key);
125
126 let source = KeySource::from_env("TEST_HEX_KEY");
127 let key = source.retrieve().unwrap();
128
129 assert_eq!(hex::encode(key.key), hex_key);
130
131 env::remove_var("TEST_HEX_KEY");
132 }
133
134 #[test]
135 fn test_key_from_inline() {
136 let hex_key = "11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff";
137 let source = KeySource::from_inline(hex_key);
138 let key = source.retrieve().unwrap();
139
140 assert_eq!(hex::encode(key.key), hex_key);
141 }
142
143 #[test]
144 fn test_invalid_key() {
145 let source = KeySource::from_inline("too_short");
146 let result = source.retrieve();
147 assert!(result.is_err());
148 }
149}