use ahash::AHashMap;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use parking_lot::RwLock;
use crate::PartitionId;
use crate::error::{KrafkaError, ProtocolErrorKind, Result};
#[derive(Debug)]
pub struct ProducerIdentity {
poisoned: AtomicBool,
inner: RwLock<IdentityInner>,
}
#[derive(Debug)]
struct IdentityInner {
producer_id: i64,
producer_epoch: i16,
sequences: AHashMap<String, AHashMap<PartitionId, SequenceState>>,
}
impl IdentityInner {
fn uninitialized() -> Self {
Self {
producer_id: -1,
producer_epoch: -1_i16,
sequences: AHashMap::new(),
}
}
fn is_initialized(&self) -> bool {
self.producer_id >= 0
}
}
#[derive(Debug, Clone)]
struct SequenceState {
next_sequence: i32,
last_acked_sequence: i32,
}
const SEQUENCE_SPACE: u32 = i32::MAX as u32 + 1;
const HALF_SEQUENCE_SPACE: u32 = SEQUENCE_SPACE / 2;
pub(crate) fn last_sequence_of_batch(base_sequence: i32, count: i32) -> Result<i32> {
if count <= 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"count must be positive",
));
}
Ok(((base_sequence as u32).wrapping_add((count - 1) as u32) % SEQUENCE_SPACE) as i32)
}
fn next_sequence_after(sequence: i32) -> i32 {
if !(0..i32::MAX).contains(&sequence) {
0
} else {
sequence + 1
}
}
fn is_newer_sequence(last_acked_sequence: i32, candidate_sequence: i32) -> bool {
if last_acked_sequence < 0 {
return true;
}
if candidate_sequence == last_acked_sequence {
return false;
}
if candidate_sequence < 0 {
return false;
}
let last = last_acked_sequence as u32;
let candidate = candidate_sequence as u32;
let forward_distance = if candidate >= last {
candidate - last
} else {
(SEQUENCE_SPACE - last) + candidate
};
forward_distance < HALF_SEQUENCE_SPACE
}
impl Default for SequenceState {
fn default() -> Self {
Self {
next_sequence: 0,
last_acked_sequence: -1,
}
}
}
impl ProducerIdentity {
pub fn new() -> Self {
Self {
poisoned: AtomicBool::new(false),
inner: RwLock::new(IdentityInner::uninitialized()),
}
}
pub fn is_initialized(&self) -> bool {
self.inner.read().is_initialized()
}
pub fn producer_id(&self) -> i64 {
self.inner.read().producer_id
}
pub fn producer_epoch(&self) -> i16 {
self.inner.read().producer_epoch
}
pub fn initialize(&self, producer_id: i64, producer_epoch: i16) {
let mut inner = self.inner.write();
inner.producer_id = producer_id;
inner.producer_epoch = producer_epoch;
self.poisoned.store(false, Ordering::Release);
inner.sequences.clear();
}
pub fn reset(&self) {
let mut inner = self.inner.write();
inner.producer_id = -1;
inner.producer_epoch = -1_i16;
self.poisoned.store(false, Ordering::Release);
inner.sequences.clear();
}
pub(crate) fn poison(&self) {
self.poisoned.store(true, Ordering::Release);
}
pub(crate) fn is_poisoned(&self) -> bool {
self.poisoned.load(Ordering::Acquire)
}
pub fn next_sequence(&self, topic: &str, partition: PartitionId) -> Result<i32> {
self.allocate_sequence(topic, partition, 1)
}
pub fn allocate_sequence(
&self,
topic: &str,
partition: PartitionId,
count: i32,
) -> Result<i32> {
if count <= 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"count must be positive",
));
}
let mut inner = self.inner.write();
let state = inner
.sequences
.entry(topic.to_string())
.or_default()
.entry(partition)
.or_default();
let base = state.next_sequence;
state.next_sequence = ((base as u32).wrapping_add(count as u32) % SEQUENCE_SPACE) as i32;
Ok(base)
}
pub fn peek_sequence(&self, topic: &str, partition: PartitionId) -> i32 {
let inner = self.inner.read();
inner
.sequences
.get(topic)
.and_then(|parts| parts.get(&partition))
.map(|s| s.next_sequence)
.unwrap_or(0)
}
pub fn acknowledge(&self, topic: &str, partition: PartitionId, sequence: i32) {
let mut inner = self.inner.write();
if let Some(state) = inner
.sequences
.get_mut(topic)
.and_then(|parts| parts.get_mut(&partition))
&& is_newer_sequence(state.last_acked_sequence, sequence)
{
state.last_acked_sequence = sequence;
}
}
pub fn rollback_sequence(&self, topic: &str, partition: PartitionId) -> Result<()> {
self.rollback_sequence_range(topic, partition, 1)
}
pub fn rollback_sequence_range(
&self,
topic: &str,
partition: PartitionId,
count: i32,
) -> Result<()> {
if count <= 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"count must be positive",
));
}
let mut inner = self.inner.write();
if let Some(state) = inner
.sequences
.get_mut(topic)
.and_then(|parts| parts.get_mut(&partition))
{
let current = state.next_sequence as u32;
state.next_sequence =
((current + SEQUENCE_SPACE - count as u32) % SEQUENCE_SPACE) as i32;
}
Ok(())
}
pub fn reset_sequence(&self, topic: &str, partition: PartitionId) {
let mut inner = self.inner.write();
if let Some(state) = inner
.sequences
.get_mut(topic)
.and_then(|parts| parts.get_mut(&partition))
{
state.next_sequence = next_sequence_after(state.last_acked_sequence);
}
}
pub fn reset_and_allocate(
&self,
topic: &str,
partition: PartitionId,
count: i32,
) -> Result<i32> {
if count <= 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"count must be positive",
));
}
let mut inner = self.inner.write();
let state = inner
.sequences
.entry(topic.to_string())
.or_default()
.entry(partition)
.or_default();
state.next_sequence = next_sequence_after(state.last_acked_sequence);
let base = state.next_sequence;
state.next_sequence = ((base as u32).wrapping_add(count as u32) % SEQUENCE_SPACE) as i32;
Ok(base)
}
pub fn last_acked_sequence(&self, topic: &str, partition: PartitionId) -> i32 {
let inner = self.inner.read();
inner
.sequences
.get(topic)
.and_then(|parts| parts.get(&partition))
.map(|s| s.last_acked_sequence)
.unwrap_or(-1)
}
#[cfg(test)]
pub(crate) fn can_retry_unknown_producer_id(
&self,
topic: &str,
partition: PartitionId,
base_sequence: i32,
count: i32,
) -> Result<bool> {
let last_sequence = last_sequence_of_batch(base_sequence, count)?;
let inner = self.inner.read();
let Some(state) = inner
.sequences
.get(topic)
.and_then(|parts| parts.get(&partition))
else {
return Ok(false);
};
Ok(
base_sequence == next_sequence_after(state.last_acked_sequence)
&& state.next_sequence == next_sequence_after(last_sequence),
)
}
pub(crate) fn check_and_reset_if_retryable(
&self,
topic: &str,
partition: PartitionId,
base_sequence: i32,
count: i32,
) -> Result<bool> {
let last_sequence = last_sequence_of_batch(base_sequence, count)?;
let mut inner = self.inner.write();
let Some(state) = inner
.sequences
.get(topic)
.and_then(|parts| parts.get(&partition))
else {
return Ok(false);
};
let retryable = base_sequence == next_sequence_after(state.last_acked_sequence)
&& state.next_sequence == next_sequence_after(last_sequence);
if retryable {
inner.producer_id = -1;
inner.producer_epoch = -1_i16;
inner.sequences.clear();
self.poisoned.store(false, Ordering::Release);
}
Ok(retryable)
}
pub(crate) fn checked_allocate_sequence(
&self,
topic: &str,
partition: PartitionId,
count: i32,
) -> Result<Option<i32>> {
if count <= 0 {
return Err(KrafkaError::protocol_kind(
ProtocolErrorKind::InvalidValue,
"count must be positive",
));
}
let mut inner = self.inner.write();
if !inner.is_initialized() {
return Ok(None);
}
let state = inner
.sequences
.entry(topic.to_string())
.or_default()
.entry(partition)
.or_default();
let base = state.next_sequence;
state.next_sequence = ((base as u32).wrapping_add(count as u32) % SEQUENCE_SPACE) as i32;
Ok(Some(base))
}
pub fn snapshot(&self) -> ProducerIdentitySnapshot {
let inner = self.inner.read();
let partition_sequences = inner
.sequences
.iter()
.flat_map(|(topic, parts)| {
parts
.iter()
.map(move |(part, state)| PartitionSequenceSnapshot {
topic: topic.clone(),
partition: *part,
next_sequence: state.next_sequence,
last_acked_sequence: state.last_acked_sequence,
})
})
.collect();
ProducerIdentitySnapshot {
producer_id: inner.producer_id,
producer_epoch: inner.producer_epoch,
partition_sequences,
}
}
pub fn remove_partition(&self, topic: &str, partition: PartitionId) {
let mut inner = self.inner.write();
if let Some(parts) = inner.sequences.get_mut(topic) {
parts.remove(&partition);
if parts.is_empty() {
inner.sequences.remove(topic);
}
}
}
pub fn remove_topic(&self, topic: &str) {
self.inner.write().sequences.remove(topic);
}
pub fn retain_partitions(&self, active: &ahash::AHashMap<String, Vec<PartitionId>>) {
let mut inner = self.inner.write();
inner.sequences.retain(|topic, parts| {
if let Some(active_parts) = active.get(topic.as_str()) {
parts.retain(|pid, _| active_parts.contains(pid));
!parts.is_empty()
} else {
false
}
});
}
pub fn restore_from_snapshot(&self, snapshot: &ProducerIdentitySnapshot) {
let mut inner = self.inner.write();
inner.producer_id = snapshot.producer_id;
inner.producer_epoch = snapshot.producer_epoch;
inner.sequences.clear();
for ps in &snapshot.partition_sequences {
inner.sequences.entry(ps.topic.clone()).or_default().insert(
ps.partition,
SequenceState {
next_sequence: ps.next_sequence,
last_acked_sequence: ps.last_acked_sequence,
},
);
}
self.poisoned.store(false, Ordering::Release);
}
#[cfg(test)]
fn set_sequence_state(
&self,
topic: &str,
partition: PartitionId,
next_sequence: i32,
last_acked_sequence: i32,
) {
self.inner
.write()
.sequences
.entry(topic.to_string())
.or_default()
.insert(
partition,
SequenceState {
next_sequence,
last_acked_sequence,
},
);
}
}
impl Default for ProducerIdentity {
fn default() -> Self {
Self::new()
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ProducerIdentitySnapshot {
pub producer_id: i64,
pub producer_epoch: i16,
pub partition_sequences: Vec<PartitionSequenceSnapshot>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct PartitionSequenceSnapshot {
pub topic: String,
pub partition: PartitionId,
pub next_sequence: i32,
pub last_acked_sequence: i32,
}
pub trait ProducerStateStore: Send + Sync {
fn load(
&self,
) -> impl std::future::Future<Output = Result<Option<ProducerIdentitySnapshot>>> + Send;
fn store(
&self,
snapshot: &ProducerIdentitySnapshot,
) -> impl std::future::Future<Output = Result<()>> + Send;
}
pub(crate) trait ErasedProducerStateStore: Send + Sync {
fn load_erased(
&self,
) -> Pin<
Box<dyn std::future::Future<Output = Result<Option<ProducerIdentitySnapshot>>> + Send + '_>,
>;
fn store_erased<'a>(
&'a self,
snapshot: &'a ProducerIdentitySnapshot,
) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>>;
}
impl<T: ProducerStateStore> ErasedProducerStateStore for T {
fn load_erased(
&self,
) -> Pin<
Box<dyn std::future::Future<Output = Result<Option<ProducerIdentitySnapshot>>> + Send + '_>,
> {
Box::pin(self.load())
}
fn store_erased<'a>(
&'a self,
snapshot: &'a ProducerIdentitySnapshot,
) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
Box::pin(self.store(snapshot))
}
}
impl<T: ProducerStateStore> ProducerStateStore for std::sync::Arc<T> {
fn load(
&self,
) -> impl std::future::Future<Output = Result<Option<ProducerIdentitySnapshot>>> + Send {
T::load(self)
}
fn store(
&self,
snapshot: &ProducerIdentitySnapshot,
) -> impl std::future::Future<Output = Result<()>> + Send {
T::store(self, snapshot)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_producer_identity_new() {
let identity = ProducerIdentity::new();
assert!(!identity.is_initialized());
assert_eq!(identity.producer_id(), -1);
assert_eq!(identity.producer_epoch(), -1);
}
#[test]
fn test_producer_identity_initialize() {
let identity = ProducerIdentity::new();
identity.initialize(12345, 0);
assert!(identity.is_initialized());
assert_eq!(identity.producer_id(), 12345);
assert_eq!(identity.producer_epoch(), 0);
}
#[test]
fn test_sequence_numbers() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), 0);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), 1);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), 2);
assert_eq!(identity.next_sequence("topic", 1).unwrap(), 0);
assert_eq!(identity.next_sequence("other-topic", 0).unwrap(), 0);
}
#[test]
fn test_peek_sequence() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
assert_eq!(identity.peek_sequence("topic", 0), 0);
assert_eq!(identity.peek_sequence("topic", 0), 0);
identity.next_sequence("topic", 0).unwrap();
assert_eq!(identity.peek_sequence("topic", 0), 1);
}
#[test]
fn test_acknowledge() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.next_sequence("topic", 0).unwrap();
identity.next_sequence("topic", 0).unwrap();
identity.next_sequence("topic", 0).unwrap();
identity.acknowledge("topic", 0, 1);
assert_eq!(identity.last_acked_sequence("topic", 0), 1);
identity.acknowledge("topic", 0, 0);
assert_eq!(identity.last_acked_sequence("topic", 0), 1);
identity.acknowledge("topic", 0, 2);
assert_eq!(identity.last_acked_sequence("topic", 0), 2);
}
#[test]
fn test_reset_sequence() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.next_sequence("topic", 0).unwrap();
identity.next_sequence("topic", 0).unwrap();
identity.next_sequence("topic", 0).unwrap();
identity.acknowledge("topic", 0, 1);
identity.reset_sequence("topic", 0);
assert_eq!(identity.peek_sequence("topic", 0), 2);
}
#[test]
fn test_acknowledge_wraps_from_max_to_zero() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.set_sequence_state("topic", 0, 1, i32::MAX);
identity.acknowledge("topic", 0, 0);
assert_eq!(identity.last_acked_sequence("topic", 0), 0);
identity.acknowledge("topic", 0, i32::MAX);
assert_eq!(identity.last_acked_sequence("topic", 0), 0);
}
#[test]
fn test_reset_sequence_wraps_to_zero_after_max_ack() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.set_sequence_state("topic", 0, 1, i32::MAX);
identity.reset_sequence("topic", 0);
assert_eq!(identity.peek_sequence("topic", 0), 0);
}
#[test]
fn test_reset_identity() {
let identity = ProducerIdentity::new();
identity.initialize(12345, 5);
identity.next_sequence("topic", 0).unwrap();
identity.reset();
assert!(!identity.is_initialized());
assert_eq!(identity.producer_id(), -1);
assert_eq!(identity.producer_epoch(), -1);
assert_eq!(identity.peek_sequence("topic", 0), 0);
}
#[test]
fn test_snapshot() {
let identity = ProducerIdentity::new();
identity.initialize(100, 1);
identity.next_sequence("topic1", 0).unwrap();
identity.next_sequence("topic1", 0).unwrap();
identity.acknowledge("topic1", 0, 0);
identity.next_sequence("topic2", 0).unwrap();
let snapshot = identity.snapshot();
assert_eq!(snapshot.producer_id, 100);
assert_eq!(snapshot.producer_epoch, 1);
assert_eq!(snapshot.partition_sequences.len(), 2);
}
#[test]
fn test_can_retry_unknown_producer_id_for_oldest_unacked_batch() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
assert_eq!(identity.allocate_sequence("topic", 0, 2).unwrap(), 0);
assert!(
identity
.can_retry_unknown_producer_id("topic", 0, 0, 2)
.unwrap()
);
}
#[test]
fn test_cannot_retry_unknown_producer_id_when_newer_batch_exists() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
assert_eq!(identity.allocate_sequence("topic", 0, 2).unwrap(), 0);
assert_eq!(identity.allocate_sequence("topic", 0, 1).unwrap(), 2);
assert!(
!identity
.can_retry_unknown_producer_id("topic", 0, 0, 2)
.unwrap()
);
}
#[test]
fn test_poison_flag_clears_on_reset_and_reinitialize() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.poison();
assert!(identity.is_poisoned());
identity.reset();
assert!(!identity.is_poisoned());
identity.initialize(2, 1);
assert!(!identity.is_poisoned());
}
#[test]
fn test_sequence_wrapping() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.set_sequence_state("topic", 0, i32::MAX, i32::MAX - 1);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), i32::MAX);
assert_eq!(identity.peek_sequence("topic", 0), 0);
}
#[test]
fn test_rollback_sequence() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), 0);
assert_eq!(identity.peek_sequence("topic", 0), 1);
identity.rollback_sequence("topic", 0).unwrap();
assert_eq!(identity.peek_sequence("topic", 0), 0);
assert_eq!(identity.next_sequence("topic", 0).unwrap(), 0);
}
#[test]
fn test_rollback_sequence_wraps_from_zero() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.set_sequence_state("topic", 0, 0, i32::MAX - 1);
identity.rollback_sequence("topic", 0).unwrap();
assert_eq!(identity.peek_sequence("topic", 0), i32::MAX);
}
#[test]
fn test_last_sequence_of_batch_single_record() {
assert_eq!(last_sequence_of_batch(0, 1).unwrap(), 0);
assert_eq!(last_sequence_of_batch(5, 1).unwrap(), 5);
assert_eq!(last_sequence_of_batch(i32::MAX, 1).unwrap(), i32::MAX);
}
#[test]
fn test_last_sequence_of_batch_multi_record() {
assert_eq!(last_sequence_of_batch(0, 5).unwrap(), 4);
assert_eq!(last_sequence_of_batch(10, 3).unwrap(), 12);
assert_eq!(last_sequence_of_batch(100, 100).unwrap(), 199);
}
#[test]
fn test_last_sequence_of_batch_wrapping() {
assert_eq!(last_sequence_of_batch(i32::MAX - 2, 5).unwrap(), 1);
assert_eq!(last_sequence_of_batch(i32::MAX, 2).unwrap(), 0);
}
#[test]
fn test_multi_record_batch_ack_then_reset() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
let base = identity.allocate_sequence("topic", 0, 5).unwrap();
assert_eq!(base, 0);
assert_eq!(identity.peek_sequence("topic", 0), 5);
let last_seq = last_sequence_of_batch(base, 5).unwrap();
assert_eq!(last_seq, 4);
identity.acknowledge("topic", 0, last_seq);
assert_eq!(identity.last_acked_sequence("topic", 0), 4);
identity.reset_sequence("topic", 0);
assert_eq!(identity.peek_sequence("topic", 0), 5);
}
#[test]
fn test_reset_and_allocate_atomic() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.allocate_sequence("topic", 0, 3).unwrap();
identity.acknowledge("topic", 0, 1);
let base = identity.reset_and_allocate("topic", 0, 5).unwrap();
assert_eq!(base, 2);
assert_eq!(identity.peek_sequence("topic", 0), 7);
}
#[test]
fn test_reset_and_allocate_no_prior_ack() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
identity.allocate_sequence("topic", 0, 3).unwrap();
let base = identity.reset_and_allocate("topic", 0, 2).unwrap();
assert_eq!(base, 0);
assert_eq!(identity.peek_sequence("topic", 0), 2);
}
#[test]
fn test_reset_and_allocate_fresh_partition() {
let identity = ProducerIdentity::new();
identity.initialize(1, 0);
let base = identity.reset_and_allocate("topic", 99, 3).unwrap();
assert_eq!(base, 0);
assert_eq!(identity.peek_sequence("topic", 99), 3);
}
#[test]
fn test_restore_from_snapshot_replaces_state() {
let identity = ProducerIdentity::new();
identity.initialize(100, 2);
identity.next_sequence("topic1", 0).unwrap();
identity.next_sequence("topic1", 0).unwrap();
let snapshot = ProducerIdentitySnapshot {
producer_id: 100,
producer_epoch: 2,
partition_sequences: vec![
PartitionSequenceSnapshot {
topic: "topic1".to_string(),
partition: 0,
next_sequence: 6,
last_acked_sequence: 5,
},
PartitionSequenceSnapshot {
topic: "topic2".to_string(),
partition: 1,
next_sequence: 10,
last_acked_sequence: 9,
},
],
};
identity.restore_from_snapshot(&snapshot);
assert_eq!(identity.peek_sequence("topic1", 0), 6);
assert_eq!(identity.peek_sequence("topic2", 1), 10);
}
#[test]
fn test_restore_from_snapshot_clears_poisoned() {
let identity = ProducerIdentity::new();
identity.initialize(50, 0);
identity.poison();
assert!(identity.is_poisoned());
let snapshot = ProducerIdentitySnapshot {
producer_id: 50,
producer_epoch: 0,
partition_sequences: vec![],
};
identity.restore_from_snapshot(&snapshot);
assert!(!identity.is_poisoned());
}
}