use alloc::{collections::BTreeMap, sync::Arc};
use core::fmt;
use async_lock::Mutex;
use future_form::Sendable;
use rand::{RngCore, rngs::OsRng};
use subduction_core::{authenticated::Authenticated, peer::id::PeerId};
use crate::transport::HttpLongPollTransport;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SessionId([u8; 16]);
impl SessionId {
#[must_use]
pub fn random() -> Self {
let mut bytes = [0u8; 16];
OsRng.fill_bytes(&mut bytes);
Self(bytes)
}
#[must_use]
pub fn to_hex(&self) -> alloc::string::String {
let mut buf = alloc::string::String::with_capacity(32);
for byte in &self.0 {
use core::fmt::Write;
let _ = write!(buf, "{byte:02x}");
}
buf
}
#[must_use]
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 32 {
return None;
}
let mut bytes = [0u8; 16];
#[allow(clippy::indexing_slicing)]
for (i, chunk) in s.as_bytes().chunks_exact(2).enumerate() {
let hi = hex_digit(chunk[0])?;
let lo = hex_digit(chunk[1])?;
bytes[i] = (hi << 4) | lo;
}
Some(Self(bytes))
}
}
impl fmt::Display for SessionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.to_hex())
}
}
const fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[derive(Debug, Clone)]
pub struct SessionStore {
pub(crate) sessions: Arc<Mutex<BTreeMap<SessionId, SessionEntry>>>,
}
#[derive(Debug, Clone)]
pub struct SessionEntry {
pub peer_id: PeerId,
pub connection: HttpLongPollTransport,
pub authenticated: Option<Authenticated<HttpLongPollTransport, Sendable>>,
}
impl SessionStore {
#[must_use]
pub fn new() -> Self {
Self {
sessions: Arc::new(Mutex::new(BTreeMap::new())),
}
}
pub async fn insert(&self, id: SessionId, entry: SessionEntry) {
self.sessions.lock().await.insert(id, entry);
}
pub async fn get(&self, id: &SessionId) -> Option<SessionEntry> {
self.sessions.lock().await.get(id).cloned()
}
pub async fn remove(&self, id: &SessionId) -> Option<SessionEntry> {
self.sessions.lock().await.remove(id)
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, feature = "std", feature = "bolero"))]
#[allow(clippy::expect_used)]
mod proptests {
use super::*;
#[test]
fn prop_hex_roundtrip() {
bolero::check!().with_type::<[u8; 16]>().for_each(|bytes| {
let id = SessionId(*bytes);
let hex = id.to_hex();
assert_eq!(hex.len(), 32);
let decoded = SessionId::from_hex(&hex).expect("roundtrip should succeed");
assert_eq!(id, decoded);
});
}
#[test]
fn prop_hex_output_is_valid_lowercase_hex() {
bolero::check!().with_type::<[u8; 16]>().for_each(|bytes| {
let hex = SessionId(*bytes).to_hex();
assert!(hex.chars().all(|c| c.is_ascii_hexdigit()));
assert!(hex.chars().all(|c| !c.is_ascii_uppercase()));
});
}
#[test]
fn prop_from_hex_rejects_non_32_char_strings() {
bolero::check!().with_type::<String>().for_each(|s| {
if s.len() != 32 {
assert!(SessionId::from_hex(s).is_none());
}
});
}
#[test]
fn prop_from_hex_rejects_non_hex_chars() {
bolero::check!().with_type::<[u8; 16]>().for_each(|bytes| {
let hex = SessionId(*bytes).to_hex();
let corrupted = alloc::format!("z{}", &hex[1..]);
assert!(SessionId::from_hex(&corrupted).is_none());
});
}
}