use self::ParseSessionIdError::*;
use hmac::digest::{Digest, FixedOutput, HashMarker, Update};
use serde::{Deserialize, Serialize};
use std::{error, fmt};
use zino_core::{SharedString, encoding::base64, error::Error, validation::Validation};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionId {
realm: SharedString,
identifier: String,
thread: u8,
count: u8,
}
impl SessionId {
#[inline]
pub fn new<D>(realm: impl Into<SharedString>, key: impl AsRef<[u8]>) -> Self
where
D: Default + FixedOutput + HashMarker + Update,
{
fn inner<D>(realm: SharedString, key: &[u8]) -> SessionId
where
D: Default + FixedOutput + HashMarker + Update,
{
let data = [realm.as_ref().as_bytes(), key].concat();
let mut hasher = D::new();
hasher.update(data.as_ref());
let identifier = base64::encode(hasher.finalize().as_slice());
SessionId {
realm,
identifier,
thread: 0,
count: 0,
}
}
inner::<D>(realm.into(), key.as_ref())
}
pub fn validate_with<D>(&self, realm: &str, key: impl AsRef<[u8]>) -> Validation
where
D: Default + FixedOutput + HashMarker + Update,
{
fn inner<D>(session_id: &SessionId, realm: &str, key: &[u8]) -> Validation
where
D: Default + FixedOutput + HashMarker + Update,
{
let mut validation = Validation::new();
let identifier = &session_id.identifier;
match base64::decode(identifier) {
Ok(hash) => {
let data = [realm.as_bytes(), key].concat();
let mut hasher = D::new();
hasher.update(data.as_ref());
if hasher.finalize().as_slice() != hash {
validation.record("identifier", "invalid session identifier");
}
}
Err(err) => {
validation.record_fail("identifier", err);
}
}
validation
}
inner::<D>(self, realm, key.as_ref())
}
pub fn accepts(&self, session_id: &SessionId) -> bool {
if self.identifier() != session_id.identifier() {
return false;
}
let realm = self.realm();
let domain = session_id.realm();
if domain == realm {
self.count() <= session_id.count()
} else {
let remainder = if realm.len() > domain.len() {
realm.strip_suffix(domain)
} else {
domain.strip_suffix(realm)
};
remainder.is_some_and(|s| s.ends_with('.'))
}
}
#[inline]
pub fn set_thread(&mut self, thread: u8) {
self.thread = thread;
}
#[inline]
pub fn increment_count(&mut self) {
self.count = self.count.saturating_add(1);
}
#[inline]
pub fn realm(&self) -> &str {
self.realm.as_ref()
}
#[inline]
pub fn identifier(&self) -> &str {
self.identifier.as_ref()
}
#[inline]
pub fn thread(&self) -> u8 {
self.thread
}
#[inline]
pub fn count(&self) -> u8 {
self.count
}
pub fn parse(s: &str) -> Result<SessionId, ParseSessionIdError> {
if let Some(s) = s.strip_prefix("SID:ANON:")
&& let Some((realm, s)) = s.split_once(':')
{
if let Some((identifier, s)) = s.split_once('-') {
if let Some((thread, count)) = s.split_once(':') {
return u8::from_str_radix(thread, 16)
.map_err(|err| ParseThreadError(err.into()))
.and_then(|thread| {
u8::from_str_radix(count, 16)
.map_err(|err| ParseCountError(err.into()))
.map(|count| Self {
realm: realm.to_owned().into(),
identifier: identifier.to_owned(),
thread,
count,
})
});
} else {
return u8::from_str_radix(s, 16)
.map_err(|err| ParseThreadError(err.into()))
.map(|thread| Self {
realm: realm.to_owned().into(),
identifier: identifier.to_owned(),
thread,
count: 0,
});
}
} else if let Some((identifier, count)) = s.split_once(':') {
return u8::from_str_radix(count, 16)
.map_err(|err| ParseCountError(err.into()))
.map(|count| Self {
realm: realm.to_owned().into(),
identifier: identifier.to_owned(),
thread: 0,
count,
});
} else {
return Ok(Self {
realm: realm.to_owned().into(),
identifier: s.to_owned(),
thread: 0,
count: 0,
});
}
}
Err(InvalidFormat)
}
}
impl fmt::Display for SessionId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let realm = &self.realm;
let identifier = &self.identifier;
let thread = self.thread;
let count = self.count;
if thread > 0 {
if count > 0 {
write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}:{count:x}")
} else {
write!(f, "SID:ANON:{realm}:{identifier}-{thread:x}")
}
} else if count > 0 {
write!(f, "SID:ANON:{realm}:{identifier}:{count:x}")
} else {
write!(f, "SID:ANON:{realm}:{identifier}")
}
}
}
#[derive(Debug)]
pub enum ParseSessionIdError {
ParseThreadError(Error),
ParseCountError(Error),
InvalidFormat,
}
impl fmt::Display for ParseSessionIdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ParseThreadError(err) => write!(f, "fail to parse thread: {err}"),
ParseCountError(err) => write!(f, "fail to parse count: {err}"),
InvalidFormat => write!(f, "invalid format"),
}
}
}
impl error::Error for ParseSessionIdError {}