#![allow(dead_code)]
use crate::encryption::{EncryptionKey, decrypt, encrypt, generate_nonce};
use blake3::Hasher;
use rand::RngExt as _;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AbeError {
#[error("Encryption failed: {0}")]
EncryptionFailed(String),
#[error("Decryption failed: {0}")]
DecryptionFailed(String),
#[error("Policy evaluation failed: {0}")]
PolicyFailed(String),
#[error("Invalid attributes: {0}")]
InvalidAttributes(String),
#[error("Key derivation failed: {0}")]
KeyDerivationFailed(String),
#[error("Serialization error: {0}")]
SerializationError(String),
}
pub type AbeResult<T> = Result<T, AbeError>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PolicyNode {
Attribute(String),
And(Vec<PolicyNode>),
Or(Vec<PolicyNode>),
Threshold { k: usize, children: Vec<PolicyNode> },
}
impl PolicyNode {
pub fn evaluate(&self, attributes: &HashSet<String>) -> bool {
match self {
PolicyNode::Attribute(attr) => attributes.contains(attr),
PolicyNode::And(children) => children.iter().all(|c| c.evaluate(attributes)),
PolicyNode::Or(children) => children.iter().any(|c| c.evaluate(attributes)),
PolicyNode::Threshold { k, children } => {
let satisfied = children.iter().filter(|c| c.evaluate(attributes)).count();
satisfied >= *k
}
}
}
pub fn get_attributes(&self) -> HashSet<String> {
let mut attrs = HashSet::new();
self.collect_attributes(&mut attrs);
attrs
}
fn collect_attributes(&self, attrs: &mut HashSet<String>) {
match self {
PolicyNode::Attribute(attr) => {
attrs.insert(attr.clone());
}
PolicyNode::And(children) | PolicyNode::Or(children) => {
for child in children {
child.collect_attributes(attrs);
}
}
PolicyNode::Threshold { children, .. } => {
for child in children {
child.collect_attributes(attrs);
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessPolicy {
root: PolicyNode,
}
impl AccessPolicy {
pub fn new(root: PolicyNode) -> Self {
Self { root }
}
pub fn and(nodes: Vec<PolicyNode>) -> Self {
Self::new(PolicyNode::And(nodes))
}
pub fn or(nodes: Vec<PolicyNode>) -> Self {
Self::new(PolicyNode::Or(nodes))
}
pub fn threshold(k: usize, children: Vec<PolicyNode>) -> Self {
Self::new(PolicyNode::Threshold { k, children })
}
pub fn evaluate(&self, attributes: &HashSet<String>) -> bool {
self.root.evaluate(attributes)
}
pub fn get_attributes(&self) -> HashSet<String> {
self.root.get_attributes()
}
}
#[derive(Clone)]
pub struct MasterSecretKey {
seed: [u8; 32],
}
impl MasterSecretKey {
fn new() -> Self {
let mut seed = [0u8; 32];
rand::rng().fill(&mut seed);
Self { seed }
}
fn derive_attribute_key(&self, attribute: &str) -> [u8; 32] {
let mut hasher = Hasher::new();
hasher.update(&self.seed);
hasher.update(b"attribute:");
hasher.update(attribute.as_bytes());
*hasher.finalize().as_bytes()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct UserSecretKey {
attribute_keys: HashMap<String, [u8; 32]>,
}
impl UserSecretKey {
fn new(attribute_keys: HashMap<String, [u8; 32]>) -> Self {
Self { attribute_keys }
}
pub fn get_attributes(&self) -> HashSet<String> {
self.attribute_keys.keys().cloned().collect()
}
pub fn has_attribute(&self, attribute: &str) -> bool {
self.attribute_keys.contains_key(attribute)
}
}
#[derive(Clone, Serialize, Deserialize)]
struct EncryptedDek {
ciphertext: Vec<u8>,
nonce: [u8; 12],
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AbeCiphertext {
policy: AccessPolicy,
encrypted_keys: HashMap<String, EncryptedDek>,
ciphertext: Vec<u8>,
nonce: [u8; 12],
}
impl AbeCiphertext {
pub fn policy(&self) -> &AccessPolicy {
&self.policy
}
pub fn to_bytes(&self) -> AbeResult<Vec<u8>> {
crate::codec::encode(self)
.map_err(|e| AbeError::SerializationError(format!("Serialization failed: {}", e)))
}
pub fn from_bytes(bytes: &[u8]) -> AbeResult<Self> {
crate::codec::decode(bytes)
.map_err(|e| AbeError::SerializationError(format!("Deserialization failed: {}", e)))
}
}
pub struct AbeAuthority {
master_key: MasterSecretKey,
}
impl AbeAuthority {
pub fn new() -> Self {
Self {
master_key: MasterSecretKey::new(),
}
}
pub fn from_master_key(seed: [u8; 32]) -> Self {
Self {
master_key: MasterSecretKey { seed },
}
}
pub fn generate_user_key(&self, attributes: &[String]) -> AbeResult<UserSecretKey> {
if attributes.is_empty() {
return Err(AbeError::InvalidAttributes(
"Attributes list cannot be empty".to_string(),
));
}
let mut attribute_keys = HashMap::new();
for attr in attributes {
let key = self.master_key.derive_attribute_key(attr);
attribute_keys.insert(attr.clone(), key);
}
Ok(UserSecretKey::new(attribute_keys))
}
pub fn encrypt(&self, policy: &AccessPolicy, plaintext: &[u8]) -> AbeResult<AbeCiphertext> {
let mut dek = [0u8; 32];
rand::rng().fill(&mut dek);
let nonce = generate_nonce();
let encryption_key = EncryptionKey::from(dek);
let ciphertext = encrypt(plaintext, &encryption_key, &nonce)
.map_err(|e| AbeError::EncryptionFailed(format!("Failed to encrypt: {}", e)))?;
let mut encrypted_keys = HashMap::new();
for attr in policy.get_attributes() {
let attr_key = self.master_key.derive_attribute_key(&attr);
let attr_enc_key = EncryptionKey::from(attr_key);
let attr_nonce = generate_nonce();
let encrypted_dek_bytes = encrypt(&dek, &attr_enc_key, &attr_nonce)
.map_err(|e| AbeError::EncryptionFailed(format!("Failed to encrypt DEK: {}", e)))?;
encrypted_keys.insert(
attr,
EncryptedDek {
ciphertext: encrypted_dek_bytes,
nonce: attr_nonce,
},
);
}
Ok(AbeCiphertext {
policy: policy.clone(),
encrypted_keys,
ciphertext,
nonce,
})
}
pub fn decrypt(
&self,
user_key: &UserSecretKey,
ciphertext: &AbeCiphertext,
) -> AbeResult<Vec<u8>> {
let user_attrs = user_key.get_attributes();
if !ciphertext.policy.evaluate(&user_attrs) {
return Err(AbeError::DecryptionFailed(
"User attributes do not satisfy access policy".to_string(),
));
}
let mut dek = None;
for (attr, attr_key) in &user_key.attribute_keys {
if let Some(encrypted_dek) = ciphertext.encrypted_keys.get(attr) {
let attr_enc_key = EncryptionKey::from(*attr_key);
let decrypted = decrypt(
&encrypted_dek.ciphertext,
&attr_enc_key,
&encrypted_dek.nonce,
);
if let Ok(dek_bytes) = decrypted {
if dek_bytes.len() == 32 {
let mut dek_arr = [0u8; 32];
dek_arr.copy_from_slice(&dek_bytes);
dek = Some(dek_arr);
break;
}
}
}
}
let dek = dek.ok_or_else(|| {
AbeError::DecryptionFailed("Could not recover data encryption key".to_string())
})?;
let encryption_key = EncryptionKey::from(dek);
decrypt(&ciphertext.ciphertext, &encryption_key, &ciphertext.nonce)
.map_err(|e| AbeError::DecryptionFailed(format!("Failed to decrypt payload: {}", e)))
}
pub fn export_master_key(&self) -> [u8; 32] {
self.master_key.seed
}
}
impl Default for AbeAuthority {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_evaluation_single_attribute() {
let policy = PolicyNode::Attribute("admin".to_string());
let mut attrs = HashSet::new();
attrs.insert("admin".to_string());
assert!(policy.evaluate(&attrs));
attrs.clear();
attrs.insert("user".to_string());
assert!(!policy.evaluate(&attrs));
}
#[test]
fn test_policy_evaluation_and() {
let policy = PolicyNode::And(vec![
PolicyNode::Attribute("admin".to_string()),
PolicyNode::Attribute("premium".to_string()),
]);
let mut attrs = HashSet::new();
attrs.insert("admin".to_string());
attrs.insert("premium".to_string());
assert!(policy.evaluate(&attrs));
attrs.clear();
attrs.insert("admin".to_string());
assert!(!policy.evaluate(&attrs));
}
#[test]
fn test_policy_evaluation_or() {
let policy = PolicyNode::Or(vec![
PolicyNode::Attribute("admin".to_string()),
PolicyNode::Attribute("moderator".to_string()),
]);
let mut attrs = HashSet::new();
attrs.insert("admin".to_string());
assert!(policy.evaluate(&attrs));
attrs.clear();
attrs.insert("moderator".to_string());
assert!(policy.evaluate(&attrs));
attrs.clear();
attrs.insert("user".to_string());
assert!(!policy.evaluate(&attrs));
}
#[test]
fn test_policy_evaluation_threshold() {
let policy = PolicyNode::Threshold {
k: 2,
children: vec![
PolicyNode::Attribute("admin".to_string()),
PolicyNode::Attribute("moderator".to_string()),
PolicyNode::Attribute("premium".to_string()),
],
};
let mut attrs = HashSet::new();
attrs.insert("admin".to_string());
attrs.insert("moderator".to_string());
assert!(policy.evaluate(&attrs));
attrs.clear();
attrs.insert("admin".to_string());
assert!(!policy.evaluate(&attrs));
}
#[test]
fn test_user_key_generation() {
let authority = AbeAuthority::new();
let attrs = vec!["admin".to_string(), "premium".to_string()];
let user_key = authority.generate_user_key(&attrs).unwrap();
assert_eq!(user_key.get_attributes().len(), 2);
assert!(user_key.has_attribute("admin"));
assert!(user_key.has_attribute("premium"));
assert!(!user_key.has_attribute("user"));
}
#[test]
fn test_encrypt_decrypt_simple() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::new(PolicyNode::Attribute("premium".to_string()));
let user_key = authority
.generate_user_key(&["premium".to_string()])
.unwrap();
let plaintext = b"Secret premium content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let decrypted = authority.decrypt(&user_key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_and_policy() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::and(vec![
PolicyNode::Attribute("premium".to_string()),
PolicyNode::Attribute("us-region".to_string()),
]);
let user_key = authority
.generate_user_key(&["premium".to_string(), "us-region".to_string()])
.unwrap();
let plaintext = b"Premium US content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let decrypted = authority.decrypt(&user_key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_decrypt_fails_without_attributes() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::new(PolicyNode::Attribute("premium".to_string()));
let user_key = authority.generate_user_key(&["basic".to_string()]).unwrap();
let plaintext = b"Secret premium content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
assert!(authority.decrypt(&user_key, &ciphertext).is_err());
}
#[test]
fn test_decrypt_fails_partial_and() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::and(vec![
PolicyNode::Attribute("premium".to_string()),
PolicyNode::Attribute("us-region".to_string()),
]);
let user_key = authority
.generate_user_key(&["premium".to_string()])
.unwrap();
let plaintext = b"Premium US content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
assert!(authority.decrypt(&user_key, &ciphertext).is_err());
}
#[test]
fn test_or_policy_decryption() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::or(vec![
PolicyNode::Attribute("admin".to_string()),
PolicyNode::Attribute("premium".to_string()),
]);
let user_key1 = authority.generate_user_key(&["admin".to_string()]).unwrap();
let user_key2 = authority
.generate_user_key(&["premium".to_string()])
.unwrap();
let plaintext = b"Admin or Premium content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let decrypted1 = authority.decrypt(&user_key1, &ciphertext).unwrap();
assert_eq!(decrypted1, plaintext);
let decrypted2 = authority.decrypt(&user_key2, &ciphertext).unwrap();
assert_eq!(decrypted2, plaintext);
}
#[test]
fn test_threshold_policy() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::threshold(
2,
vec![
PolicyNode::Attribute("attr1".to_string()),
PolicyNode::Attribute("attr2".to_string()),
PolicyNode::Attribute("attr3".to_string()),
],
);
let user_key = authority
.generate_user_key(&["attr1".to_string(), "attr2".to_string()])
.unwrap();
let plaintext = b"Threshold content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let decrypted = authority.decrypt(&user_key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_complex_nested_policy() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::and(vec![
PolicyNode::Or(vec![
PolicyNode::Attribute("admin".to_string()),
PolicyNode::Attribute("moderator".to_string()),
]),
PolicyNode::Attribute("premium".to_string()),
]);
let user_key = authority
.generate_user_key(&["moderator".to_string(), "premium".to_string()])
.unwrap();
let plaintext = b"Complex policy content";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let decrypted = authority.decrypt(&user_key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_ciphertext_serialization() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::new(PolicyNode::Attribute("test".to_string()));
let plaintext = b"Serialization test";
let ciphertext = authority.encrypt(&policy, plaintext).unwrap();
let bytes = ciphertext.to_bytes().unwrap();
let restored = AbeCiphertext::from_bytes(&bytes).unwrap();
let user_key = authority.generate_user_key(&["test".to_string()]).unwrap();
let decrypted = authority.decrypt(&user_key, &restored).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_empty_attributes_fails() {
let authority = AbeAuthority::new();
let result = authority.generate_user_key(&[]);
assert!(result.is_err());
}
#[test]
fn test_master_key_export_import() {
let authority1 = AbeAuthority::new();
let seed = authority1.export_master_key();
let authority2 = AbeAuthority::from_master_key(seed);
let user_key = authority1.generate_user_key(&["test".to_string()]).unwrap();
let policy = AccessPolicy::new(PolicyNode::Attribute("test".to_string()));
let plaintext = b"Cross-authority test";
let ciphertext = authority2.encrypt(&policy, plaintext).unwrap();
let decrypted = authority1.decrypt(&user_key, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_multiple_plaintexts_same_policy() {
let authority = AbeAuthority::new();
let policy = AccessPolicy::new(PolicyNode::Attribute("premium".to_string()));
let user_key = authority
.generate_user_key(&["premium".to_string()])
.unwrap();
let plaintext1 = b"First message";
let plaintext2 = b"Second message";
let ciphertext1 = authority.encrypt(&policy, plaintext1).unwrap();
let ciphertext2 = authority.encrypt(&policy, plaintext2).unwrap();
let decrypted1 = authority.decrypt(&user_key, &ciphertext1).unwrap();
let decrypted2 = authority.decrypt(&user_key, &ciphertext2).unwrap();
assert_eq!(decrypted1, plaintext1);
assert_eq!(decrypted2, plaintext2);
}
}