#![allow(dead_code)]
use std::collections::HashSet;
use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use tokio::sync::{Mutex, RwLock};
use tracing::{info, warn};
use crabka_metadata::MetadataImage;
use crabka_protocol::records::{Record, RecordBatch};
use crate::coordinator::unified::classic_state::OffsetEntry;
use crate::error::BrokerError;
use crate::partition_registry::PartitionRegistry;
use crate::txn::bootstrap;
use crate::txn::partitioner::partition_for_tid;
use crate::txn::state::TxnEntry;
pub(crate) type OffsetKey = (String, i32);
pub(crate) type PendingTxnOffsets =
std::collections::HashMap<String, Vec<(OffsetKey, OffsetEntry)>>;
pub(crate) struct TxnCoordinator {
pub(crate) node_id: crabka_metadata::NodeId,
pub(crate) partitions: Arc<PartitionRegistry>,
pub(crate) producer_ids: Arc<crate::producer_id_manager::ProducerIdManager>,
state: DashMap<String, Arc<Mutex<TxnEntry>>>,
leader_partitions: RwLock<HashSet<i32>>,
pid_to_tid: DashMap<i64, String>,
pending_txn_offsets: DashMap<i64, PendingTxnOffsets>,
}
impl TxnCoordinator {
pub(crate) fn new(
node_id: crabka_metadata::NodeId,
partitions: Arc<PartitionRegistry>,
producer_ids: Arc<crate::producer_id_manager::ProducerIdManager>,
) -> Self {
Self {
node_id,
partitions,
producer_ids,
state: DashMap::new(),
leader_partitions: RwLock::new(HashSet::new()),
pid_to_tid: DashMap::new(),
pending_txn_offsets: DashMap::new(),
}
}
pub(crate) fn buffer_txn_offsets(
&self,
producer_id: i64,
group_id: &str,
entries: Vec<(OffsetKey, OffsetEntry)>,
) {
if entries.is_empty() {
return;
}
self.pending_txn_offsets
.entry(producer_id)
.or_default()
.entry(group_id.to_string())
.or_default()
.extend(entries);
}
pub(crate) fn take_txn_offsets(&self, producer_id: i64) -> PendingTxnOffsets {
self.pending_txn_offsets
.remove(&producer_id)
.map(|(_, v)| v)
.unwrap_or_default()
}
pub(crate) async fn refresh_leader_partitions(&self, image: &MetadataImage) {
let mut set = HashSet::new();
for p in image.partitions_of(bootstrap::TOPIC) {
if p.leader == self.node_id {
set.insert(p.partition);
}
}
*self.leader_partitions.write().await = set;
}
#[allow(clippy::unused_self)]
pub(crate) fn partition_for(&self, tid: &str) -> i32 {
partition_for_tid(tid, bootstrap::NUM_PARTITIONS)
}
pub(crate) async fn is_coordinator_for(&self, tid: &str) -> bool {
let p = self.partition_for(tid);
self.leader_partitions.read().await.contains(&p)
}
pub(crate) fn get(&self, tid: &str) -> Option<Arc<Mutex<TxnEntry>>> {
self.state.get(tid).map(|e| e.value().clone())
}
pub(crate) fn tid_for_pid(&self, pid: i64) -> Option<String> {
self.pid_to_tid.get(&pid).map(|e| e.value().clone())
}
fn evict_rolled_pid(pid_to_tid: &DashMap<i64, String>, entry: &TxnEntry) {
if entry.prev_producer_id >= 0 && entry.prev_producer_id != entry.producer_id {
pid_to_tid.remove(&entry.prev_producer_id);
}
}
pub(crate) async fn snapshot(&self) -> Vec<TxnEntry> {
let handles: Vec<Arc<Mutex<TxnEntry>>> =
self.state.iter().map(|e| e.value().clone()).collect();
let mut out = Vec::with_capacity(handles.len());
for h in handles {
let entry = h.lock().await;
out.push(entry.clone());
}
out
}
pub(crate) async fn put(
&self,
entry: TxnEntry,
txnv: crate::txn::version::TxnVersion,
) -> Result<(), BrokerError> {
let tid = entry.transactional_id.clone();
let p = self.partition_for(&tid);
let part = self
.partitions
.get(bootstrap::TOPIC, p)
.ok_or_else(|| BrokerError::Txn(format!("__transaction_state-{p} not local")))?;
let key = crate::txn::log_record::encode_key(&tid);
let value = crate::txn::log_record::encode_value(&entry, txnv.flexible_records());
let mut batch = RecordBatch::default();
batch.records.push(Record {
offset_delta: 0,
key: Some(Bytes::from(key)),
value: Some(Bytes::from(value)),
..Default::default()
});
batch.last_offset_delta = 0;
part.produce_batch(batch).await?;
Self::evict_rolled_pid(&self.pid_to_tid, &entry);
self.pid_to_tid
.insert(entry.producer_id, entry.transactional_id.clone());
self.state.insert(tid, Arc::new(Mutex::new(entry)));
Ok(())
}
pub(crate) async fn recover(&self, image: &MetadataImage) -> Result<(), BrokerError> {
self.refresh_leader_partitions(image).await;
let local_partitions: Vec<i32> = self
.leader_partitions
.read()
.await
.iter()
.copied()
.collect();
for p in local_partitions {
let Some(part) = self.partitions.get(bootstrap::TOPIC, p) else {
continue;
};
let mut offset = part.log_start_offset();
loop {
let out = match part.read_log(offset, 1 << 20) {
Ok(o) => o,
Err(e) => {
warn!(
partition = p,
error = %e,
"read error during __transaction_state recovery; skipping partition"
);
break;
}
};
if out.batches.is_empty() {
break;
}
for batch in &out.batches {
for rec in &batch.records {
let Some(key_bytes) = rec.key.as_ref() else {
warn!(
partition = p,
"__transaction_state record missing key; skipping"
);
continue;
};
let tid = match crate::txn::log_record::decode_key(key_bytes) {
Ok(t) => t,
Err(e) => {
warn!(
partition = p,
error = %e,
"invalid TransactionLogKey in __transaction_state; skipping"
);
continue;
}
};
let Some(value_bytes) = rec.value.as_ref() else {
self.state.remove(&tid);
continue;
};
let entry = match crate::txn::log_record::decode_value(value_bytes, tid) {
Ok(e) => e,
Err(e) => {
warn!(
partition = p,
error = %e,
"invalid TransactionLogValue in __transaction_state; skipping"
);
continue;
}
};
self.pid_to_tid
.insert(entry.producer_id, entry.transactional_id.clone());
self.state
.insert(entry.transactional_id.clone(), Arc::new(Mutex::new(entry)));
}
offset = batch.base_offset + i64::from(batch.last_offset_delta) + 1;
}
}
}
info!(
tids_loaded = self.state.len(),
"TxnCoordinator recovery complete"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn entry(pid: i64, prev: i64) -> TxnEntry {
let mut e = TxnEntry::new_empty("tid-a".into(), pid, 0, 60_000, 0);
e.prev_producer_id = prev;
e
}
#[test]
fn evict_rolled_pid_drops_only_the_prior_id_on_a_roll() {
let map: DashMap<i64, String> = DashMap::new();
map.insert(1000, "tid-a".into());
TxnCoordinator::evict_rolled_pid(&map, &entry(2000, 1000));
map.insert(2000, "tid-a".into());
assert!(
map.get(&1000).is_none(),
"stale pre-roll pid must be evicted"
);
assert!(map.get(&2000).map(|e| e.value().clone()) == Some("tid-a".into()));
}
#[test]
fn evict_rolled_pid_is_noop_without_a_roll() {
let map: DashMap<i64, String> = DashMap::new();
map.insert(1000, "tid-a".into());
TxnCoordinator::evict_rolled_pid(&map, &entry(1000, -1));
assert!(map.get(&1000).is_some());
TxnCoordinator::evict_rolled_pid(&map, &entry(1000, 1000));
assert!(map.get(&1000).is_some());
}
#[test]
fn evict_rolled_pid_is_idempotent_after_the_id_is_gone() {
let map: DashMap<i64, String> = DashMap::new();
map.insert(2000, "tid-a".into());
TxnCoordinator::evict_rolled_pid(&map, &entry(2000, 1000));
TxnCoordinator::evict_rolled_pid(&map, &entry(2000, 1000));
assert!(map.get(&1000).is_none());
assert!(map.get(&2000).is_some());
}
}