use std::collections::{HashMap, HashSet};
use crate::protocol::{FetchForgottenTopic, FetchPartitionRequest, FetchTopicRequest};
use crate::{BrokerId, PartitionId};
pub const INITIAL_EPOCH: i32 = 0;
const ZERO_UUID: [u8; 16] = [0u8; 16];
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum SessionKey {
Name(String),
Uuid([u8; 16]),
}
impl SessionKey {
fn from_request(topic: &FetchTopicRequest) -> Self {
match topic.topic_id {
Some(id) if id != ZERO_UUID => Self::Uuid(id),
_ => Self::Name(topic.topic.clone()),
}
}
}
#[derive(Debug)]
struct TopicSession {
name: String,
partitions: HashMap<PartitionId, PartitionState>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct PartitionState {
fetch_offset: i64,
partition_max_bytes: i32,
}
#[derive(Debug)]
pub struct FetchSessionState {
#[allow(dead_code)]
broker_id: BrokerId,
session_id: i32,
epoch: i32,
topics: HashMap<SessionKey, TopicSession>,
}
#[derive(Debug)]
pub struct FetchSessionRequest {
pub session_id: i32,
pub session_epoch: i32,
pub topics: Vec<FetchTopicRequest>,
pub forgotten_topics: Vec<FetchForgottenTopic>,
pub is_full_fetch: bool,
}
impl FetchSessionState {
pub fn new(broker_id: BrokerId) -> Self {
Self {
broker_id,
session_id: 0,
epoch: INITIAL_EPOCH,
topics: HashMap::new(),
}
}
pub fn has_session(&self) -> bool {
self.session_id != 0
}
#[cfg(test)]
pub fn broker_id(&self) -> BrokerId {
self.broker_id
}
#[cfg(test)]
pub fn session_id(&self) -> i32 {
self.session_id
}
#[cfg(test)]
pub fn epoch(&self) -> i32 {
self.epoch
}
#[cfg(test)]
pub fn partition_count(&self) -> usize {
self.topics.values().map(|t| t.partitions.len()).sum()
}
pub fn build_request(&self, desired: &[FetchTopicRequest]) -> FetchSessionRequest {
if !self.has_session() {
return FetchSessionRequest {
session_id: 0,
session_epoch: INITIAL_EPOCH,
topics: desired.to_vec(),
forgotten_topics: Vec::new(),
is_full_fetch: true,
};
}
let epoch = self.epoch;
let total: usize = desired.iter().map(|t| t.partitions.len()).sum();
let mut desired_map: HashMap<SessionKey, HashMap<PartitionId, &FetchPartitionRequest>> =
HashMap::with_capacity(desired.len());
for topic in desired {
let key = SessionKey::from_request(topic);
let part_map = desired_map.entry(key).or_default();
part_map.reserve(topic.partitions.len());
for part in &topic.partitions {
part_map.insert(part.partition, part);
}
}
let _ = total;
let mut changed: HashMap<&str, Vec<FetchPartitionRequest>> = HashMap::new();
for topic in desired {
let key = SessionKey::from_request(topic);
let session_topic = self.topics.get(&key);
for part in &topic.partitions {
let is_new_or_changed =
match session_topic.and_then(|t| t.partitions.get(&part.partition)) {
None => true,
Some(prev) => {
prev.fetch_offset != part.fetch_offset
|| prev.partition_max_bytes != part.partition_max_bytes
}
};
if is_new_or_changed {
changed
.entry(topic.topic.as_str())
.or_default()
.push(part.clone());
}
}
}
let desired_keys: HashSet<SessionKey> =
desired.iter().map(SessionKey::from_request).collect();
let mut forgotten_map: HashMap<&str, Vec<i32>> = HashMap::new();
for (key, session_topic) in &self.topics {
if desired_keys.contains(key) {
let desired_parts = desired_map.get(key);
for &partition in session_topic.partitions.keys() {
let still_wanted = desired_parts.and_then(|m| m.get(&partition)).is_some();
if !still_wanted {
forgotten_map
.entry(session_topic.name.as_str())
.or_default()
.push(partition);
}
}
} else {
let parts: Vec<i32> = session_topic.partitions.keys().copied().collect();
if !parts.is_empty() {
forgotten_map
.entry(session_topic.name.as_str())
.or_default()
.extend(parts);
}
}
}
let name_to_uuid: HashMap<&str, [u8; 16]> = desired
.iter()
.filter_map(|t| {
t.topic_id
.filter(|id| *id != ZERO_UUID)
.map(|id| (t.topic.as_str(), id))
})
.collect();
let topics: Vec<FetchTopicRequest> = changed
.into_iter()
.map(|(name, partitions)| FetchTopicRequest {
topic: name.to_string(),
topic_id: name_to_uuid.get(name).copied(),
partitions,
})
.collect();
let forgotten_topics: Vec<FetchForgottenTopic> = forgotten_map
.into_iter()
.map(|(name, partitions)| FetchForgottenTopic {
topic: name.to_string(),
topic_id: name_to_uuid.get(name).copied(),
partitions,
})
.collect();
FetchSessionRequest {
session_id: self.session_id,
session_epoch: epoch,
topics,
forgotten_topics,
is_full_fetch: false,
}
}
pub fn update_from_response(
&mut self,
response_session_id: i32,
desired: &[FetchTopicRequest],
) {
if response_session_id == 0 {
self.reset();
return;
}
self.session_id = response_session_id;
self.epoch = self.next_epoch();
self.topics.clear();
for topic in desired {
let key = SessionKey::from_request(topic);
let mut partitions = HashMap::with_capacity(topic.partitions.len());
for part in &topic.partitions {
partitions.insert(
part.partition,
PartitionState {
fetch_offset: part.fetch_offset,
partition_max_bytes: part.partition_max_bytes,
},
);
}
self.topics.insert(
key,
TopicSession {
name: topic.topic.clone(),
partitions,
},
);
}
}
pub fn reset(&mut self) {
self.session_id = 0;
self.epoch = INITIAL_EPOCH;
self.topics.clear();
}
fn next_epoch(&self) -> i32 {
if self.epoch == i32::MAX {
1
} else {
self.epoch + 1
}
}
}
#[derive(Debug, Default)]
pub struct FetchSessionCache {
sessions: HashMap<BrokerId, FetchSessionState>,
}
impl FetchSessionCache {
pub fn new() -> Self {
Self {
sessions: HashMap::new(),
}
}
pub fn get_or_create(&mut self, broker_id: BrokerId) -> &mut FetchSessionState {
self.sessions
.entry(broker_id)
.or_insert_with(|| FetchSessionState::new(broker_id))
}
pub fn reset_broker(&mut self, broker_id: BrokerId) {
if let Some(session) = self.sessions.get_mut(&broker_id) {
session.reset();
}
}
pub fn reset_all(&mut self) {
for session in self.sessions.values_mut() {
session.reset();
}
}
pub(crate) fn retain_brokers(&mut self, broker_ids: &[BrokerId]) {
let broker_set: HashSet<_> = broker_ids.iter().copied().collect();
self.sessions.retain(|id, _| broker_set.contains(id));
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
fn make_topic_request(topic: &str, partitions: &[(i32, i64, i32)]) -> FetchTopicRequest {
FetchTopicRequest {
topic: topic.to_string(),
topic_id: None,
partitions: partitions
.iter()
.map(|&(partition, offset, max_bytes)| FetchPartitionRequest {
partition,
current_leader_epoch: -1,
fetch_offset: offset,
last_fetched_epoch: -1,
log_start_offset: -1,
partition_max_bytes: max_bytes,
replica_directory_id: None,
high_watermark: None,
})
.collect(),
}
}
fn make_topic_request_with_epoch(
topic: &str,
partitions: &[(i32, i64, i32, i32)],
) -> FetchTopicRequest {
FetchTopicRequest {
topic: topic.to_string(),
topic_id: None,
partitions: partitions
.iter()
.map(
|&(partition, offset, max_bytes, epoch)| FetchPartitionRequest {
partition,
current_leader_epoch: epoch,
fetch_offset: offset,
last_fetched_epoch: -1,
log_start_offset: -1,
partition_max_bytes: max_bytes,
replica_directory_id: None,
high_watermark: None,
},
)
.collect(),
}
}
#[test]
fn test_new_session_starts_with_no_session() {
let state = FetchSessionState::new(1);
assert_eq!(state.session_id(), 0);
assert_eq!(state.epoch(), INITIAL_EPOCH);
assert!(!state.has_session());
assert_eq!(state.partition_count(), 0);
}
#[test]
fn test_first_fetch_is_full() {
let state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
let req = state.build_request(&desired);
assert!(req.is_full_fetch);
assert_eq!(req.session_id, 0);
assert_eq!(req.session_epoch, INITIAL_EPOCH);
assert_eq!(req.topics.len(), 1);
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_session_established_after_response() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
assert!(state.has_session());
assert_eq!(state.session_id(), 42);
assert_eq!(state.epoch(), 1);
assert_eq!(state.partition_count(), 1);
}
#[test]
fn test_incremental_fetch_no_changes() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
let req = state.build_request(&desired);
assert!(!req.is_full_fetch);
assert_eq!(req.session_id, 42);
assert_eq!(req.session_epoch, 1); assert!(req.topics.is_empty()); assert!(req.forgotten_topics.is_empty()); }
#[test]
fn test_incremental_fetch_offset_changed() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request("topic-a", &[(0, 200, 1048576)])];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert_eq!(req.topics.len(), 1);
assert_eq!(req.topics[0].partitions[0].fetch_offset, 200);
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_incremental_fetch_partition_added() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request(
"topic-a",
&[(0, 100, 1048576), (1, 0, 1048576)],
)];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert_eq!(req.topics.len(), 1);
assert_eq!(req.topics[0].partitions.len(), 1);
assert_eq!(req.topics[0].partitions[0].partition, 1);
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_incremental_fetch_partition_removed() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request(
"topic-a",
&[(0, 100, 1048576), (1, 50, 1048576)],
)];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty()); assert_eq!(req.forgotten_topics.len(), 1);
assert_eq!(req.forgotten_topics[0].topic, "topic-a");
assert_eq!(req.forgotten_topics[0].partitions, vec![1]);
}
#[test]
fn test_session_reset() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
assert!(state.has_session());
state.reset();
assert!(!state.has_session());
assert_eq!(state.session_id(), 0);
assert_eq!(state.epoch(), INITIAL_EPOCH);
assert_eq!(state.partition_count(), 0);
}
#[test]
fn test_broker_returns_zero_session_id_closes_session() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
assert!(state.has_session());
state.update_from_response(0, &desired);
assert!(!state.has_session());
}
#[test]
fn test_epoch_wraps_at_max() {
let mut state = FetchSessionState::new(1);
state.session_id = 42;
state.epoch = i32::MAX;
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
let req = state.build_request(&desired);
assert_eq!(req.session_epoch, i32::MAX);
state.update_from_response(42, &desired);
assert_eq!(state.epoch(), 1);
}
#[test]
fn test_mixed_changes_and_removals() {
let mut state = FetchSessionState::new(1);
let desired = vec![
make_topic_request("topic-a", &[(0, 100, 1048576), (1, 50, 1048576)]),
make_topic_request("topic-b", &[(0, 200, 1048576)]),
];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request(
"topic-a",
&[(0, 300, 1048576), (1, 50, 1048576), (2, 0, 1048576)],
)];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
let changed_partitions: Vec<i32> = req
.topics
.iter()
.flat_map(|t| t.partitions.iter().map(|p| p.partition))
.collect();
assert!(changed_partitions.contains(&0));
assert!(changed_partitions.contains(&2));
assert!(!changed_partitions.contains(&1));
assert_eq!(req.forgotten_topics.len(), 1);
assert_eq!(req.forgotten_topics[0].topic, "topic-b");
}
#[test]
fn test_leader_epoch_change_not_tracked() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request_with_epoch(
"topic-a",
&[(0, 100, 1048576, 5)],
)];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request_with_epoch(
"topic-a",
&[(0, 100, 1048576, 6)],
)];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty()); assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_cache_get_or_create() {
let mut cache = FetchSessionCache::new();
{
let session = cache.get_or_create(1);
assert_eq!(session.broker_id(), 1);
assert!(!session.has_session());
}
{
let session = cache.get_or_create(1);
session.session_id = 42;
}
{
let session = cache.get_or_create(1);
assert_eq!(session.session_id(), 42);
}
}
#[test]
fn test_cache_reset_broker() {
let mut cache = FetchSessionCache::new();
let desired = vec![make_topic_request("t", &[(0, 0, 1048576)])];
cache.get_or_create(1).update_from_response(42, &desired);
cache.get_or_create(2).update_from_response(43, &desired);
cache.reset_broker(1);
assert!(!cache.get_or_create(1).has_session());
assert!(cache.get_or_create(2).has_session());
}
#[test]
fn test_cache_reset_all() {
let mut cache = FetchSessionCache::new();
let desired = vec![make_topic_request("t", &[(0, 0, 1048576)])];
cache.get_or_create(1).update_from_response(42, &desired);
cache.get_or_create(2).update_from_response(43, &desired);
cache.reset_all();
assert!(!cache.get_or_create(1).has_session());
assert!(!cache.get_or_create(2).has_session());
}
#[test]
fn test_cache_retain_brokers() {
let mut cache = FetchSessionCache::new();
let desired = vec![make_topic_request("t", &[(0, 0, 1048576)])];
cache.get_or_create(1).update_from_response(42, &desired);
cache.get_or_create(2).update_from_response(43, &desired);
cache.get_or_create(3).update_from_response(44, &desired);
cache.retain_brokers(&[1, 3]);
assert!(cache.get_or_create(1).has_session());
assert!(!cache.get_or_create(2).has_session()); assert!(cache.get_or_create(3).has_session());
}
#[test]
fn test_full_fetch_after_reset() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
state.reset();
let req = state.build_request(&desired);
assert!(req.is_full_fetch);
assert_eq!(req.session_id, 0);
assert_eq!(req.session_epoch, INITIAL_EPOCH);
}
fn make_uuid_topic_request(
topic: &str,
uuid: [u8; 16],
partitions: &[(i32, i64, i32)],
) -> FetchTopicRequest {
FetchTopicRequest {
topic: topic.to_string(),
topic_id: Some(uuid),
partitions: partitions
.iter()
.map(|&(partition, offset, max_bytes)| FetchPartitionRequest {
partition,
current_leader_epoch: -1,
fetch_offset: offset,
last_fetched_epoch: -1,
log_start_offset: -1,
partition_max_bytes: max_bytes,
replica_directory_id: None,
high_watermark: None,
})
.collect(),
}
}
#[test]
fn test_uuid_keyed_incremental_no_changes() {
let uuid: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let mut state = FetchSessionState::new(1);
let desired = vec![make_uuid_topic_request(
"topic-a",
uuid,
&[(0, 100, 1048576)],
)];
state.update_from_response(42, &desired);
assert_eq!(state.partition_count(), 1);
let req = state.build_request(&desired);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty());
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_uuid_keyed_offset_change() {
let uuid: [u8; 16] = [0xAA; 16];
let mut state = FetchSessionState::new(1);
let desired = vec![make_uuid_topic_request(
"topic-a",
uuid,
&[(0, 100, 1048576)],
)];
state.update_from_response(42, &desired);
let desired2 = vec![make_uuid_topic_request(
"topic-a",
uuid,
&[(0, 200, 1048576)],
)];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert_eq!(req.topics.len(), 1);
assert_eq!(req.topics[0].topic, "topic-a");
assert_eq!(req.topics[0].partitions[0].fetch_offset, 200);
}
#[test]
fn test_uuid_keyed_partition_removed() {
let uuid: [u8; 16] = [0xBB; 16];
let mut state = FetchSessionState::new(1);
let desired = vec![make_uuid_topic_request(
"topic-a",
uuid,
&[(0, 100, 1048576), (1, 50, 1048576)],
)];
state.update_from_response(42, &desired);
let desired2 = vec![make_uuid_topic_request(
"topic-a",
uuid,
&[(0, 100, 1048576)],
)];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty()); assert_eq!(req.forgotten_topics.len(), 1);
assert_eq!(req.forgotten_topics[0].topic, "topic-a");
assert_eq!(req.forgotten_topics[0].partitions, vec![1]);
}
#[test]
fn test_zero_uuid_falls_back_to_name_key() {
let zero_uuid = [0u8; 16];
let mut state = FetchSessionState::new(1);
let desired = vec![make_uuid_topic_request(
"topic-a",
zero_uuid,
&[(0, 100, 1048576)],
)];
state.update_from_response(42, &desired);
let req = state.build_request(&desired);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty());
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_mixed_uuid_and_name_keying() {
let uuid: [u8; 16] = [0xCC; 16];
let mut state = FetchSessionState::new(1);
let desired = vec![
make_uuid_topic_request("topic-uuid", uuid, &[(0, 100, 1048576)]),
make_topic_request("topic-name", &[(0, 200, 1048576)]),
];
state.update_from_response(42, &desired);
assert_eq!(state.partition_count(), 2);
let req = state.build_request(&desired);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty());
assert!(req.forgotten_topics.is_empty());
}
#[test]
fn test_empty_desired_produces_all_forgotten() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request(
"topic-a",
&[(0, 100, 1048576), (1, 50, 1048576)],
)];
state.update_from_response(42, &desired);
let desired2: Vec<FetchTopicRequest> = vec![];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert!(req.topics.is_empty());
assert_eq!(req.forgotten_topics.len(), 1);
assert_eq!(req.forgotten_topics[0].partitions.len(), 2);
}
#[test]
fn test_max_bytes_change_detected() {
let mut state = FetchSessionState::new(1);
let desired = vec![make_topic_request("topic-a", &[(0, 100, 1048576)])];
state.update_from_response(42, &desired);
let desired2 = vec![make_topic_request("topic-a", &[(0, 100, 2097152)])];
let req = state.build_request(&desired2);
assert!(!req.is_full_fetch);
assert_eq!(req.topics.len(), 1);
assert_eq!(req.topics[0].partitions[0].partition_max_bytes, 2097152);
}
}