use std::fs;
use std::io::{Read, Write};
use std::path::PathBuf;
use crate::crypto::aes_gcm::{aes256_gcm_decrypt, aes256_gcm_encrypt};
use crate::crypto::sha256::sha256;
use crate::crypto::uuid::Uuid;
const SERVICE_NAME: &str = "reddb";
const KEYRING_FILE: &str = "keyring.enc";
#[derive(Debug, Clone)]
pub enum PasswordSource {
Flag(String),
EnvVar(String),
Keyring(String),
None,
}
impl PasswordSource {
pub fn password(&self) -> Option<&str> {
match self {
PasswordSource::Flag(p) => Some(p),
PasswordSource::EnvVar(p) => Some(p),
PasswordSource::Keyring(p) => Some(p),
PasswordSource::None => None,
}
}
pub fn is_encrypted(&self) -> bool {
!matches!(self, PasswordSource::None)
}
pub fn source_name(&self) -> &'static str {
match self {
PasswordSource::Flag(_) => "flag",
PasswordSource::EnvVar(_) => "env",
PasswordSource::Keyring(_) => "keyring",
PasswordSource::None => "none",
}
}
}
pub fn resolve_password(flag_password: Option<&str>) -> PasswordSource {
if let Some(pwd) = flag_password {
if !pwd.is_empty() {
return PasswordSource::Flag(pwd.to_string());
}
}
if let Ok(pwd) = std::env::var("REDDB_KEY").or_else(|_| std::env::var("REDBLUE_DB_KEY")) {
if !pwd.is_empty() {
return PasswordSource::EnvVar(pwd);
}
}
if let Some(pwd) = get_from_keyring() {
return PasswordSource::Keyring(pwd);
}
PasswordSource::None
}
pub fn get_from_keyring() -> Option<String> {
let keyring_path = get_keyring_path()?;
if !keyring_path.exists() {
return None;
}
let mut file = fs::File::open(&keyring_path).ok()?;
let mut encrypted_data = Vec::new();
file.read_to_end(&mut encrypted_data).ok()?;
if encrypted_data.len() < 28 {
return None;
}
let key = derive_keyring_key();
let nonce: [u8; 12] = encrypted_data[..12].try_into().ok()?;
let ciphertext_and_tag = &encrypted_data[12..];
let plaintext = aes256_gcm_decrypt(&key, &nonce, &[], ciphertext_and_tag).ok()?;
String::from_utf8(plaintext).ok()
}
pub fn save_to_keyring(password: &str) -> Result<(), String> {
let keyring_path = get_keyring_path().ok_or("Failed to determine keyring path")?;
if let Some(parent) = keyring_path.parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create keyring directory: {}", e))?;
}
let key = derive_keyring_key();
let nonce = generate_nonce();
let ciphertext_and_tag = aes256_gcm_encrypt(&key, &nonce, &[], password.as_bytes());
let mut data = Vec::with_capacity(12 + ciphertext_and_tag.len());
data.extend_from_slice(&nonce);
data.extend_from_slice(&ciphertext_and_tag);
let mut file = fs::File::create(&keyring_path)
.map_err(|e| format!("Failed to create keyring file: {}", e))?;
file.write_all(&data)
.map_err(|e| format!("Failed to write keyring: {}", e))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let permissions = std::fs::Permissions::from_mode(0o600);
fs::set_permissions(&keyring_path, permissions)
.map_err(|e| format!("Failed to set keyring permissions: {}", e))?;
}
Ok(())
}
pub fn clear_keyring() -> Result<(), String> {
let keyring_path = get_keyring_path().ok_or("Failed to determine keyring path")?;
if keyring_path.exists() {
fs::remove_file(&keyring_path).map_err(|e| format!("Failed to remove keyring: {}", e))?;
}
Ok(())
}
pub fn has_keyring_password() -> bool {
get_from_keyring().is_some()
}
fn get_keyring_path() -> Option<PathBuf> {
if let Ok(config_dir) = std::env::var("XDG_CONFIG_HOME") {
return Some(
PathBuf::from(config_dir)
.join(SERVICE_NAME)
.join(KEYRING_FILE),
);
}
if let Ok(home) = std::env::var("HOME") {
return Some(
PathBuf::from(home)
.join(".config")
.join(SERVICE_NAME)
.join(KEYRING_FILE),
);
}
if let Ok(appdata) = std::env::var("APPDATA") {
return Some(PathBuf::from(appdata).join(SERVICE_NAME).join(KEYRING_FILE));
}
None
}
fn derive_keyring_key() -> [u8; 32] {
let mut identity = String::new();
if let Ok(hostname) = std::env::var("HOSTNAME") {
identity.push_str(&hostname);
} else if let Ok(name) = std::env::var("COMPUTERNAME") {
identity.push_str(&name);
}
identity.push(':');
if let Ok(user) = std::env::var("USER") {
identity.push_str(&user);
} else if let Ok(user) = std::env::var("USERNAME") {
identity.push_str(&user);
}
identity.push(':');
if let Ok(home) = std::env::var("HOME") {
identity.push_str(&home);
} else if let Ok(home) = std::env::var("USERPROFILE") {
identity.push_str(&home);
}
identity.push_str(":reddb-keyring-v1");
sha256(identity.as_bytes())
}
fn generate_nonce() -> [u8; 12] {
let uuid = Uuid::new_v4();
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&uuid.as_bytes()[0..12]);
nonce
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static KEYRING_TEST_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_password_source_is_encrypted() {
assert!(PasswordSource::Flag("test".to_string()).is_encrypted());
assert!(PasswordSource::EnvVar("test".to_string()).is_encrypted());
assert!(PasswordSource::Keyring("test".to_string()).is_encrypted());
assert!(!PasswordSource::None.is_encrypted());
}
#[test]
fn test_password_source_name() {
assert_eq!(PasswordSource::Flag("".to_string()).source_name(), "flag");
assert_eq!(PasswordSource::EnvVar("".to_string()).source_name(), "env");
assert_eq!(
PasswordSource::Keyring("".to_string()).source_name(),
"keyring"
);
assert_eq!(PasswordSource::None.source_name(), "none");
}
#[test]
fn test_password_source_password() {
assert_eq!(
PasswordSource::Flag("mypass".to_string()).password(),
Some("mypass")
);
assert_eq!(
PasswordSource::EnvVar("envpass".to_string()).password(),
Some("envpass")
);
assert_eq!(
PasswordSource::Keyring("ringpass".to_string()).password(),
Some("ringpass")
);
assert_eq!(PasswordSource::None.password(), None);
}
#[test]
fn test_derive_keyring_key_deterministic() {
let key1 = derive_keyring_key();
let key2 = derive_keyring_key();
assert_eq!(key1, key2);
assert_eq!(key1.len(), 32);
}
#[test]
fn test_derive_keyring_key_length() {
let key = derive_keyring_key();
assert_eq!(key.len(), 32); }
#[test]
fn test_generate_nonce_uniqueness() {
let nonce1 = generate_nonce();
let nonce2 = generate_nonce();
assert_ne!(nonce1, nonce2);
assert_eq!(nonce1.len(), 12);
assert_eq!(nonce2.len(), 12);
}
#[test]
fn test_resolve_password_flag_priority() {
let result = resolve_password(Some("flag_password"));
assert!(matches!(result, PasswordSource::Flag(_)));
if let PasswordSource::Flag(pwd) = result {
assert_eq!(pwd, "flag_password");
}
}
#[test]
fn test_resolve_password_empty_flag() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
std::env::remove_var("REDDB_KEY");
let _ = clear_keyring();
let result = resolve_password(Some(""));
assert!(!matches!(result, PasswordSource::Flag(_)));
}
#[test]
fn test_resolve_password_env_var() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
std::env::set_var("REDDB_KEY", "env_test_password");
let result = resolve_password(None);
std::env::remove_var("REDDB_KEY");
assert!(matches!(result, PasswordSource::EnvVar(_)));
if let PasswordSource::EnvVar(pwd) = result {
assert_eq!(pwd, "env_test_password");
}
}
#[test]
fn test_resolve_password_flag_overrides_env() {
std::env::set_var("REDDB_KEY", "env_password");
let result = resolve_password(Some("flag_password"));
std::env::remove_var("REDDB_KEY");
assert!(matches!(result, PasswordSource::Flag(_)));
}
#[test]
fn test_keyring_save_and_retrieve() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let result = save_to_keyring("test_keyring_password_12345");
assert!(result.is_ok(), "Failed to save to keyring: {:?}", result);
let retrieved = get_from_keyring();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), "test_keyring_password_12345");
let _ = clear_keyring();
}
#[test]
fn test_keyring_has_password() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
assert!(!has_keyring_password());
let _ = save_to_keyring("check_password");
assert!(has_keyring_password());
let _ = clear_keyring();
assert!(!has_keyring_password());
}
#[test]
fn test_clear_keyring_nonexistent() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let result = clear_keyring();
assert!(result.is_ok());
}
#[test]
fn test_keyring_special_characters() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let special_password = "p@$$w0rd!#%&*()[]{}|;':\",./<>?`~";
let result = save_to_keyring(special_password);
assert!(result.is_ok());
let retrieved = get_from_keyring();
assert_eq!(retrieved, Some(special_password.to_string()));
let _ = clear_keyring();
}
#[test]
fn test_keyring_unicode_password() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let unicode_password = "пароль🔒密码パスワード";
let result = save_to_keyring(unicode_password);
assert!(result.is_ok());
let retrieved = get_from_keyring();
assert_eq!(retrieved, Some(unicode_password.to_string()));
let _ = clear_keyring();
}
#[test]
fn test_keyring_empty_password() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let result = save_to_keyring("");
assert!(result.is_ok());
let retrieved = get_from_keyring();
assert_eq!(retrieved, Some("".to_string()));
let _ = clear_keyring();
}
#[test]
fn test_keyring_long_password() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
let _ = clear_keyring();
let long_password = "x".repeat(10000);
let result = save_to_keyring(&long_password);
assert!(result.is_ok());
let retrieved = get_from_keyring();
assert_eq!(retrieved, Some(long_password));
let _ = clear_keyring();
}
#[test]
fn test_resolve_password_keyring_integration() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
std::env::remove_var("REDDB_KEY");
let _ = clear_keyring();
let _ = save_to_keyring("keyring_test_pwd");
let result = resolve_password(None);
assert!(matches!(result, PasswordSource::Keyring(_)));
if let PasswordSource::Keyring(pwd) = result {
assert_eq!(pwd, "keyring_test_pwd");
}
let _ = clear_keyring();
}
#[test]
fn test_resolve_password_none_when_empty() {
let _lock = KEYRING_TEST_LOCK.lock().unwrap();
std::env::remove_var("REDDB_KEY");
let _ = clear_keyring();
let result = resolve_password(None);
assert!(matches!(result, PasswordSource::None));
}
#[test]
fn test_get_keyring_path_returns_some() {
let path = get_keyring_path();
if let Some(p) = path {
assert!(p.to_string_lossy().contains("reddb"));
assert!(p.to_string_lossy().contains("keyring.enc"));
}
}
}