use crate::encryption::algorithms::{EncryptionAlgorithm, EncryptionKey};
use crate::encryption::errors::{EncryptionError, EncryptionResult};
use argon2::{
password_hash::{PasswordHasher, SaltString},
Algorithm, Argon2, Params, Version,
};
use pbkdf2::pbkdf2_hmac;
use rand::{thread_rng, Rng};
use sha2::Sha256;
#[derive(Debug, Clone)]
pub enum KeyDerivationMethod {
Argon2 {
memory_kib: u32,
iterations: u32,
parallelism: u32,
},
#[allow(dead_code)]
Pbkdf2 { iterations: u32 },
}
#[derive(Debug, Clone)]
pub struct KeyDerivationManager {
default_argon2_params: Argon2Params,
}
#[derive(Debug, Clone)]
pub struct Argon2Params {
pub memory_kib: u32,
pub iterations: u32,
pub parallelism: u32,
}
impl Default for Argon2Params {
fn default() -> Self {
Self {
memory_kib: 19456, iterations: 2,
parallelism: 1,
}
}
}
impl KeyDerivationManager {
pub fn new() -> Self {
Self {
default_argon2_params: Argon2Params::default(),
}
}
pub fn derive_master_key(&self, password: &str) -> EncryptionResult<EncryptionKey> {
self.derive_key(
password.as_bytes(),
KeyDerivationMethod::Argon2 {
memory_kib: self.default_argon2_params.memory_kib,
iterations: self.default_argon2_params.iterations,
parallelism: self.default_argon2_params.parallelism,
},
"master_key_salt",
EncryptionAlgorithm::Aes256Gcm,
)
}
#[allow(dead_code)]
pub async fn derive_master_key_async(
&self,
password: String,
) -> EncryptionResult<EncryptionKey> {
let params = self.default_argon2_params.clone();
tokio::task::spawn_blocking(move || {
let manager = Self::new();
manager.derive_key(
password.as_bytes(),
KeyDerivationMethod::Argon2 {
memory_kib: params.memory_kib,
iterations: params.iterations,
parallelism: params.parallelism,
},
"master_key_salt",
EncryptionAlgorithm::Aes256Gcm,
)
})
.await
.map_err(|e| EncryptionError::key_derivation_failed(format!("Task join error: {}", e)))?
}
#[allow(dead_code)]
pub fn derive_workspace_key(
&self,
master_key: &EncryptionKey,
workspace_id: &str,
) -> EncryptionResult<EncryptionKey> {
let master_bytes = master_key.as_bytes();
let workspace_bytes = workspace_id.as_bytes();
let mut derived_key = vec![0u8; 32];
pbkdf2_hmac::<Sha256>(
master_bytes,
workspace_bytes,
10000, &mut derived_key,
);
EncryptionKey::new(derived_key, EncryptionAlgorithm::Aes256Gcm)
}
pub fn derive_key(
&self,
secret: &[u8],
method: KeyDerivationMethod,
salt: &str,
algorithm: EncryptionAlgorithm,
) -> EncryptionResult<EncryptionKey> {
match method {
KeyDerivationMethod::Argon2 {
memory_kib,
iterations,
parallelism,
} => {
self.derive_key_argon2(secret, salt, memory_kib, iterations, parallelism, algorithm)
}
KeyDerivationMethod::Pbkdf2 { iterations } => {
self.derive_key_pbkdf2(secret, salt, iterations, algorithm)
}
}
}
fn derive_key_argon2(
&self,
secret: &[u8],
_salt: &str,
memory_kib: u32,
iterations: u32,
parallelism: u32,
algorithm: EncryptionAlgorithm,
) -> EncryptionResult<EncryptionKey> {
let salt = SaltString::encode_b64(b"randomsalt12345678901234567890123456789012").unwrap();
let params = Params::new(
memory_kib,
iterations,
parallelism,
Some(32), )
.map_err(|e| EncryptionError::key_derivation_failed(e.to_string()))?;
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
let password_hash = argon2
.hash_password(secret, &salt)
.map_err(|e| EncryptionError::key_derivation_failed(e.to_string()))?;
let hash_binding = password_hash.hash.unwrap();
let hash_bytes = hash_binding.as_bytes();
let key_bytes: Vec<u8> = hash_bytes.to_vec();
let key_len = match algorithm {
EncryptionAlgorithm::Aes256Gcm => 32,
EncryptionAlgorithm::ChaCha20Poly1305 => 32,
};
let final_key_bytes = if key_bytes.len() >= key_len {
key_bytes[..key_len].to_vec()
} else {
return Err(EncryptionError::key_derivation_failed(
"Derived key too short for algorithm",
));
};
EncryptionKey::new(final_key_bytes, algorithm)
}
fn derive_key_pbkdf2(
&self,
secret: &[u8],
salt: &str,
iterations: u32,
algorithm: EncryptionAlgorithm,
) -> EncryptionResult<EncryptionKey> {
let salt_bytes = salt.as_bytes();
let key_len = match algorithm {
EncryptionAlgorithm::Aes256Gcm => 32,
EncryptionAlgorithm::ChaCha20Poly1305 => 32,
};
let mut derived_key = vec![0u8; key_len];
pbkdf2_hmac::<Sha256>(secret, salt_bytes, iterations, &mut derived_key);
EncryptionKey::new(derived_key, algorithm)
}
#[allow(dead_code)]
pub fn verify_password(
&self,
password: &str,
expected_key: &EncryptionKey,
) -> EncryptionResult<bool> {
let derived_key = self.derive_master_key(password)?;
Ok(derived_key.as_bytes() == expected_key.as_bytes())
}
#[allow(dead_code)]
pub async fn verify_password_async(
&self,
password: String,
expected_key: EncryptionKey,
) -> EncryptionResult<bool> {
let params = self.default_argon2_params.clone();
tokio::task::spawn_blocking(move || {
let manager = KeyDerivationManager {
default_argon2_params: params,
};
let derived_key = manager.derive_master_key(&password)?;
Ok(derived_key.as_bytes() == expected_key.as_bytes())
})
.await
.map_err(|e| EncryptionError::key_derivation_failed(format!("Task join error: {}", e)))?
}
#[allow(dead_code)]
pub fn generate_salt() -> String {
let mut salt = [0u8; 16];
let mut rng = thread_rng();
rng.fill(&mut salt);
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, salt)
}
#[allow(dead_code)]
pub fn validate_parameters(&self, method: &KeyDerivationMethod) -> EncryptionResult<()> {
match method {
KeyDerivationMethod::Argon2 {
memory_kib,
iterations,
parallelism,
} => {
if *memory_kib < 8 {
return Err(EncryptionError::invalid_algorithm(
"Argon2 memory must be at least 8 KiB",
));
}
if *iterations < 1 {
return Err(EncryptionError::invalid_algorithm(
"Argon2 iterations must be at least 1",
));
}
if *parallelism < 1 {
return Err(EncryptionError::invalid_algorithm(
"Argon2 parallelism must be at least 1",
));
}
}
KeyDerivationMethod::Pbkdf2 { iterations } => {
if *iterations < 1000 {
return Err(EncryptionError::invalid_algorithm(
"PBKDF2 iterations should be at least 1000 for security",
));
}
}
}
Ok(())
}
}
impl Default for KeyDerivationManager {
fn default() -> Self {
Self::new()
}
}