use crate::attestation::report::AttestationReport;
use crate::errors::TeeError;
use chacha20poly1305::ChaCha20Poly1305;
use chacha20poly1305::aead::{Aead, KeyInit};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use x25519_dalek::{PublicKey, StaticSecret};
use zeroize::Zeroize;
#[derive(Debug)]
pub struct KeyExchangeSession {
pub session_id: String,
pub public_key: Vec<u8>,
private_key: Vec<u8>,
pub created_at: u64,
pub ttl_secs: u64,
}
impl KeyExchangeSession {
pub fn new(ttl_secs: u64) -> Self {
let secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
let public = PublicKey::from(&secret);
let public_key = public.as_bytes().to_vec();
let private_key = secret.to_bytes().to_vec();
let mut session_id_bytes = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut session_id_bytes);
let session_id = hex::encode(session_id_bytes);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
session_id,
public_key,
private_key,
created_at: now,
ttl_secs,
}
}
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now.saturating_sub(self.created_at) > self.ttl_secs
}
pub fn public_key_digest(&self) -> String {
hex::encode(Sha256::digest(&self.public_key))
}
pub fn remaining_ttl(&self) -> Duration {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let elapsed = now.saturating_sub(self.created_at);
Duration::from_secs(self.ttl_secs.saturating_sub(elapsed))
}
pub fn open(&self, payload: &SealedSecretPayload) -> Result<Vec<u8>, TeeError> {
let client_public_bytes: [u8; 32] = payload
.ephemeral_public_key
.as_ref()
.ok_or_else(|| TeeError::SealedSecret("missing ephemeral public key".into()))?
.as_slice()
.try_into()
.map_err(|_| TeeError::SealedSecret("ephemeral public key must be 32 bytes".into()))?;
let nonce_bytes = payload
.nonce
.as_ref()
.ok_or_else(|| TeeError::SealedSecret("missing nonce".into()))?;
if nonce_bytes.len() != 12 {
return Err(TeeError::SealedSecret("nonce must be 12 bytes".into()));
}
let secret_bytes: [u8; 32] = self
.private_key
.as_slice()
.try_into()
.map_err(|_| TeeError::SealedSecret("invalid private key length".into()))?;
let secret = StaticSecret::from(secret_bytes);
let client_public = PublicKey::from(client_public_bytes);
let shared = secret.diffie_hellman(&client_public);
let enc_key = Sha256::digest(shared.as_bytes());
let cipher = ChaCha20Poly1305::new_from_slice(&enc_key)
.map_err(|e| TeeError::SealedSecret(format!("cipher init failed: {e}")))?;
let nonce = chacha20poly1305::aead::generic_array::GenericArray::from_slice(nonce_bytes);
cipher
.decrypt(nonce, payload.ciphertext.as_ref())
.map_err(|_| {
TeeError::SealedSecret("decryption failed: invalid ciphertext or key".into())
})
}
}
impl Drop for KeyExchangeSession {
fn drop(&mut self) {
self.private_key.zeroize();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyExchangeRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyExchangeResponse {
pub session_id: String,
pub public_key_hex: String,
pub attestation: AttestationReport,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SealedSecretPayload {
pub session_id: String,
pub ciphertext: Vec<u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub nonce: Option<Vec<u8>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ephemeral_public_key: Option<Vec<u8>>,
}
impl SealedSecretPayload {
pub fn seal(
session_id: String,
plaintext: &[u8],
tee_public_key: &[u8],
) -> Result<Self, TeeError> {
let tee_pk_bytes: [u8; 32] = tee_public_key
.try_into()
.map_err(|_| TeeError::SealedSecret("TEE public key must be 32 bytes".into()))?;
let client_secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
let client_public = PublicKey::from(&client_secret);
let tee_public = PublicKey::from(tee_pk_bytes);
let shared = client_secret.diffie_hellman(&tee_public);
let enc_key = Sha256::digest(shared.as_bytes());
let cipher = ChaCha20Poly1305::new_from_slice(&enc_key)
.map_err(|e| TeeError::SealedSecret(format!("cipher init failed: {e}")))?;
let mut nonce_bytes = [0u8; 12];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = chacha20poly1305::aead::generic_array::GenericArray::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|_| TeeError::SealedSecret("encryption failed".into()))?;
Ok(Self {
session_id,
ciphertext,
nonce: Some(nonce_bytes.to_vec()),
ephemeral_public_key: Some(client_public.as_bytes().to_vec()),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SealedSecretResult {
pub success: bool,
pub attestation_digest: String,
pub key_fingerprint: String,
}