use rialo_tee_secret_sharing::{encrypt_user_payload, USER_SECRET_AAD};
use rialo_types::PublicKey;
use crate::{
error::{Result, RialoError},
rpc::types::Pubkey,
};
pub const MAX_SECRET_LENGTH: usize = 64 * 1024;
pub fn encrypt_secret(
secret: String,
creator_pubkey: &Pubkey,
secret_sharing_pubkey: &PublicKey,
) -> Result<Vec<u8>> {
if secret.is_empty() {
return Err(RialoError::InvalidInput(
"Secret cannot be empty".to_string(),
));
}
let secret_bytes = secret.as_bytes();
if secret_bytes.len() > MAX_SECRET_LENGTH {
return Err(RialoError::InvalidInput(format!(
"Secret exceeds maximum length of {} bytes (got {} bytes)",
MAX_SECRET_LENGTH,
secret_bytes.len()
)));
}
let aad = [USER_SECRET_AAD, &creator_pubkey.to_bytes()].concat();
let ciphertext = encrypt_user_payload(secret_sharing_pubkey, secret_bytes, aad.as_ref())
.map_err(|e| RialoError::Encryption(format!("Encryption failed: {}", e)))?;
Ok(ciphertext)
}
#[cfg(test)]
mod tests {
use rialo_tee_secret_sharing::{decrypt_user_message, initialize_secret_key, SecretSharingKey};
use zeroize::Zeroizing;
use super::*;
fn create_test_key() -> SecretSharingKey {
let key_bytes = Zeroizing::new([42u8; 32]);
SecretSharingKey::from_private_key_bytes(key_bytes)
}
#[test]
fn test_encrypt_secret_basic() {
let sk = create_test_key();
let creator_pubkey = Pubkey::from([1u8; 32]);
let _ = initialize_secret_key(sk.clone());
let secret = "Bearer test-token-12345".to_string();
let result = encrypt_secret(secret.clone(), &creator_pubkey, sk.public_key());
assert!(result.is_ok());
let ciphertext = result.unwrap();
let decrypted = decrypt_user_message(
&ciphertext,
&PublicKey::from_bytes(creator_pubkey.to_bytes()),
)
.unwrap();
assert_eq!(decrypted.as_bytes(), secret.as_bytes());
}
#[test]
fn test_encrypt_secret_empty() {
let sk = create_test_key();
let creator_pubkey = Pubkey::from([3u8; 32]);
let result = encrypt_secret(String::new(), &creator_pubkey, sk.public_key());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("empty"));
}
#[test]
fn test_encrypt_secret_oversized() {
let sk = create_test_key();
let creator_pubkey = Pubkey::from([4u8; 32]);
let oversized_secret = "x".repeat(MAX_SECRET_LENGTH + 1);
let result = encrypt_secret(oversized_secret, &creator_pubkey, sk.public_key());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("maximum length"));
}
#[test]
fn test_encrypt_secret_with_different_creator() {
let sk = create_test_key();
let creator_pubkey1 = Pubkey::from([10u8; 32]);
let creator_pubkey2 = Pubkey::from([20u8; 32]);
let _ = initialize_secret_key(sk.clone());
let secret = "test-secret-with-creator".to_string();
let ciphertext1 =
encrypt_secret(secret.clone(), &creator_pubkey1, sk.public_key()).unwrap();
let ciphertext2 =
encrypt_secret(secret.clone(), &creator_pubkey2, sk.public_key()).unwrap();
assert_ne!(ciphertext1, ciphertext2);
let decrypted1 = decrypt_user_message(
&ciphertext1,
&PublicKey::from_bytes(creator_pubkey1.to_bytes()),
)
.unwrap();
assert_eq!(decrypted1.as_bytes(), secret.as_bytes());
let decrypted2 = decrypt_user_message(
&ciphertext2,
&PublicKey::from_bytes(creator_pubkey2.to_bytes()),
)
.unwrap();
assert_eq!(decrypted2.as_bytes(), secret.as_bytes());
}
}