use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::atomic::{AtomicI32, AtomicU64, AtomicUsize, Ordering};
use std::time::Instant;
use crabka_protocol::owned::fetch_request::FetchRequest;
use crabka_protocol::primitives::uuid::Uuid as WireUuid;
use crate::codes;
pub const INVALID_SESSION_ID: i32 = 0;
pub const INITIAL_EPOCH: i32 = 0;
pub const FINAL_EPOCH: i32 = -1;
#[must_use]
pub fn next_epoch(prev: i32) -> i32 {
let n = prev.wrapping_add(1);
if n <= 0 { 1 } else { n }
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FetchSessionKey {
pub topic_name: String,
pub topic_id: WireUuid,
pub partition: i32,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CachedPartitionState {
pub fetch_offset: i64,
pub last_fetched_epoch: i32,
pub current_leader_epoch: i32,
pub max_bytes: i32,
pub log_start_offset: i64,
pub last_high_watermark: i64,
pub last_last_stable_offset: i64,
pub last_log_start_offset: i64,
pub last_preferred_read_replica: i32,
pub last_aborted_txns_hash: u64,
pub last_error_code: i16,
}
pub struct FetchSession {
pub id: i32,
pub next_epoch: i32,
pub privileged: bool,
pub creator_principal: String,
pub partitions: HashMap<FetchSessionKey, CachedPartitionState>,
pub last_used: Instant,
}
#[derive(Debug)]
pub enum SessionDecision {
Sessionless,
NewSession,
Incremental {
session_id: i32,
new_epoch: i32,
partitions: Vec<(FetchSessionKey, CachedPartitionState)>,
},
Close { session_id: i32 },
Error { code: i16 },
}
struct Inner {
sessions: HashMap<i32, FetchSession>,
}
pub struct FetchSessionCache {
inner: Mutex<Inner>,
next_id: AtomicI32,
max_slots: usize,
evictions: AtomicU64,
num_sessions: AtomicUsize,
num_partitions: AtomicUsize,
}
impl FetchSessionCache {
#[must_use]
pub fn new(max_slots: usize) -> Self {
Self {
inner: Mutex::new(Inner {
sessions: HashMap::new(),
}),
next_id: AtomicI32::new(1),
max_slots,
evictions: AtomicU64::new(0),
num_sessions: AtomicUsize::new(0),
num_partitions: AtomicUsize::new(0),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.num_sessions.load(Ordering::Relaxed)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn total_partitions_cached(&self) -> usize {
self.num_partitions.load(Ordering::Relaxed)
}
#[must_use]
pub fn evictions_total(&self) -> u64 {
self.evictions.load(Ordering::Relaxed)
}
pub fn classify(&self, req: &FetchRequest) -> SessionDecision {
let sid = req.session_id;
let epoch = req.session_epoch;
if sid == INVALID_SESSION_ID {
return match epoch {
FINAL_EPOCH => SessionDecision::Sessionless,
INITIAL_EPOCH => SessionDecision::NewSession,
_ => SessionDecision::Error {
code: codes::INVALID_FETCH_SESSION_EPOCH,
},
};
}
let mut guard = self.inner.lock().expect("poisoned");
if epoch == FINAL_EPOCH {
if !guard.sessions.contains_key(&sid) {
return SessionDecision::Error {
code: codes::FETCH_SESSION_ID_NOT_FOUND,
};
}
return SessionDecision::Close { session_id: sid };
}
let Some(session) = guard.sessions.get_mut(&sid) else {
return SessionDecision::Error {
code: codes::FETCH_SESSION_ID_NOT_FOUND,
};
};
if epoch != session.next_epoch {
return SessionDecision::Error {
code: codes::INVALID_FETCH_SESSION_EPOCH,
};
}
session.last_used = Instant::now();
let partitions_before = session.partitions.len();
for ft in &req.forgotten_topics_data {
session.partitions.retain(|k, _| {
let topic_match = (!ft.topic.is_empty() && k.topic_name == ft.topic)
|| (ft.topic_id != WireUuid::ZERO && k.topic_id == ft.topic_id);
if !topic_match {
return true;
}
!ft.partitions.contains(&k.partition)
});
}
for t in &req.topics {
for fp in &t.partitions {
let existing_key = session
.partitions
.keys()
.find(|k| {
k.partition == fp.partition
&& ((!t.topic.is_empty() && k.topic_name == t.topic)
|| (t.topic_id != WireUuid::ZERO && k.topic_id == t.topic_id))
})
.cloned();
let key = existing_key.unwrap_or_else(|| FetchSessionKey {
topic_name: t.topic.clone(),
topic_id: t.topic_id,
partition: fp.partition,
});
let entry = session.partitions.entry(key).or_default();
entry.fetch_offset = fp.fetch_offset;
entry.max_bytes = fp.partition_max_bytes;
entry.current_leader_epoch = fp.current_leader_epoch;
entry.last_fetched_epoch = fp.last_fetched_epoch;
entry.log_start_offset = fp.log_start_offset;
}
}
let partitions_after = session.partitions.len();
if partitions_after >= partitions_before {
self.num_partitions
.fetch_add(partitions_after - partitions_before, Ordering::Relaxed);
} else {
self.num_partitions
.fetch_sub(partitions_before - partitions_after, Ordering::Relaxed);
}
let new_epoch = next_epoch(session.next_epoch);
session.next_epoch = new_epoch;
let partitions: Vec<(FetchSessionKey, CachedPartitionState)> = session
.partitions
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
SessionDecision::Incremental {
session_id: sid,
new_epoch,
partitions,
}
}
pub fn try_allocate(
&self,
privileged: bool,
creator_principal: String,
partitions: Vec<(FetchSessionKey, CachedPartitionState)>,
) -> i32 {
if self.max_slots == 0 {
return INVALID_SESSION_ID;
}
let mut guard = self.inner.lock().expect("poisoned");
if guard.sessions.len() >= self.max_slots {
let victim: Option<i32> = guard
.sessions
.iter()
.filter(|(_, s)| if privileged { true } else { !s.privileged })
.min_by_key(|(_, s)| s.last_used)
.map(|(id, _)| *id);
match victim {
Some(id) => {
let evicted = guard.sessions.remove(&id).expect("victim present");
self.num_sessions.fetch_sub(1, Ordering::Relaxed);
self.num_partitions
.fetch_sub(evicted.partitions.len(), Ordering::Relaxed);
self.evictions.fetch_add(1, Ordering::Relaxed);
}
None => return INVALID_SESSION_ID,
}
}
let id = loop {
let candidate = self.next_id.fetch_add(1, Ordering::Relaxed);
if candidate <= 0 {
self.next_id.store(1, Ordering::Relaxed);
continue;
}
if !guard.sessions.contains_key(&candidate) {
break candidate;
}
};
let partitions: HashMap<FetchSessionKey, CachedPartitionState> =
partitions.into_iter().collect();
let session = FetchSession {
id,
next_epoch: 1,
privileged,
creator_principal,
partitions,
last_used: Instant::now(),
};
let added_partitions = session.partitions.len();
guard.sessions.insert(id, session);
self.num_sessions.fetch_add(1, Ordering::Relaxed);
self.num_partitions
.fetch_add(added_partitions, Ordering::Relaxed);
id
}
pub fn finalize_incremental(
&self,
session_id: i32,
sent: &[(FetchSessionKey, CachedPartitionState)],
) {
let mut guard = self.inner.lock().expect("poisoned");
let Some(session) = guard.sessions.get_mut(&session_id) else {
return;
};
for (k, s) in sent {
if let Some(state) = session.partitions.get_mut(k) {
state.last_high_watermark = s.last_high_watermark;
state.last_last_stable_offset = s.last_last_stable_offset;
state.last_log_start_offset = s.last_log_start_offset;
state.last_preferred_read_replica = s.last_preferred_read_replica;
state.last_aborted_txns_hash = s.last_aborted_txns_hash;
state.last_error_code = s.last_error_code;
}
}
}
pub fn close(&self, session_id: i32) {
let mut guard = self.inner.lock().expect("poisoned");
if let Some(session) = guard.sessions.remove(&session_id) {
self.num_sessions.fetch_sub(1, Ordering::Relaxed);
self.num_partitions
.fetch_sub(session.partitions.len(), Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use crabka_protocol::owned::fetch_request::{FetchPartition, FetchTopic, ForgottenTopic};
fn req(
session_id: i32,
session_epoch: i32,
topics: Vec<FetchTopic>,
forgotten: Vec<ForgottenTopic>,
) -> FetchRequest {
FetchRequest {
session_id,
session_epoch,
topics,
forgotten_topics_data: forgotten,
..Default::default()
}
}
fn topic(name: &str, partitions: &[i32]) -> FetchTopic {
FetchTopic {
topic: name.to_string(),
topic_id: WireUuid::ZERO,
partitions: partitions
.iter()
.map(|&p| FetchPartition {
partition: p,
fetch_offset: 0,
partition_max_bytes: 1024,
..Default::default()
})
.collect(),
..Default::default()
}
}
#[test]
fn next_epoch_wraps_skipping_sentinels() {
assert!(next_epoch(0) == 1);
assert!(next_epoch(1) == 2);
assert!(next_epoch(i32::MAX) == 1);
assert!(next_epoch(-1) == 1);
}
#[test]
fn sessionless_request_is_classified_correctly() {
let cache = FetchSessionCache::new(10);
let r = req(0, FINAL_EPOCH, vec![], vec![]);
assert!(matches!(cache.classify(&r), SessionDecision::Sessionless));
}
#[test]
fn new_session_request_is_classified_correctly() {
let cache = FetchSessionCache::new(10);
let r = req(0, INITIAL_EPOCH, vec![topic("t", &[0])], vec![]);
assert!(matches!(cache.classify(&r), SessionDecision::NewSession));
}
#[test]
fn allocate_returns_nonzero_monotonic_ids() {
let cache = FetchSessionCache::new(10);
let a = cache.try_allocate(false, "alice".into(), vec![]);
let b = cache.try_allocate(false, "alice".into(), vec![]);
assert!(a > 0);
assert!(b > 0);
assert!(a != b);
assert!(cache.len() == 2);
}
#[test]
fn allocate_skips_zero_on_wrap() {
let cache = FetchSessionCache::new(10);
cache.next_id.store(0, Ordering::Relaxed);
let id = cache.try_allocate(false, "alice".into(), vec![]);
assert!(id > 0);
}
#[test]
fn allocate_returns_zero_when_max_slots_zero() {
let cache = FetchSessionCache::new(0);
let id = cache.try_allocate(false, "alice".into(), vec![]);
assert!(id == INVALID_SESSION_ID);
}
#[test]
fn unknown_session_id_returns_not_found() {
let cache = FetchSessionCache::new(10);
let r = req(12345, 1, vec![], vec![]);
match cache.classify(&r) {
SessionDecision::Error { code } => {
assert!(code == codes::FETCH_SESSION_ID_NOT_FOUND);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn stale_epoch_returns_invalid_epoch() {
let cache = FetchSessionCache::new(10);
let id = cache.try_allocate(false, "alice".into(), vec![]);
let r = req(id, 99, vec![], vec![]);
match cache.classify(&r) {
SessionDecision::Error { code } => {
assert!(code == codes::INVALID_FETCH_SESSION_EPOCH);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn close_request_returns_close_then_handler_drops() {
let cache = FetchSessionCache::new(10);
let id = cache.try_allocate(false, "alice".into(), vec![]);
let r = req(id, FINAL_EPOCH, vec![], vec![]);
match cache.classify(&r) {
SessionDecision::Close { session_id } => assert!(session_id == id),
other => panic!("expected Close, got {other:?}"),
}
cache.close(id);
assert!(cache.len() == 0);
let r2 = req(id, 1, vec![], vec![]);
match cache.classify(&r2) {
SessionDecision::Error { code } => {
assert!(code == codes::FETCH_SESSION_ID_NOT_FOUND);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn invalid_session_id_zero_with_stray_epoch_is_error() {
let cache = FetchSessionCache::new(10);
let r = req(0, 5, vec![], vec![]);
match cache.classify(&r) {
SessionDecision::Error { code } => {
assert!(code == codes::INVALID_FETCH_SESSION_EPOCH);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn incremental_merges_request_topics_and_bumps_epoch() {
let cache = FetchSessionCache::new(10);
let initial = vec![(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: 0,
},
CachedPartitionState {
fetch_offset: 100,
max_bytes: 1024,
..Default::default()
},
)];
let id = cache.try_allocate(false, "alice".into(), initial);
let r = req(id, 1, vec![topic("t", &[0, 1])], vec![]);
match cache.classify(&r) {
SessionDecision::Incremental {
session_id,
new_epoch,
partitions,
} => {
assert!(session_id == id);
assert!(new_epoch == 2);
assert!(partitions.len() == 2);
}
other => panic!("expected Incremental, got {other:?}"),
}
let r2 = req(id, 1, vec![], vec![]);
match cache.classify(&r2) {
SessionDecision::Error { code } => {
assert!(code == codes::INVALID_FETCH_SESSION_EPOCH);
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn incremental_merge_matches_cached_key_by_topic_id_only() {
let cache = FetchSessionCache::new(10);
let tid = WireUuid([7u8; 16]);
let cached_key = FetchSessionKey {
topic_name: "t".into(),
topic_id: tid,
partition: 0,
};
let id = cache.try_allocate(
false,
"alice".into(),
vec![(
cached_key.clone(),
CachedPartitionState {
fetch_offset: 5,
max_bytes: 1024,
..Default::default()
},
)],
);
let r = req(
id,
1,
vec![FetchTopic {
topic: String::new(),
topic_id: tid,
partitions: vec![FetchPartition {
partition: 0,
fetch_offset: 42,
partition_max_bytes: 2048,
..Default::default()
}],
..Default::default()
}],
vec![],
);
let SessionDecision::Incremental { partitions, .. } = cache.classify(&r) else {
panic!("expected Incremental");
};
assert!(partitions.len() == 1, "no duplicate entry created");
let (k, s) = &partitions[0];
assert!(k.topic_name == "t", "cached name preserved");
assert!(k.topic_id == tid);
assert!(s.fetch_offset == 42, "fetch_offset updated");
assert!(s.max_bytes == 2048, "max_bytes updated");
}
#[test]
fn incremental_merge_matches_cached_key_by_topic_name_only() {
let cache = FetchSessionCache::new(10);
let tid = WireUuid([9u8; 16]);
let cached_key = FetchSessionKey {
topic_name: "t".into(),
topic_id: tid,
partition: 0,
};
let id = cache.try_allocate(
false,
"alice".into(),
vec![(
cached_key.clone(),
CachedPartitionState {
fetch_offset: 5,
max_bytes: 1024,
..Default::default()
},
)],
);
let r = req(
id,
1,
vec![FetchTopic {
topic: "t".into(),
topic_id: WireUuid::ZERO,
partitions: vec![FetchPartition {
partition: 0,
fetch_offset: 99,
partition_max_bytes: 4096,
..Default::default()
}],
..Default::default()
}],
vec![],
);
let SessionDecision::Incremental { partitions, .. } = cache.classify(&r) else {
panic!("expected Incremental");
};
assert!(partitions.len() == 1);
let (_, s) = &partitions[0];
assert!(s.fetch_offset == 99);
assert!(s.max_bytes == 4096);
}
#[test]
fn forgotten_topics_drop_partitions_from_cache() {
let cache = FetchSessionCache::new(10);
let initial = vec![
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: 0,
},
CachedPartitionState::default(),
),
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: 1,
},
CachedPartitionState::default(),
),
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: 2,
},
CachedPartitionState::default(),
),
];
let id = cache.try_allocate(false, "alice".into(), initial);
let forgotten = vec![ForgottenTopic {
topic: "t".into(),
topic_id: WireUuid::ZERO,
partitions: vec![1],
..Default::default()
}];
let r = req(id, 1, vec![], forgotten);
match cache.classify(&r) {
SessionDecision::Incremental { partitions, .. } => {
assert!(partitions.len() == 2);
let parts: Vec<i32> = partitions.iter().map(|(k, _)| k.partition).collect();
assert!(parts.contains(&0));
assert!(parts.contains(&2));
assert!(!parts.contains(&1));
}
other => panic!("expected Incremental, got {other:?}"),
}
}
#[test]
fn lru_eviction_drops_oldest_non_privileged() {
let cache = FetchSessionCache::new(2);
let a = cache.try_allocate(false, "a".into(), vec![]);
std::thread::sleep(std::time::Duration::from_millis(2));
let b = cache.try_allocate(false, "b".into(), vec![]);
std::thread::sleep(std::time::Duration::from_millis(2));
let c = cache.try_allocate(false, "c".into(), vec![]);
assert!(cache.len() == 2);
assert!(cache.evictions_total() == 1);
let g = cache.inner.lock().unwrap();
assert!(!g.sessions.contains_key(&a));
assert!(g.sessions.contains_key(&b));
assert!(g.sessions.contains_key(&c));
}
#[test]
fn non_privileged_cannot_evict_privileged() {
let cache = FetchSessionCache::new(1);
let p = cache.try_allocate(true, "follower".into(), vec![]);
assert!(p > 0);
let c = cache.try_allocate(false, "consumer".into(), vec![]);
assert!(c == INVALID_SESSION_ID);
assert!(cache.evictions_total() == 0);
assert!(cache.len() == 1);
}
#[test]
fn privileged_can_evict_privileged() {
let cache = FetchSessionCache::new(1);
let p1 = cache.try_allocate(true, "f1".into(), vec![]);
std::thread::sleep(std::time::Duration::from_millis(2));
let p2 = cache.try_allocate(true, "f2".into(), vec![]);
assert!(p2 > 0);
assert!(cache.len() == 1);
assert!(cache.evictions_total() == 1);
let g = cache.inner.lock().unwrap();
assert!(!g.sessions.contains_key(&p1));
assert!(g.sessions.contains_key(&p2));
}
#[test]
fn finalize_incremental_updates_last_state() {
let cache = FetchSessionCache::new(10);
let key = FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: 0,
};
let id = cache.try_allocate(
false,
"a".into(),
vec![(key.clone(), CachedPartitionState::default())],
);
let sent = vec![(
key.clone(),
CachedPartitionState {
last_high_watermark: 42,
last_log_start_offset: 7,
..Default::default()
},
)];
cache.finalize_incremental(id, &sent);
let g = cache.inner.lock().unwrap();
let s = g.sessions.get(&id).unwrap().partitions.get(&key).unwrap();
assert!(s.last_high_watermark == 42);
assert!(s.last_log_start_offset == 7);
}
#[test]
fn total_partitions_cached_sums_across_sessions() {
let cache = FetchSessionCache::new(10);
let mk = |p| {
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: p,
},
CachedPartitionState::default(),
)
};
cache.try_allocate(false, "a".into(), vec![mk(0), mk(1)]);
cache.try_allocate(false, "b".into(), vec![mk(2), mk(3), mk(4)]);
assert!(cache.total_partitions_cached() == 5);
}
#[test]
fn counters_track_merge_forget_and_close() {
let cache = FetchSessionCache::new(10);
let mk = |p| {
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: p,
},
CachedPartitionState::default(),
)
};
let id = cache.try_allocate(false, "a".into(), vec![mk(0), mk(1)]);
assert!(cache.len() == 1);
assert!(cache.total_partitions_cached() == 2);
let forgotten = vec![ForgottenTopic {
topic: "t".into(),
topic_id: WireUuid::ZERO,
partitions: vec![1],
..Default::default()
}];
let r = req(id, 1, vec![topic("t", &[0, 2, 3])], forgotten);
assert!(matches!(
cache.classify(&r),
SessionDecision::Incremental { .. }
));
assert!(cache.total_partitions_cached() == 3);
cache.close(id);
assert!(cache.len() == 0);
assert!(cache.total_partitions_cached() == 0);
}
#[test]
fn counters_track_eviction() {
let cache = FetchSessionCache::new(1);
let mk = |p| {
(
FetchSessionKey {
topic_name: "t".into(),
topic_id: WireUuid::ZERO,
partition: p,
},
CachedPartitionState::default(),
)
};
cache.try_allocate(false, "a".into(), vec![mk(0), mk(1)]);
assert!(cache.total_partitions_cached() == 2);
cache.try_allocate(false, "b".into(), vec![mk(0)]);
assert!(cache.len() == 1);
assert!(cache.total_partitions_cached() == 1);
}
}