use std::collections::HashSet;
use dashmap::DashMap;
use crate::codes;
const FINAL_EPOCH: i32 = -1;
const INITIAL_EPOCH: i32 = 0;
#[derive(Debug, Default)]
struct ShareSession {
epoch: i32,
#[allow(dead_code)]
partitions: HashSet<(uuid::Uuid, i32)>,
}
#[derive(Debug)]
pub(crate) struct ShareSessionCache {
sessions: DashMap<(String, String), ShareSession>,
max: usize,
}
impl ShareSessionCache {
pub(crate) fn new(max: usize) -> Self {
Self {
sessions: DashMap::new(),
max,
}
}
pub(crate) fn validate(&self, group: &str, member: &str, epoch: i32) -> Result<(), i16> {
let key = (group.to_string(), member.to_string());
if epoch == FINAL_EPOCH {
self.sessions.remove(&key);
return Ok(());
}
if epoch == INITIAL_EPOCH {
if let Some(mut entry) = self.sessions.get_mut(&key) {
entry.epoch = 1;
entry.partitions.clear();
return Ok(());
}
if self.sessions.len() >= self.max {
return Err(codes::SHARE_SESSION_LIMIT_REACHED);
}
self.sessions.insert(
key,
ShareSession {
epoch: 1,
partitions: HashSet::new(),
},
);
return Ok(());
}
let Some(mut entry) = self.sessions.get_mut(&key) else {
return Err(codes::SHARE_SESSION_NOT_FOUND);
};
if entry.epoch != epoch {
return Err(codes::INVALID_SHARE_SESSION_EPOCH);
}
entry.epoch = match entry.epoch.checked_add(1) {
Some(next) if next > 0 => next,
_ => 1,
};
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
#[test]
fn open_then_incremental_advances() {
let cache = ShareSessionCache::new(8);
assert!(cache.validate("g", "m", 0) == Ok(()));
assert!(cache.validate("g", "m", 1) == Ok(()));
assert!(cache.validate("g", "m", 2) == Ok(()));
}
#[test]
fn stale_epoch_is_invalid() {
let cache = ShareSessionCache::new(8);
assert!(cache.validate("g", "m", 0) == Ok(()));
assert!(cache.validate("g", "m", 5) == Err(codes::INVALID_SHARE_SESSION_EPOCH));
}
#[test]
fn unknown_member_non_zero_epoch_not_found() {
let cache = ShareSessionCache::new(8);
assert!(cache.validate("g", "ghost", 3) == Err(codes::SHARE_SESSION_NOT_FOUND));
}
#[test]
fn close_removes_session() {
let cache = ShareSessionCache::new(8);
assert!(cache.validate("g", "m", 0) == Ok(()));
assert!(cache.validate("g", "m", -1) == Ok(()));
assert!(cache.validate("g", "m", 1) == Err(codes::SHARE_SESSION_NOT_FOUND));
}
#[test]
fn close_absent_session_is_ok() {
let cache = ShareSessionCache::new(8);
assert!(cache.validate("g", "never", -1) == Ok(()));
}
#[test]
fn over_capacity_is_limit_reached() {
let cache = ShareSessionCache::new(1);
assert!(cache.validate("g", "m1", 0) == Ok(()));
assert!(cache.validate("g", "m2", 0) == Err(codes::SHARE_SESSION_LIMIT_REACHED));
assert!(cache.validate("g", "m1", 0) == Ok(()));
}
}