Skip to main content

laminar_core/sink/
transaction.rs

1//! Transaction management for exactly-once sinks
2
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use super::error::SinkError;
6use super::traits::{SinkState, TransactionId};
7use crate::operator::Output;
8
9/// State machine for transaction lifecycle
10#[derive(Debug)]
11pub struct TransactionState {
12    /// Current transaction ID (if any)
13    current_tx: Option<TransactionId>,
14
15    /// Current state
16    state: SinkState,
17
18    /// Transaction counter for ID generation
19    next_tx_id: AtomicU64,
20
21    /// Number of writes in current transaction
22    write_count: u64,
23
24    /// Number of records in current transaction
25    record_count: u64,
26}
27
28impl TransactionState {
29    /// Create a new transaction state
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            current_tx: None,
34            state: SinkState::Idle,
35            next_tx_id: AtomicU64::new(1),
36            write_count: 0,
37            record_count: 0,
38        }
39    }
40
41    /// Check if the state is idle (no active transaction)
42    #[must_use]
43    pub fn is_idle(&self) -> bool {
44        self.state == SinkState::Idle
45    }
46
47    /// Check if there's an active transaction
48    #[must_use]
49    pub fn is_active(&self) -> bool {
50        self.state == SinkState::InTransaction
51    }
52
53    /// Get the current state
54    #[must_use]
55    pub fn state(&self) -> SinkState {
56        self.state
57    }
58
59    /// Get the current transaction ID
60    #[must_use]
61    pub fn current_transaction(&self) -> Option<&TransactionId> {
62        self.current_tx.as_ref()
63    }
64
65    /// Get the write count for the current transaction
66    #[must_use]
67    pub fn write_count(&self) -> u64 {
68        self.write_count
69    }
70
71    /// Get the record count for the current transaction
72    #[must_use]
73    pub fn record_count(&self) -> u64 {
74        self.record_count
75    }
76
77    /// Begin a new transaction
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if a transaction is already active.
82    pub fn begin(&mut self, tx_id: TransactionId) -> Result<(), SinkError> {
83        if !self.state.can_begin_transaction() {
84            return Err(SinkError::TransactionAlreadyActive(
85                self.current_tx
86                    .as_ref()
87                    .map_or_else(|| "unknown".to_string(), ToString::to_string),
88            ));
89        }
90
91        self.current_tx = Some(tx_id);
92        self.state = SinkState::InTransaction;
93        self.write_count = 0;
94        self.record_count = 0;
95
96        Ok(())
97    }
98
99    /// Generate a new transaction ID and begin
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if a transaction is already active.
104    pub fn begin_new(&mut self) -> Result<TransactionId, SinkError> {
105        let tx_id = TransactionId::new(self.next_tx_id.fetch_add(1, Ordering::SeqCst));
106        self.begin(tx_id.clone())?;
107        Ok(tx_id)
108    }
109
110    /// Record a write operation
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if there's no active transaction.
115    pub fn record_write(&mut self, record_count: u64) -> Result<(), SinkError> {
116        if !self.is_active() {
117            return Err(SinkError::NoActiveTransaction);
118        }
119
120        self.write_count += 1;
121        self.record_count += record_count;
122
123        Ok(())
124    }
125
126    /// Commit the current transaction
127    ///
128    /// # Errors
129    ///
130    /// Returns an error if the transaction ID doesn't match or if
131    /// there's no active transaction.
132    pub fn commit(&mut self, tx_id: &TransactionId) -> Result<(), SinkError> {
133        self.validate_tx_id(tx_id)?;
134
135        if !self.state.can_commit() {
136            return Err(SinkError::NoActiveTransaction);
137        }
138
139        self.state = SinkState::Committing;
140        // Reset after successful commit
141        self.current_tx = None;
142        self.state = SinkState::Idle;
143        self.write_count = 0;
144        self.record_count = 0;
145
146        Ok(())
147    }
148
149    /// Rollback the current transaction
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if the transaction ID doesn't match.
154    pub fn rollback(&mut self, tx_id: &TransactionId) -> Result<(), SinkError> {
155        self.validate_tx_id(tx_id)?;
156
157        self.current_tx = None;
158        self.state = SinkState::Idle;
159        self.write_count = 0;
160        self.record_count = 0;
161
162        Ok(())
163    }
164
165    /// Force rollback without validation (for recovery)
166    pub fn force_rollback(&mut self) {
167        self.current_tx = None;
168        self.state = SinkState::Idle;
169        self.write_count = 0;
170        self.record_count = 0;
171    }
172
173    /// Mark the state as error
174    pub fn mark_error(&mut self) {
175        self.state = SinkState::Error;
176    }
177
178    /// Validate that the given transaction ID matches the current one
179    fn validate_tx_id(&self, tx_id: &TransactionId) -> Result<(), SinkError> {
180        match &self.current_tx {
181            Some(current) if current == tx_id => Ok(()),
182            Some(current) => Err(SinkError::TransactionIdMismatch {
183                expected: current.to_string(),
184                actual: tx_id.to_string(),
185            }),
186            None => Err(SinkError::NoActiveTransaction),
187        }
188    }
189}
190
191impl Default for TransactionState {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// Coordinator for two-phase commit across multiple sinks
198pub struct TransactionCoordinator {
199    /// Current transaction ID
200    current_tx: Option<TransactionId>,
201
202    /// Transaction counter
203    next_tx_id: AtomicU64,
204
205    /// Participating sinks (by ID)
206    participants: Vec<String>,
207
208    /// Sinks that have voted to prepare
209    prepared: Vec<String>,
210
211    /// Sinks that have committed
212    committed: Vec<String>,
213}
214
215impl TransactionCoordinator {
216    /// Create a new transaction coordinator
217    #[must_use]
218    pub fn new() -> Self {
219        Self {
220            current_tx: None,
221            next_tx_id: AtomicU64::new(1),
222            participants: Vec::new(),
223            prepared: Vec::new(),
224            committed: Vec::new(),
225        }
226    }
227
228    /// Register a participant sink
229    pub fn register_participant(&mut self, sink_id: String) {
230        if !self.participants.contains(&sink_id) {
231            self.participants.push(sink_id);
232        }
233    }
234
235    /// Begin a new coordinated transaction
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if a transaction is already active.
240    pub fn begin(&mut self) -> Result<TransactionId, SinkError> {
241        if let Some(ref tx) = self.current_tx {
242            return Err(SinkError::TransactionAlreadyActive(tx.to_string()));
243        }
244
245        let tx_id = TransactionId::new(self.next_tx_id.fetch_add(1, Ordering::SeqCst));
246        self.current_tx = Some(tx_id.clone());
247        self.prepared.clear();
248        self.committed.clear();
249
250        Ok(tx_id)
251    }
252
253    /// Record that a sink has prepared
254    ///
255    /// # Errors
256    ///
257    /// Returns an error if the sink is not registered.
258    pub fn mark_prepared(&mut self, sink_id: &str) -> Result<(), SinkError> {
259        if !self.participants.contains(&sink_id.to_string()) {
260            return Err(SinkError::ConfigurationError(format!(
261                "Unknown sink: {sink_id}"
262            )));
263        }
264
265        if !self.prepared.contains(&sink_id.to_string()) {
266            self.prepared.push(sink_id.to_string());
267        }
268
269        Ok(())
270    }
271
272    /// Check if all participants are prepared
273    #[must_use]
274    pub fn all_prepared(&self) -> bool {
275        self.prepared.len() == self.participants.len()
276    }
277
278    /// Record that a sink has committed
279    pub fn mark_committed(&mut self, sink_id: &str) {
280        if !self.committed.contains(&sink_id.to_string()) {
281            self.committed.push(sink_id.to_string());
282        }
283    }
284
285    /// Check if all participants have committed
286    #[must_use]
287    pub fn all_committed(&self) -> bool {
288        self.committed.len() == self.participants.len()
289    }
290
291    /// Complete the transaction
292    pub fn complete(&mut self) {
293        self.current_tx = None;
294        self.prepared.clear();
295        self.committed.clear();
296    }
297
298    /// Get the current transaction ID
299    #[must_use]
300    pub fn current_transaction(&self) -> Option<&TransactionId> {
301        self.current_tx.as_ref()
302    }
303
304    /// Get the list of participants
305    #[must_use]
306    pub fn participants(&self) -> &[String] {
307        &self.participants
308    }
309}
310
311impl Default for TransactionCoordinator {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317/// Trait for sinks that support two-phase commit
318pub trait TwoPhaseCommitSink: Send {
319    /// Prepare the transaction for commit (2PC phase 1)
320    ///
321    /// After prepare, the sink should be ready to commit atomically.
322    ///
323    /// # Errors
324    ///
325    /// Returns an error if the sink cannot prepare.
326    fn prepare(&mut self, tx_id: &TransactionId) -> Result<(), SinkError>;
327
328    /// Commit a prepared transaction (2PC phase 2)
329    ///
330    /// # Errors
331    ///
332    /// Returns an error if the commit fails.
333    fn commit_prepared(&mut self, tx_id: &TransactionId) -> Result<(), SinkError>;
334
335    /// Abort a prepared transaction
336    ///
337    /// # Errors
338    ///
339    /// Returns an error if the abort fails.
340    fn abort_prepared(&mut self, tx_id: &TransactionId) -> Result<(), SinkError>;
341
342    /// Recover pending transactions after restart
343    ///
344    /// Returns a list of transaction IDs that were prepared but not committed.
345    ///
346    /// # Errors
347    ///
348    /// Returns an error if recovery fails.
349    fn recover_pending(&mut self) -> Result<Vec<TransactionId>, SinkError>;
350}
351
352/// Buffer for transactional writes
353#[derive(Debug, Default)]
354#[allow(dead_code)] // Public API for Phase 3 connector implementations
355pub struct TransactionBuffer {
356    /// Buffered outputs
357    outputs: Vec<Output>,
358
359    /// Total size estimate
360    size_bytes: usize,
361}
362
363#[allow(dead_code)] // Public API for Phase 3 connector implementations
364impl TransactionBuffer {
365    /// Create a new transaction buffer
366    #[must_use]
367    pub fn new() -> Self {
368        Self::default()
369    }
370
371    /// Create with pre-allocated capacity
372    #[must_use]
373    pub fn with_capacity(capacity: usize) -> Self {
374        Self {
375            outputs: Vec::with_capacity(capacity),
376            size_bytes: 0,
377        }
378    }
379
380    /// Add outputs to the buffer
381    pub fn push(&mut self, outputs: Vec<Output>) {
382        // Estimate size (rough approximation)
383        for output in &outputs {
384            self.size_bytes += match output {
385                Output::Event(e) => e.data.get_array_memory_size(),
386                Output::Changelog(c) => c.event.data.get_array_memory_size() + 32,
387                _ => 32,
388            };
389        }
390        self.outputs.extend(outputs);
391    }
392
393    /// Get the number of outputs
394    #[must_use]
395    pub fn len(&self) -> usize {
396        self.outputs.len()
397    }
398
399    /// Check if the buffer is empty
400    #[must_use]
401    pub fn is_empty(&self) -> bool {
402        self.outputs.is_empty()
403    }
404
405    /// Get the estimated size in bytes
406    #[must_use]
407    pub fn size_bytes(&self) -> usize {
408        self.size_bytes
409    }
410
411    /// Take all buffered outputs
412    pub fn take(&mut self) -> Vec<Output> {
413        self.size_bytes = 0;
414        std::mem::take(&mut self.outputs)
415    }
416
417    /// Clear the buffer
418    pub fn clear(&mut self) {
419        self.outputs.clear();
420        self.size_bytes = 0;
421    }
422
423    /// Get a reference to the buffered outputs
424    #[must_use]
425    pub fn outputs(&self) -> &[Output] {
426        &self.outputs
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_transaction_state_new() {
436        let state = TransactionState::new();
437        assert!(state.is_idle());
438        assert!(!state.is_active());
439        assert!(state.current_transaction().is_none());
440    }
441
442    #[test]
443    fn test_transaction_state_begin() {
444        let mut state = TransactionState::new();
445        let tx_id = TransactionId::new(1);
446
447        state.begin(tx_id.clone()).unwrap();
448
449        assert!(state.is_active());
450        assert!(!state.is_idle());
451        assert_eq!(state.current_transaction(), Some(&tx_id));
452    }
453
454    #[test]
455    fn test_transaction_state_begin_new() {
456        let mut state = TransactionState::new();
457
458        let tx1 = state.begin_new().unwrap();
459        state.commit(&tx1).unwrap();
460
461        let tx2 = state.begin_new().unwrap();
462        assert_ne!(tx1.id(), tx2.id());
463    }
464
465    #[test]
466    fn test_transaction_state_double_begin() {
467        let mut state = TransactionState::new();
468        state.begin(TransactionId::new(1)).unwrap();
469
470        let result = state.begin(TransactionId::new(2));
471        assert!(matches!(
472            result,
473            Err(SinkError::TransactionAlreadyActive(_))
474        ));
475    }
476
477    #[test]
478    fn test_transaction_state_commit() {
479        let mut state = TransactionState::new();
480        let tx_id = TransactionId::new(1);
481
482        state.begin(tx_id.clone()).unwrap();
483        state.commit(&tx_id).unwrap();
484
485        assert!(state.is_idle());
486        assert!(state.current_transaction().is_none());
487    }
488
489    #[test]
490    fn test_transaction_state_commit_wrong_id() {
491        let mut state = TransactionState::new();
492        state.begin(TransactionId::new(1)).unwrap();
493
494        let result = state.commit(&TransactionId::new(2));
495        assert!(matches!(
496            result,
497            Err(SinkError::TransactionIdMismatch { .. })
498        ));
499    }
500
501    #[test]
502    fn test_transaction_state_rollback() {
503        let mut state = TransactionState::new();
504        let tx_id = TransactionId::new(1);
505
506        state.begin(tx_id.clone()).unwrap();
507        state.record_write(100).unwrap();
508        state.rollback(&tx_id).unwrap();
509
510        assert!(state.is_idle());
511        assert_eq!(state.write_count(), 0);
512        assert_eq!(state.record_count(), 0);
513    }
514
515    #[test]
516    fn test_transaction_state_force_rollback() {
517        let mut state = TransactionState::new();
518        state.begin(TransactionId::new(1)).unwrap();
519        state.mark_error();
520
521        state.force_rollback();
522
523        assert!(state.is_idle());
524    }
525
526    #[test]
527    fn test_transaction_coordinator_basic() {
528        let mut coord = TransactionCoordinator::new();
529
530        coord.register_participant("sink-1".to_string());
531        coord.register_participant("sink-2".to_string());
532
533        let _tx_id = coord.begin().unwrap();
534        assert!(coord.current_transaction().is_some());
535
536        coord.mark_prepared("sink-1").unwrap();
537        assert!(!coord.all_prepared());
538
539        coord.mark_prepared("sink-2").unwrap();
540        assert!(coord.all_prepared());
541
542        coord.mark_committed("sink-1");
543        assert!(!coord.all_committed());
544
545        coord.mark_committed("sink-2");
546        assert!(coord.all_committed());
547
548        coord.complete();
549        assert!(coord.current_transaction().is_none());
550    }
551
552    #[test]
553    fn test_transaction_buffer() {
554        use crate::operator::Event;
555        use arrow_array::{Int64Array, RecordBatch};
556        use std::sync::Arc;
557
558        let mut buffer = TransactionBuffer::new();
559
560        let array = Arc::new(Int64Array::from(vec![1, 2, 3]));
561        let batch = RecordBatch::try_from_iter(vec![("col", array as _)]).unwrap();
562        let event = Event::new(1000, batch);
563
564        buffer.push(vec![Output::Event(event)]);
565
566        assert_eq!(buffer.len(), 1);
567        assert!(!buffer.is_empty());
568        assert!(buffer.size_bytes() > 0);
569
570        let outputs = buffer.take();
571        assert_eq!(outputs.len(), 1);
572        assert!(buffer.is_empty());
573    }
574}