use std::{
num::NonZeroUsize,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use p256::{NistP256, elliptic_curve::SecretKey};
use crate::{JwtUser, KeyError, PrivyHpke, generated::types::AuthenticateBody};
const EXPIRY_BUFFER: Duration = Duration::from_secs(60);
type JwtCache = lru::LruCache<String, (SystemTime, SecretKey<NistP256>)>;
#[derive(Debug, Clone)]
pub struct JwtExchange {
cache: Arc<Mutex<JwtCache>>,
}
impl JwtExchange {
pub fn new(capacity: NonZeroUsize) -> Self {
JwtExchange {
cache: Arc::new(Mutex::new(lru::LruCache::new(capacity))),
}
}
pub async fn exchange_jwt_for_authorization_key(
&self,
jwt_user: &JwtUser,
) -> Result<SecretKey<NistP256>, KeyError> {
let client = &jwt_user.0;
let jwt = &jwt_user.1;
{
let mut cache = self.cache.lock().expect("lock poisoned");
let expired = if let Some((expiry, key)) = cache.get(jwt) {
let buffer = *expiry - EXPIRY_BUFFER;
if buffer > SystemTime::now() {
return Ok(key.clone());
}
true
} else {
false
};
if expired {
cache.demote(jwt);
}
}
#[cfg(all(feature = "unsafe_debug", debug_assertions))]
{
tracing::debug!("Starting HPKE JWT exchange for user JWT: {}", jwt);
}
let hpke_manager = PrivyHpke::new();
let public_key_b64 = hpke_manager.public_key()?;
tracing::debug!(
"Generated HPKE public key for authentication request {}",
public_key_b64
);
let body = AuthenticateBody {
user_jwt: jwt.clone(),
encryption_type: Some(crate::generated::types::AuthenticateBodyEncryptionType::Hpke),
recipient_public_key: Some(public_key_b64),
};
let auth = match client.wallets().authenticate_with_jwt(&body).await {
Ok(r) => r.into_inner(),
Err(e) => {
tracing::error!("failed to fetch authorization key: {:?}", e);
return Err(KeyError::Other(Box::new(e)));
}
};
let (key, expiry) = match auth {
crate::generated::types::AuthenticateResponse::WithEncryption {
encrypted_authorization_key,
expires_at,
..
} => {
tracing::debug!("Received encrypted authorization key, starting HPKE decryption");
let key = hpke_manager.decrypt_p256(
&encrypted_authorization_key.encapsulated_key,
&encrypted_authorization_key.ciphertext,
)?;
let expiry = SystemTime::UNIX_EPOCH + Duration::from_secs_f64(expires_at);
(key, expiry)
}
crate::generated::types::AuthenticateResponse::WithoutEncryption { .. } => {
tracing::warn!("Received unencrypted authorization key (fallback mode)");
unimplemented!()
}
};
tokio::time::sleep(Duration::from_millis(1000)).await;
{
let mut cache = self.cache.lock().expect("lock poisoned");
cache.push(jwt.clone(), (expiry, key.clone()));
}
tracing::info!("Successfully obtained and parsed authorization key");
Ok(key)
}
}