use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::oneshot;
use crate::network::BrokerConnection;
use crate::types::TopicPartitionKey;
use crate::{Error, Result, TransactionStateError};
pub struct TransactionCoordinatorConnection {
pub broker_id: i32,
pub address: String,
pub connection: BrokerConnection,
}
#[derive(Debug, Clone)]
pub struct TransactionCoordinator {
pub broker_id: i32,
pub address: String,
}
#[derive(Debug, Clone, Copy)]
pub struct ProducerIdentity {
pub id: i64,
pub epoch: i16,
}
pub struct PendingTransactionCompletion {
pub result: TransactionResult,
pub reply: Option<oneshot::Sender<Result<()>>>,
}
pub struct TransactionManager {
pub transactional_id: String,
pub transaction_timeout: Duration,
pub coordinator: Option<TransactionCoordinator>,
pub producer: Option<ProducerIdentity>,
pub sequences: HashMap<TopicPartitionKey, i32>,
state: TransactionState,
pub pending_completion: Option<PendingTransactionCompletion>,
}
impl TransactionManager {
pub fn new(transactional_id: String, transaction_timeout: Duration) -> Self {
Self {
transactional_id,
transaction_timeout,
coordinator: None,
producer: None,
sequences: HashMap::new(),
state: TransactionState::Ready,
pending_completion: None,
}
}
pub fn begin(&mut self) -> std::result::Result<(), TransactionStateError> {
match &self.state {
TransactionState::Ready => {
self.state = TransactionState::InTransaction;
Ok(())
}
TransactionState::InTransaction => Err(TransactionStateError::AlreadyInProgress),
TransactionState::Completing(result) => {
Err(TransactionStateError::Completing(result.label()))
}
TransactionState::AbortOnly(reason) => {
Err(TransactionStateError::MustAbortBeforeReuse(reason.clone()))
}
TransactionState::Fatal(reason) => Err(TransactionStateError::Fatal(reason.clone())),
}
}
pub fn ensure_append_allowed(&self) -> std::result::Result<(), TransactionStateError> {
match &self.state {
TransactionState::InTransaction => Ok(()),
TransactionState::Ready => Err(TransactionStateError::AppendWithoutBegin),
TransactionState::Completing(result) => {
Err(TransactionStateError::Completing(result.label()))
}
TransactionState::AbortOnly(reason) => {
Err(TransactionStateError::AbortRequired(reason.clone()))
}
TransactionState::Fatal(reason) => Err(TransactionStateError::Fatal(reason.clone())),
}
}
pub fn ensure_send_offsets_allowed(&self) -> std::result::Result<(), TransactionStateError> {
match &self.state {
TransactionState::InTransaction => Ok(()),
TransactionState::Ready => Err(TransactionStateError::SendOffsetsWithoutBegin),
TransactionState::Completing(result) => {
Err(TransactionStateError::Completing(result.label()))
}
TransactionState::AbortOnly(reason) => {
Err(TransactionStateError::AbortRequired(reason.clone()))
}
TransactionState::Fatal(reason) => Err(TransactionStateError::Fatal(reason.clone())),
}
}
pub fn request_completion(
&mut self,
result: TransactionResult,
reply: Option<oneshot::Sender<Result<()>>>,
) -> std::result::Result<(), TransactionStateError> {
match (&self.state, result) {
(TransactionState::InTransaction, requested) => {
self.state = TransactionState::Completing(requested);
self.pending_completion = Some(PendingTransactionCompletion {
result: requested,
reply,
});
Ok(())
}
(TransactionState::AbortOnly(_), TransactionResult::Abort) => {
self.state = TransactionState::Completing(TransactionResult::Abort);
self.pending_completion = Some(PendingTransactionCompletion {
result: TransactionResult::Abort,
reply,
});
Ok(())
}
(TransactionState::AbortOnly(reason), TransactionResult::Commit) => {
Err(TransactionStateError::CommitRequiresAbort(reason.clone()))
}
(TransactionState::Ready, _) => Err(TransactionStateError::NoActiveTransaction),
(TransactionState::Completing(existing), _) => {
Err(TransactionStateError::Completing(existing.label()))
}
(TransactionState::Fatal(reason), _) => {
Err(TransactionStateError::Fatal(reason.clone()))
}
}
}
pub fn ensure_shutdown_abort(&mut self) {
if self.pending_completion.is_some() {
return;
}
if matches!(
self.state,
TransactionState::InTransaction | TransactionState::AbortOnly(_)
) {
self.state = TransactionState::Completing(TransactionResult::Abort);
self.pending_completion = Some(PendingTransactionCompletion {
result: TransactionResult::Abort,
reply: None,
});
}
}
pub fn mark_abort_only(&mut self, message: String) {
if matches!(self.state, TransactionState::Fatal(_)) {
return;
}
if !matches!(self.state, TransactionState::Completing(_)) {
self.state = TransactionState::AbortOnly(message);
}
}
pub fn mark_fatal(&mut self, message: String) {
self.state = TransactionState::Fatal(message);
self.pending_completion = None;
}
pub fn finish_success(&mut self, producer: ProducerIdentity) {
self.reset_with_new_producer(producer);
self.state = TransactionState::Ready;
}
pub fn reset_with_new_producer(&mut self, producer: ProducerIdentity) {
self.producer = Some(producer);
self.sequences.clear();
self.pending_completion = None;
}
pub fn transactional_id_if_active(&self) -> Option<&str> {
match self.state {
TransactionState::Ready | TransactionState::Fatal(_) => None,
TransactionState::InTransaction
| TransactionState::Completing(_)
| TransactionState::AbortOnly(_) => Some(self.transactional_id.as_str()),
}
}
pub fn shutdown_result(&self) -> std::result::Result<(), TransactionStateError> {
match &self.state {
TransactionState::Ready => Ok(()),
TransactionState::InTransaction => {
Err(TransactionStateError::ShutdownWithActiveTransaction)
}
TransactionState::Completing(result) => Err(
TransactionStateError::ShutdownWhileCompleting(result.label()),
),
TransactionState::AbortOnly(reason) => {
Err(TransactionStateError::ShutdownAbortRequired(reason.clone()))
}
TransactionState::Fatal(reason) => {
Err(TransactionStateError::ShutdownFatal(reason.clone()))
}
}
}
pub fn fail_pending(&mut self, message: &str) {
if let Some(pending) = self.pending_completion.take()
&& let Some(reply) = pending.reply
{
let _ = reply.send(Err(Error::Internal(anyhow::anyhow!(message.to_owned()))));
}
}
}
#[derive(Debug, Clone)]
enum TransactionState {
Ready,
InTransaction,
Completing(TransactionResult),
AbortOnly(String),
Fatal(String),
}
#[derive(Debug, Clone, Copy)]
pub enum TransactionResult {
Commit,
Abort,
}
impl TransactionResult {
fn label(self) -> &'static str {
match self {
Self::Commit => "commit",
Self::Abort => "abort",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn append_requires_active_transaction() {
let manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
let error = manager.ensure_append_allowed().unwrap_err();
assert!(matches!(error, TransactionStateError::AppendWithoutBegin));
}
#[test]
fn abort_only_state_rejects_commit_and_allows_abort() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
manager.begin().unwrap();
manager.mark_abort_only("boom".to_owned());
let commit_error = manager
.request_completion(TransactionResult::Commit, None)
.unwrap_err();
assert!(matches!(
commit_error,
TransactionStateError::CommitRequiresAbort(reason) if reason == "boom"
));
manager
.request_completion(TransactionResult::Abort, None)
.unwrap();
}
#[test]
fn successful_completion_resets_state_and_sequences() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
manager.begin().unwrap();
manager
.sequences
.insert(TopicPartitionKey::new("topic-a".to_owned(), 0), 12);
manager.finish_success(ProducerIdentity { id: 7, epoch: 2 });
assert!(manager.transactional_id_if_active().is_none());
assert!(manager.sequences.is_empty());
assert_eq!(manager.producer.unwrap().id, 7);
assert_eq!(manager.producer.unwrap().epoch, 2);
}
#[test]
fn begin_reports_state_specific_errors() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
manager.begin().unwrap();
assert!(matches!(
manager.begin().unwrap_err(),
TransactionStateError::AlreadyInProgress
));
manager
.request_completion(TransactionResult::Commit, None)
.unwrap();
assert!(matches!(
manager.begin().unwrap_err(),
TransactionStateError::Completing("commit")
));
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
manager.mark_abort_only("abort-me".to_owned());
assert!(matches!(
manager.begin().unwrap_err(),
TransactionStateError::MustAbortBeforeReuse(reason) if reason == "abort-me"
));
manager.mark_fatal("fatal".to_owned());
assert!(matches!(
manager.begin().unwrap_err(),
TransactionStateError::Fatal(reason) if reason == "fatal"
));
}
#[test]
fn send_offsets_and_completion_require_active_transaction() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
assert!(matches!(
manager.ensure_send_offsets_allowed().unwrap_err(),
TransactionStateError::SendOffsetsWithoutBegin
));
assert!(matches!(
manager
.request_completion(TransactionResult::Commit, None)
.unwrap_err(),
TransactionStateError::NoActiveTransaction
));
manager.begin().unwrap();
assert!(manager.ensure_send_offsets_allowed().is_ok());
manager
.request_completion(TransactionResult::Abort, None)
.unwrap();
assert!(matches!(
manager.ensure_send_offsets_allowed().unwrap_err(),
TransactionStateError::Completing("abort")
));
assert!(matches!(
manager
.request_completion(TransactionResult::Commit, None)
.unwrap_err(),
TransactionStateError::Completing("abort")
));
}
#[test]
fn abort_only_and_fatal_states_control_append_and_shutdown() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
manager.begin().unwrap();
manager.mark_abort_only("bad send".to_owned());
assert_eq!(manager.transactional_id_if_active(), Some("tx-a"));
assert!(matches!(
manager.ensure_append_allowed().unwrap_err(),
TransactionStateError::AbortRequired(reason) if reason == "bad send"
));
assert!(matches!(
manager.shutdown_result().unwrap_err(),
TransactionStateError::ShutdownAbortRequired(reason) if reason == "bad send"
));
manager.mark_fatal("coordinator fenced".to_owned());
assert!(manager.transactional_id_if_active().is_none());
assert!(matches!(
manager.ensure_append_allowed().unwrap_err(),
TransactionStateError::Fatal(reason) if reason == "coordinator fenced"
));
assert!(matches!(
manager.shutdown_result().unwrap_err(),
TransactionStateError::ShutdownFatal(reason) if reason == "coordinator fenced"
));
}
#[test]
fn shutdown_abort_and_pending_reply_paths_are_explicit() {
let mut manager = TransactionManager::new("tx-a".to_owned(), Duration::from_secs(30));
assert!(manager.shutdown_result().is_ok());
manager.ensure_shutdown_abort();
assert!(manager.pending_completion.is_none());
manager.begin().unwrap();
assert!(matches!(
manager.shutdown_result().unwrap_err(),
TransactionStateError::ShutdownWithActiveTransaction
));
manager.ensure_shutdown_abort();
assert!(matches!(
manager.shutdown_result().unwrap_err(),
TransactionStateError::ShutdownWhileCompleting("abort")
));
assert!(matches!(
manager
.pending_completion
.as_ref()
.map(|pending| pending.result),
Some(TransactionResult::Abort)
));
let (reply, result) = oneshot::channel();
manager.pending_completion = Some(PendingTransactionCompletion {
result: TransactionResult::Commit,
reply: Some(reply),
});
manager.fail_pending("failed completion");
let error = result
.blocking_recv()
.expect("reply should be sent")
.unwrap_err();
assert!(error.to_string().contains("failed completion"));
assert!(manager.pending_completion.is_none());
}
}