use std::collections::{BTreeMap, HashSet};
use bytes::{Bytes, BytesMut};
use crabka_protocol::ProtocolError;
use crabka_protocol::primitives::array::{
get_array_len, get_nullable_array_len, put_array_len, put_nullable_array_len,
};
use crabka_protocol::primitives::fixed::{
get_i8, get_i16, get_i32, get_i64, put_i8, put_i16, put_i32, put_i64,
};
use crabka_protocol::primitives::string_bytes::{
get_compact_string_owned, get_string_owned, put_compact_string, put_string,
};
use crabka_protocol::tagged_fields::{UnknownTaggedFields, WriteTaggedFields, read_tagged_fields};
use crate::error::BrokerError;
use crate::txn::state::{TopicPartition, TxnEntry, TxnState};
const TAG_PREV_PRODUCER_ID: u32 = 0;
const TAG_NEXT_PRODUCER_ID: u32 = 1;
const TAG_CLIENT_TXN_VERSION: u32 = 2;
const PRODUCER_ID_NONE: i64 = -1;
fn group_partitions(partitions: &HashSet<TopicPartition>) -> Vec<(&str, Vec<i32>)> {
let mut by_topic: BTreeMap<&str, Vec<i32>> = BTreeMap::new();
for tp in partitions {
by_topic.entry(&tp.topic).or_default().push(tp.partition);
}
by_topic
.into_iter()
.map(|(topic, mut ids)| {
ids.sort_unstable();
(topic, ids)
})
.collect()
}
pub(crate) fn encode_value(entry: &TxnEntry, flexible: bool) -> Vec<u8> {
let version: i16 = i16::from(flexible);
let mut buf = BytesMut::new();
put_i16(&mut buf, version);
put_i64(&mut buf, entry.producer_id);
put_i16(&mut buf, entry.producer_epoch);
put_i32(&mut buf, entry.txn_timeout_ms);
put_i8(&mut buf, entry.state.to_kafka_status());
let groups = group_partitions(&entry.partitions);
if groups.is_empty() {
put_nullable_array_len(&mut buf, None, flexible);
} else {
put_nullable_array_len(&mut buf, Some(groups.len()), flexible);
for (topic, ids) in &groups {
if flexible {
put_compact_string(&mut buf, topic);
} else {
put_string(&mut buf, topic);
}
put_array_len(&mut buf, ids.len(), flexible);
for id in ids {
put_i32(&mut buf, *id);
}
if flexible {
WriteTaggedFields::new().write(&mut buf, &UnknownTaggedFields::default());
}
}
}
put_i64(&mut buf, entry.last_update_ms);
put_i64(&mut buf, entry.start_ms);
if flexible {
let mut tagged = WriteTaggedFields::new();
if entry.prev_producer_id != PRODUCER_ID_NONE {
tagged.add(TAG_PREV_PRODUCER_ID, i64_to_bytes(entry.prev_producer_id));
}
if entry.next_producer_id != PRODUCER_ID_NONE {
tagged.add(TAG_NEXT_PRODUCER_ID, i64_to_bytes(entry.next_producer_id));
}
tagged.write(&mut buf, &UnknownTaggedFields::default());
}
buf.to_vec()
}
fn i64_to_bytes(v: i64) -> Bytes {
let mut b = BytesMut::with_capacity(8);
put_i64(&mut b, v);
b.freeze()
}
pub(crate) fn decode_value(
bytes: &[u8],
transactional_id: String,
) -> Result<TxnEntry, BrokerError> {
let mut buf = bytes;
let version = get_i16(&mut buf)?;
let flexible = match version {
0 => false,
1 => true,
_ => {
return Err(BrokerError::Protocol(ProtocolError::InvalidValue(
"unsupported TransactionLogValue version",
)));
}
};
let producer_id = get_i64(&mut buf)?;
let producer_epoch = get_i16(&mut buf)?;
let txn_timeout_ms = get_i32(&mut buf)?;
let status = get_i8(&mut buf)?;
let state = TxnState::from_kafka_status(status).ok_or(BrokerError::Protocol(
ProtocolError::InvalidValue("unknown TransactionStatus"),
))?;
let mut partitions = HashSet::new();
if let Some(count) = get_nullable_array_len(&mut buf, flexible)? {
for _ in 0..count {
let topic = if flexible {
get_compact_string_owned(&mut buf)?
} else {
get_string_owned(&mut buf)?
};
let id_count = get_array_len(&mut buf, flexible)?;
for _ in 0..id_count {
let partition = get_i32(&mut buf)?;
partitions.insert(TopicPartition {
topic: topic.clone(),
partition,
});
}
if flexible {
read_tagged_fields(&mut buf, |_, _| Ok(false))?;
}
}
}
let last_update_ms = get_i64(&mut buf)?;
let start_ms = get_i64(&mut buf)?;
let mut prev_producer_id = PRODUCER_ID_NONE;
let mut next_producer_id = PRODUCER_ID_NONE;
if flexible {
read_tagged_fields(&mut buf, |tag, payload| match tag {
TAG_PREV_PRODUCER_ID => {
prev_producer_id = get_i64(payload)?;
Ok(true)
}
TAG_NEXT_PRODUCER_ID => {
next_producer_id = get_i64(payload)?;
Ok(true)
}
TAG_CLIENT_TXN_VERSION => {
let _ = get_i16(payload)?;
Ok(true)
}
_ => Ok(false),
})?;
}
if !buf.is_empty() {
return Err(BrokerError::Protocol(ProtocolError::InvalidValue(
"TransactionLogValue: trailing bytes after decode",
)));
}
Ok(TxnEntry {
transactional_id,
producer_id,
producer_epoch,
state,
txn_timeout_ms,
partitions,
prev_producer_id,
next_producer_id,
last_update_ms,
start_ms,
})
}
pub(crate) fn encode_key(transactional_id: &str) -> Vec<u8> {
let mut buf = BytesMut::new();
put_i16(&mut buf, 0);
put_string(&mut buf, transactional_id);
buf.to_vec()
}
pub(crate) fn decode_key(bytes: &[u8]) -> Result<String, BrokerError> {
let mut buf = bytes;
let version = get_i16(&mut buf)?;
if version != 0 {
return Err(BrokerError::Protocol(ProtocolError::InvalidValue(
"unsupported TransactionLogKey version",
)));
}
let transactional_id = get_string_owned(&mut buf)?;
if !buf.is_empty() {
return Err(BrokerError::Protocol(ProtocolError::InvalidValue(
"TransactionLogKey: trailing bytes after decode",
)));
}
Ok(transactional_id)
}
#[cfg(test)]
mod tests {
use assert2::assert;
use super::*;
#[rustfmt::skip]
const SAMPLE: &[u8] = &[
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xea, 0x60, 0x01, 0x02, 0x07, b't', b'x', b't', b'e', b's', b't', 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x9e, 0x7b, 0x4b, 0x36, 0x7a, 0x00, 0x00, 0x01, 0x9e, 0x7b, 0x4b, 0x36, 0x7a, 0x00, ];
const SAMPLE_TS: i64 = 0x0000_019e_7b4b_367a;
fn sample_entry() -> TxnEntry {
let mut partitions = HashSet::new();
partitions.insert(TopicPartition {
topic: "txtest".into(),
partition: 0,
});
TxnEntry {
transactional_id: "my-txn-id".into(),
producer_id: 0,
producer_epoch: 0,
state: TxnState::Ongoing,
txn_timeout_ms: 60_000,
partitions,
prev_producer_id: -1,
next_producer_id: -1,
last_update_ms: SAMPLE_TS,
start_ms: SAMPLE_TS,
}
}
#[test]
fn sample_bytes_decode() {
let entry = decode_value(SAMPLE, "my-txn-id".into()).unwrap();
assert!(entry.producer_id == 0);
assert!(entry.producer_epoch == 0);
assert!(entry.txn_timeout_ms == 60_000);
assert!(entry.state == TxnState::Ongoing);
assert!(entry.prev_producer_id == -1);
assert!(entry.next_producer_id == -1);
assert!(entry.last_update_ms == SAMPLE_TS);
assert!(entry.start_ms == SAMPLE_TS);
let expected: HashSet<TopicPartition> = [TopicPartition {
topic: "txtest".into(),
partition: 0,
}]
.into_iter()
.collect();
assert!(entry.partitions == expected);
}
#[test]
fn sample_bytes_encode_byte_identical() {
let encoded = encode_value(&sample_entry(), true);
assert!(
encoded == SAMPLE,
"encode_value did not byte-match SAMPLE\n expected: {:02x?}\n actual: {:02x?}",
SAMPLE,
encoded
);
}
#[test]
fn v1_round_trip_multi_topic_nondefault_ids() {
let mut partitions = HashSet::new();
partitions.insert(TopicPartition {
topic: "zebra".into(),
partition: 5,
});
partitions.insert(TopicPartition {
topic: "zebra".into(),
partition: 1,
});
partitions.insert(TopicPartition {
topic: "alpha".into(),
partition: 3,
});
let entry = TxnEntry {
transactional_id: "tid".into(),
producer_id: 42,
producer_epoch: 7,
state: TxnState::PrepareCommit,
txn_timeout_ms: 30_000,
partitions,
prev_producer_id: 100,
next_producer_id: 200,
last_update_ms: 1_234_567,
start_ms: 1_000_000,
};
let first = encode_value(&entry, true);
let decoded = decode_value(&first, "tid".into()).unwrap();
assert!(decoded.producer_id == entry.producer_id);
assert!(decoded.producer_epoch == entry.producer_epoch);
assert!(decoded.state == entry.state);
assert!(decoded.txn_timeout_ms == entry.txn_timeout_ms);
assert!(decoded.prev_producer_id == 100);
assert!(decoded.next_producer_id == 200);
assert!(decoded.last_update_ms == entry.last_update_ms);
assert!(decoded.start_ms == entry.start_ms);
assert!(decoded.partitions == entry.partitions);
let second = encode_value(&decoded, true);
assert!(first == second);
}
#[test]
fn v0_round_trip_no_tagged_section() {
let mut partitions = HashSet::new();
partitions.insert(TopicPartition {
topic: "t".into(),
partition: 0,
});
let entry = TxnEntry {
transactional_id: "tid".into(),
producer_id: 9,
producer_epoch: 2,
state: TxnState::Ongoing,
txn_timeout_ms: 60_000,
partitions,
prev_producer_id: 5,
next_producer_id: 6,
last_update_ms: 111,
start_ms: 222,
};
let encoded = encode_value(&entry, false);
assert!(encoded[0] == 0x00 && encoded[1] == 0x00);
let decoded = decode_value(&encoded, "tid".into()).unwrap();
assert!(decoded.producer_id == 9);
assert!(decoded.state == TxnState::Ongoing);
assert!(decoded.partitions == entry.partitions);
assert!(decoded.last_update_ms == 111);
assert!(decoded.start_ms == 222);
assert!(decoded.prev_producer_id == -1);
assert!(decoded.next_producer_id == -1);
}
#[test]
fn key_round_trip() {
let encoded = encode_key("abc");
assert!(decode_key(&encoded).unwrap() == "abc");
assert!(encoded == &[0x00, 0x00, 0x00, 0x03, b'a', b'b', b'c']);
}
#[test]
fn encode_is_deterministic_across_hashset_orders() {
let make = |order: &[(&str, i32)]| {
let mut partitions = HashSet::new();
for (t, p) in order {
partitions.insert(TopicPartition {
topic: (*t).into(),
partition: *p,
});
}
TxnEntry {
transactional_id: "tid".into(),
producer_id: 1,
producer_epoch: 0,
state: TxnState::Ongoing,
txn_timeout_ms: 60_000,
partitions,
prev_producer_id: -1,
next_producer_id: -1,
last_update_ms: 1,
start_ms: 1,
}
};
let a = make(&[("b", 2), ("a", 1), ("b", 0), ("a", 3)]);
let b = make(&[("a", 3), ("b", 0), ("a", 1), ("b", 2)]);
assert!(encode_value(&a, true) == encode_value(&b, true));
assert!(encode_value(&a, false) == encode_value(&b, false));
}
#[test]
fn decode_value_rejects_truncated_input() {
assert!(decode_value(&SAMPLE[..10], "t".into()).is_err());
assert!(decode_value(&SAMPLE[..1], "t".into()).is_err());
assert!(decode_value(&[], "t".into()).is_err());
}
#[test]
fn decode_value_rejects_unknown_version() {
let mut bad = SAMPLE.to_vec();
bad[0] = 0x00;
bad[1] = 0x02; assert!(decode_value(&bad, "t".into()).is_err());
}
#[test]
fn decode_value_rejects_trailing_bytes() {
let mut extra = SAMPLE.to_vec();
extra.push(0xff); assert!(decode_value(&extra, "t".into()).is_err());
}
#[test]
fn decode_key_rejects_unknown_version_and_truncation() {
let key = encode_key("abc");
let mut bad = key.clone();
bad[1] = 0x09;
assert!(decode_key(&bad).is_err());
assert!(decode_key(&key[..1]).is_err());
}
#[test]
fn empty_partitions_round_trips_as_null_both_versions() {
let e = TxnEntry::new_empty("tid".into(), 5, 0, 30_000, 100);
for flexible in [false, true] {
let bytes = encode_value(&e, flexible);
let decoded = decode_value(&bytes, "tid".into()).expect("decode");
assert!(decoded.partitions.is_empty());
assert!(decoded.producer_id == 5);
}
}
}