use crate::config::TeeKeyExchangeConfig;
use crate::errors::TeeError;
use crate::exchange::protocol::KeyExchangeSession;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct TeeAuthService {
config: TeeKeyExchangeConfig,
sessions: Arc<Mutex<BTreeMap<String, KeyExchangeSession>>>,
cleanup_handle: std::sync::Mutex<Option<tokio::task::AbortHandle>>,
}
impl TeeAuthService {
pub fn new(config: TeeKeyExchangeConfig) -> Self {
Self {
config,
sessions: Arc::new(Mutex::new(BTreeMap::new())),
cleanup_handle: std::sync::Mutex::new(None),
}
}
pub async fn create_session(&self) -> Result<(String, Vec<u8>), TeeError> {
let mut sessions = self.sessions.lock().await;
sessions.retain(|_, s| !s.is_expired());
if sessions.len() >= self.config.max_sessions {
return Err(TeeError::KeyExchange(format!(
"maximum session limit reached ({})",
self.config.max_sessions
)));
}
let session = KeyExchangeSession::new(self.config.session_ttl_secs);
let session_id = session.session_id.clone();
let public_key = session.public_key.clone();
sessions.insert(session_id.clone(), session);
tracing::debug!(session_id = %session_id, "created key exchange session");
Ok((session_id, public_key))
}
pub async fn consume_session(&self, session_id: &str) -> Result<KeyExchangeSession, TeeError> {
let mut sessions = self.sessions.lock().await;
let session = sessions
.get(session_id)
.ok_or_else(|| TeeError::KeyExchange(format!("session not found: {session_id}")))?;
if session.is_expired() {
sessions.remove(session_id);
return Err(TeeError::KeyExchange(format!(
"session expired: {session_id}"
)));
}
let session = sessions
.remove(session_id)
.expect("session exists; checked above");
tracing::debug!(session_id = %session_id, "consumed key exchange session");
Ok(session)
}
pub async fn active_session_count(&self) -> usize {
let sessions = self.sessions.lock().await;
sessions.values().filter(|s| !s.is_expired()).count()
}
pub async fn get_session_public_key(&self, session_id: &str) -> Result<Vec<u8>, TeeError> {
let sessions = self.sessions.lock().await;
let session = sessions
.get(session_id)
.ok_or_else(|| TeeError::KeyExchange(format!("session not found: {session_id}")))?;
if session.is_expired() {
return Err(TeeError::KeyExchange(format!(
"session expired: {session_id}"
)));
}
Ok(session.public_key.clone())
}
pub fn start_cleanup_loop(&self) -> tokio::task::JoinHandle<()> {
let sessions = self.sessions.clone();
let ttl_secs = self.config.session_ttl_secs;
let handle = tokio::spawn(async move {
tracing::info!("TEE auth service cleanup loop started");
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(ttl_secs.max(30))).await;
let mut sessions = sessions.lock().await;
let before = sessions.len();
sessions.retain(|_, s| !s.is_expired());
let evicted = before - sessions.len();
if evicted > 0 {
tracing::debug!(
evicted = evicted,
remaining = sessions.len(),
"evicted expired key exchange sessions"
);
}
}
});
*self.cleanup_handle.lock().unwrap() = Some(handle.abort_handle());
handle
}
pub fn session_ttl_secs(&self) -> u64 {
self.config.session_ttl_secs
}
pub fn max_sessions(&self) -> usize {
self.config.max_sessions
}
}
impl Drop for TeeAuthService {
fn drop(&mut self) {
if let Some(handle) = self.cleanup_handle.lock().unwrap().take() {
handle.abort();
}
}
}