use std::{collections::HashMap, sync::Arc};
use anyhow::Ok;
use anyhow::Result;
use anyhow::anyhow;
use async_lock::{Mutex, RwLock};
use chacha20poly1305::aead::{OsRng, rand_core::RngCore};
use x25519_dalek::PublicKey;
use crate::{crypto::zero_trust_session_key::SessionKey, time::SystemTime};
pub struct PairedSessionKey {
pub length: usize,
pub main: Arc<RwLock<HashMap<Vec<u8>, SessionKey>>>,
pub temp: Arc<Mutex<HashMap<Vec<u8>, SessionKey>>>, }
impl PairedSessionKey {
pub fn new(length: usize) -> Self {
Self {
length,
main: Arc::new(RwLock::new(HashMap::new())),
temp: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn create(&self, is_main: bool) -> (Vec<u8>, PublicKey) {
let mut session_id = vec![0u8; self.length];
OsRng.fill_bytes(&mut session_id);
let session_key = SessionKey::new();
let ephemeral_public = session_key.ephemeral_public.clone();
if is_main {
self.main
.write()
.await
.insert(session_id.clone(), session_key);
} else {
self.temp
.lock()
.await
.insert(session_id.clone(), session_key);
}
(session_id, ephemeral_public)
}
pub async fn save(&self, from: Vec<u8>, to: Vec<u8>) -> Result<()> {
let mut temp_sessions = self.temp.lock().await;
let mut session_key = temp_sessions
.remove(&from)
.ok_or_else(|| anyhow!("temp session not found"))?;
session_key.touch();
let mut sk = self.main.write().await;
sk.insert(to, session_key);
Ok(())
}
pub async fn cleanup(&self, ttl_ms: u128) {
self.temp
.lock()
.await
.retain(|_, sk| !SystemTime::is_expired(sk.updated_at, ttl_ms));
self.main
.write()
.await
.retain(|_, sk| !SystemTime::is_expired(sk.updated_at, ttl_ms));
}
pub async fn with_session<R>(
&self,
key: &Vec<u8>,
f: impl FnOnce(&mut SessionKey) -> Result<R>,
) -> Result<R> {
let mut sessions = self.main.write().await;
let sk = sessions
.get_mut(key)
.ok_or_else(|| anyhow!("session not found for address"))?;
sk.touch();
f(sk)
}
fn parse_public_key(bytes: &[u8]) -> Result<x25519_dalek::PublicKey> {
let array: [u8; 32] = bytes
.get(..32)
.and_then(|slice| slice.try_into().ok())
.ok_or_else(|| {
anyhow!(
"Invalid public key length: expected 32, got {}",
bytes.len()
)
})?;
Ok(x25519_dalek::PublicKey::from(array))
}
pub async fn establish_begins(
&self,
id: Vec<u8>,
remote: &[u8], ) -> Result<Option<PublicKey>> {
let mut session_key = SessionKey::new();
let ephemeral_public = session_key.ephemeral_public.clone();
let client_pub = Self::parse_public_key(remote)?;
if let Err(_) = session_key.establish(&client_pub) {
return Ok(None);
}
session_key.touch();
self.main.write().await.insert(id, session_key);
Ok(Some(ephemeral_public))
}
pub async fn establish_ends(&self, id: Vec<u8>, remote: &[u8]) -> Result<bool> {
let mut temp_sessions = self.temp.lock().await;
let mut session = match temp_sessions.remove(&id) {
Some(s) => s,
None => return Ok(false),
};
let peer_pub = Self::parse_public_key(remote)?;
if let Err(_) = session.establish(&peer_pub) {
return Ok(false);
}
session.touch();
drop(temp_sessions);
self.main.write().await.insert(id, session);
Ok(true)
}
pub async fn encrypt(&self, key: &Vec<u8>, plaintext: &[u8]) -> Result<Vec<u8>> {
let mut sessions = self.main.write().await;
let sk = sessions
.get_mut(key)
.ok_or_else(|| anyhow!("session not found for address"))?;
let ct = sk.encrypt(plaintext)?;
Ok(ct)
}
pub async fn decrypt(&self, key: &Vec<u8>, data: &[u8]) -> Result<Vec<u8>> {
let mut sessions = self.main.write().await;
let sk = sessions
.get_mut(key)
.ok_or_else(|| anyhow!("session not found for address"))?;
let pt = sk.decrypt(data)?;
Ok(pt)
}
}