use std::process::Command;
use crate::crypto::EncryptionKey;
use crate::error::{CrablockError, Result};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum KeySource {
Env { var: String },
File { path: String },
Command { cmd: String },
Inline { key: String },
}
impl KeySource {
pub fn from_env(var: impl Into<String>) -> Self {
Self::Env { var: var.into() }
}
pub fn from_file(path: impl Into<String>) -> Self {
Self::File { path: path.into() }
}
pub fn from_command(cmd: impl Into<String>) -> Self {
Self::Command { cmd: cmd.into() }
}
pub fn from_inline(key: impl Into<String>) -> Self {
Self::Inline { key: key.into() }
}
pub fn retrieve(&self) -> Result<EncryptionKey> {
match self {
KeySource::Env { var } => {
let key_str = std::env::var(var).map_err(|_| {
CrablockError::KeySource(format!("Environment variable {var} not set"))
})?;
Self::parse_key(&key_str)
}
KeySource::File { path } => {
let key_str = std::fs::read_to_string(path).map_err(|e| {
CrablockError::KeySource(format!("Failed to read key file {path}: {e}"))
})?;
Self::parse_key(key_str.trim())
}
KeySource::Command { cmd } => {
let output = Command::new("sh")
.arg("-c")
.arg(cmd)
.output()
.map_err(|e| {
CrablockError::KeySource(format!("Failed to execute key command: {e}"))
})?;
if !output.status.success() {
return Err(CrablockError::KeySource(format!(
"Key command failed with exit code: {:?}",
output.status.code()
)));
}
let key_str = String::from_utf8(output.stdout).map_err(|e| {
CrablockError::KeySource(format!("Invalid UTF-8 in key output: {e}"))
})?;
Self::parse_key(key_str.trim())
}
KeySource::Inline { key } => Self::parse_key(key),
}
}
fn parse_key(key_str: &str) -> Result<EncryptionKey> {
if key_str.len() == 64 && key_str.chars().all(|c| c.is_ascii_hexdigit()) {
return EncryptionKey::from_hex(key_str);
}
if key_str.len() >= 32 {
if let Ok(key) = EncryptionKey::from_base64(key_str) {
return Ok(key);
}
}
if key_str.len() == 32 {
let mut key = [0u8; 32];
key.copy_from_slice(key_str.as_bytes());
return Ok(EncryptionKey::new(key));
}
Err(CrablockError::InvalidKey(format!(
"Key must be 32 bytes (64 hex chars or ~44 base64). Got {} characters",
key_str.len()
)))
}
}
pub fn parse_key_source(key_str: &str) -> KeySource {
if let Some(env) = key_str.strip_prefix("env:") {
KeySource::from_env(env)
} else if let Some(path) = key_str.strip_prefix("file:") {
KeySource::from_file(path)
} else if let Some(cmd) = key_str.strip_prefix("cmd:") {
KeySource::from_command(cmd)
} else {
KeySource::from_inline(key_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_key_from_hex_env() {
let hex_key = "aabbccddaabbccddaabbccddaabbccddaabbccddaabbccddaabbccddaabbccdd";
env::set_var("TEST_HEX_KEY", hex_key);
let source = KeySource::from_env("TEST_HEX_KEY");
let key = source.retrieve().unwrap();
assert_eq!(hex::encode(key.key), hex_key);
env::remove_var("TEST_HEX_KEY");
}
#[test]
fn test_key_from_inline() {
let hex_key = "11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff";
let source = KeySource::from_inline(hex_key);
let key = source.retrieve().unwrap();
assert_eq!(hex::encode(key.key), hex_key);
}
#[test]
fn test_invalid_key() {
let source = KeySource::from_inline("too_short");
let result = source.retrieve();
assert!(result.is_err());
}
}