Skip to main content

heliosdb_proxy/
transaction_journal.rs

1//! Transaction Journal - TR (Transaction Replay)
2//!
3//! Logs all statements within a transaction for replay after failover.
4//! Enables Oracle-grade TAF+TAC merged functionality.
5
6use super::{NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12/// Journal entry for a single statement
13#[derive(Debug, Clone)]
14pub struct JournalEntry {
15    /// Entry sequence number
16    pub sequence: u64,
17    /// SQL statement text
18    pub statement: String,
19    /// Bound parameters
20    pub parameters: Vec<JournalValue>,
21    /// Result checksum (for verification after replay)
22    pub result_checksum: Option<u64>,
23    /// Number of rows affected
24    pub rows_affected: Option<u64>,
25    /// Timestamp
26    pub timestamp: chrono::DateTime<chrono::Utc>,
27    /// Statement type
28    pub statement_type: StatementType,
29    /// Execution duration (ms)
30    pub duration_ms: u64,
31}
32
33/// Serializable parameter value
34#[derive(Debug, Clone)]
35pub enum JournalValue {
36    Null,
37    Bool(bool),
38    Int64(i64),
39    Float64(f64),
40    Text(String),
41    Bytes(Vec<u8>),
42    Array(Vec<JournalValue>),
43}
44
45/// Statement type classification
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48    /// SELECT query
49    Select,
50    /// INSERT statement
51    Insert,
52    /// UPDATE statement
53    Update,
54    /// DELETE statement
55    Delete,
56    /// DDL (CREATE, ALTER, DROP)
57    Ddl,
58    /// Transaction control (BEGIN, COMMIT, ROLLBACK)
59    Transaction,
60    /// SET statement
61    Set,
62    /// Other/unknown
63    Other,
64}
65
66impl StatementType {
67    /// Determine statement type from SQL
68    pub fn from_sql(sql: &str) -> Self {
69        let upper = sql.trim().to_uppercase();
70        if upper.starts_with("SELECT") {
71            StatementType::Select
72        } else if upper.starts_with("INSERT") {
73            StatementType::Insert
74        } else if upper.starts_with("UPDATE") {
75            StatementType::Update
76        } else if upper.starts_with("DELETE") {
77            StatementType::Delete
78        } else if upper.starts_with("CREATE")
79            || upper.starts_with("ALTER")
80            || upper.starts_with("DROP")
81        {
82            StatementType::Ddl
83        } else if upper.starts_with("BEGIN")
84            || upper.starts_with("COMMIT")
85            || upper.starts_with("ROLLBACK")
86            || upper.starts_with("SAVEPOINT")
87        {
88            StatementType::Transaction
89        } else if upper.starts_with("SET") {
90            StatementType::Set
91        } else {
92            StatementType::Other
93        }
94    }
95
96    /// Is this a read-only statement?
97    pub fn is_read_only(&self) -> bool {
98        matches!(self, StatementType::Select)
99    }
100
101    /// Is this a mutating statement?
102    pub fn is_mutation(&self) -> bool {
103        matches!(
104            self,
105            StatementType::Insert
106                | StatementType::Update
107                | StatementType::Delete
108                | StatementType::Ddl
109        )
110    }
111}
112
113/// Transaction journal for a single transaction
114#[derive(Debug, Clone)]
115pub struct TransactionJournalEntry {
116    /// Transaction ID
117    pub tx_id: Uuid,
118    /// Session ID
119    pub session_id: Uuid,
120    /// Node where transaction started
121    pub node_id: NodeId,
122    /// Transaction start time
123    pub started_at: chrono::DateTime<chrono::Utc>,
124    /// Start LSN (for WAL synchronization)
125    pub start_lsn: u64,
126    /// Journal entries
127    pub entries: Vec<JournalEntry>,
128    /// Current sequence
129    pub current_sequence: u64,
130    /// Is transaction active
131    pub active: bool,
132    /// Has mutations
133    pub has_mutations: bool,
134    /// Savepoints
135    pub savepoints: Vec<Savepoint>,
136}
137
138/// Savepoint information
139#[derive(Debug, Clone)]
140pub struct Savepoint {
141    /// Savepoint name
142    pub name: String,
143    /// Sequence at savepoint
144    pub sequence: u64,
145    /// Created timestamp
146    pub created_at: chrono::DateTime<chrono::Utc>,
147}
148
149impl TransactionJournalEntry {
150    /// Create a new transaction journal entry
151    pub fn new(tx_id: Uuid, session_id: Uuid, node_id: NodeId, start_lsn: u64) -> Self {
152        Self {
153            tx_id,
154            session_id,
155            node_id,
156            started_at: chrono::Utc::now(),
157            start_lsn,
158            entries: Vec::new(),
159            current_sequence: 0,
160            active: true,
161            has_mutations: false,
162            savepoints: Vec::new(),
163        }
164    }
165
166    /// Add an entry to the journal
167    pub fn add_entry(&mut self, entry: JournalEntry) {
168        if entry.statement_type.is_mutation() {
169            self.has_mutations = true;
170        }
171        self.current_sequence = entry.sequence;
172        self.entries.push(entry);
173    }
174
175    /// Create a savepoint
176    pub fn create_savepoint(&mut self, name: String) {
177        self.savepoints.push(Savepoint {
178            name,
179            sequence: self.current_sequence,
180            created_at: chrono::Utc::now(),
181        });
182    }
183
184    /// Rollback to savepoint
185    pub fn rollback_to_savepoint(&mut self, name: &str) -> Option<u64> {
186        if let Some(idx) = self.savepoints.iter().position(|s| s.name == name) {
187            let savepoint = &self.savepoints[idx];
188            let sequence = savepoint.sequence;
189
190            // Truncate entries after savepoint
191            self.entries.retain(|e| e.sequence <= sequence);
192
193            // Remove later savepoints
194            self.savepoints.truncate(idx + 1);
195
196            Some(sequence)
197        } else {
198            None
199        }
200    }
201
202    /// Get entries for replay
203    pub fn entries_for_replay(&self) -> Vec<&JournalEntry> {
204        self.entries.iter().collect()
205    }
206
207    /// Get only mutation entries
208    pub fn mutation_entries(&self) -> Vec<&JournalEntry> {
209        self.entries
210            .iter()
211            .filter(|e| e.statement_type.is_mutation())
212            .collect()
213    }
214
215    /// Calculate total size of journal
216    pub fn total_size(&self) -> usize {
217        self.entries
218            .iter()
219            .map(|e| e.statement.len() + estimate_params_size(&e.parameters))
220            .sum()
221    }
222}
223
224fn estimate_params_size(params: &[JournalValue]) -> usize {
225    params
226        .iter()
227        .map(|p| match p {
228            JournalValue::Null => 1,
229            JournalValue::Bool(_) => 1,
230            JournalValue::Int64(_) => 8,
231            JournalValue::Float64(_) => 8,
232            JournalValue::Text(s) => s.len(),
233            JournalValue::Bytes(b) => b.len(),
234            JournalValue::Array(a) => estimate_params_size(a),
235        })
236        .sum()
237}
238
239/// Transaction Journal Manager
240pub struct TransactionJournal {
241    /// Active transaction journals
242    journals: Arc<RwLock<HashMap<Uuid, TransactionJournalEntry>>>,
243    /// Maximum entries per journal
244    max_entries: usize,
245    /// Maximum journal size (bytes)
246    max_size: usize,
247    /// Whether journaling is enabled
248    enabled: bool,
249}
250
251impl TransactionJournal {
252    /// Create a new transaction journal manager
253    pub fn new() -> Self {
254        Self {
255            journals: Arc::new(RwLock::new(HashMap::new())),
256            max_entries: 10000,
257            max_size: 64 * 1024 * 1024, // 64MB
258            enabled: true,
259        }
260    }
261
262    /// Configure maximum entries
263    pub fn with_max_entries(mut self, max: usize) -> Self {
264        self.max_entries = max;
265        self
266    }
267
268    /// Configure maximum size
269    pub fn with_max_size(mut self, max: usize) -> Self {
270        self.max_size = max;
271        self
272    }
273
274    /// Enable or disable journaling
275    pub fn set_enabled(&mut self, enabled: bool) {
276        self.enabled = enabled;
277    }
278
279    /// Collect every journal entry across every active transaction
280    /// whose `timestamp` falls within the inclusive window
281    /// `[from, to]`. Results are sorted in timestamp order so the
282    /// caller can replay them chronologically regardless of which
283    /// transaction they came from.
284    ///
285    /// Used by the time-travel replay engine (`src/replay/`) to
286    /// reconstruct "what happened at the source between these two
287    /// timestamps" against a staging target.
288    pub async fn entries_in_window(
289        &self,
290        from: chrono::DateTime<chrono::Utc>,
291        to: chrono::DateTime<chrono::Utc>,
292    ) -> Vec<(Uuid, JournalEntry)> {
293        let journals = self.journals.read().await;
294        let mut out: Vec<(Uuid, JournalEntry)> = Vec::new();
295        for (tx_id, j) in journals.iter() {
296            for entry in &j.entries {
297                if entry.timestamp >= from && entry.timestamp <= to {
298                    out.push((*tx_id, entry.clone()));
299                }
300            }
301        }
302        out.sort_by_key(|(_, e)| e.timestamp);
303        out
304    }
305
306    /// Start journaling a transaction
307    pub async fn begin_transaction(
308        &self,
309        tx_id: Uuid,
310        session_id: Uuid,
311        node_id: NodeId,
312        start_lsn: u64,
313    ) -> Result<()> {
314        if !self.enabled {
315            return Ok(());
316        }
317
318        let journal = TransactionJournalEntry::new(tx_id, session_id, node_id, start_lsn);
319        self.journals.write().await.insert(tx_id, journal);
320
321        tracing::debug!("Started journaling transaction {:?}", tx_id);
322        Ok(())
323    }
324
325    /// Log a statement
326    pub async fn log_statement(
327        &self,
328        tx_id: Uuid,
329        statement: String,
330        parameters: Vec<JournalValue>,
331        result_checksum: Option<u64>,
332        rows_affected: Option<u64>,
333        duration_ms: u64,
334    ) -> Result<()> {
335        if !self.enabled {
336            return Ok(());
337        }
338
339        let mut journals = self.journals.write().await;
340        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
341            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
342        })?;
343
344        // Check limits
345        if journal.entries.len() >= self.max_entries {
346            return Err(ProxyError::Internal(
347                "Transaction journal entries limit exceeded".to_string(),
348            ));
349        }
350
351        if journal.total_size() >= self.max_size {
352            return Err(ProxyError::Internal(
353                "Transaction journal size limit exceeded".to_string(),
354            ));
355        }
356
357        let sequence = journal.current_sequence + 1;
358        let statement_type = StatementType::from_sql(&statement);
359
360        let entry = JournalEntry {
361            sequence,
362            statement,
363            parameters,
364            result_checksum,
365            rows_affected,
366            timestamp: chrono::Utc::now(),
367            statement_type,
368            duration_ms,
369        };
370
371        journal.add_entry(entry);
372
373        Ok(())
374    }
375
376    /// Create a savepoint
377    pub async fn create_savepoint(&self, tx_id: Uuid, name: String) -> Result<()> {
378        if !self.enabled {
379            return Ok(());
380        }
381
382        let mut journals = self.journals.write().await;
383        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
384            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
385        })?;
386
387        journal.create_savepoint(name);
388        Ok(())
389    }
390
391    /// Rollback to savepoint
392    pub async fn rollback_to_savepoint(&self, tx_id: Uuid, name: &str) -> Result<()> {
393        if !self.enabled {
394            return Ok(());
395        }
396
397        let mut journals = self.journals.write().await;
398        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
399            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
400        })?;
401
402        journal
403            .rollback_to_savepoint(name)
404            .ok_or_else(|| ProxyError::Internal(format!("Savepoint '{}' not found", name)))?;
405
406        Ok(())
407    }
408
409    /// Commit transaction (clear journal)
410    pub async fn commit_transaction(&self, tx_id: Uuid) -> Result<()> {
411        self.journals.write().await.remove(&tx_id);
412        tracing::debug!("Committed and cleared journal for transaction {:?}", tx_id);
413        Ok(())
414    }
415
416    /// Rollback transaction (clear journal)
417    pub async fn rollback_transaction(&self, tx_id: Uuid) -> Result<()> {
418        self.journals.write().await.remove(&tx_id);
419        tracing::debug!(
420            "Rolled back and cleared journal for transaction {:?}",
421            tx_id
422        );
423        Ok(())
424    }
425
426    /// Get journal for a transaction (for replay)
427    pub async fn get_journal(&self, tx_id: &Uuid) -> Option<TransactionJournalEntry> {
428        self.journals.read().await.get(tx_id).cloned()
429    }
430
431    /// Get active transaction count
432    pub async fn active_count(&self) -> usize {
433        self.journals.read().await.len()
434    }
435
436    /// Get statistics
437    pub async fn stats(&self) -> JournalStats {
438        let journals = self.journals.read().await;
439        let total_entries: usize = journals.values().map(|j| j.entries.len()).sum();
440        let total_size: usize = journals.values().map(|j| j.total_size()).sum();
441
442        JournalStats {
443            active_transactions: journals.len(),
444            total_entries,
445            total_size_bytes: total_size,
446            enabled: self.enabled,
447        }
448    }
449
450    /// Get all active transaction journals (for failover replay)
451    pub async fn get_all_active(&self) -> Vec<TransactionJournalEntry> {
452        self.journals.read().await.values().cloned().collect()
453    }
454
455    /// Get the maximum start LSN across all active transactions
456    /// Used to determine how far the standby needs to catch up
457    pub async fn get_max_start_lsn(&self) -> Option<u64> {
458        let journals = self.journals.read().await;
459        journals.values().map(|j| j.start_lsn).max()
460    }
461
462    /// Get transactions that started on a specific node
463    /// Useful for replaying only transactions affected by a node failure
464    pub async fn get_transactions_for_node(&self, node_id: NodeId) -> Vec<TransactionJournalEntry> {
465        self.journals
466            .read()
467            .await
468            .values()
469            .filter(|j| j.node_id == node_id)
470            .cloned()
471            .collect()
472    }
473}
474
475impl Default for TransactionJournal {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481/// Journal statistics
482#[derive(Debug, Clone)]
483pub struct JournalStats {
484    /// Number of active transactions being journaled
485    pub active_transactions: usize,
486    /// Total journal entries across all transactions
487    pub total_entries: usize,
488    /// Total size of journals in bytes
489    pub total_size_bytes: usize,
490    /// Whether journaling is enabled
491    pub enabled: bool,
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_statement_type_detection() {
500        assert_eq!(
501            StatementType::from_sql("SELECT * FROM users"),
502            StatementType::Select
503        );
504        assert_eq!(
505            StatementType::from_sql("INSERT INTO users VALUES (1)"),
506            StatementType::Insert
507        );
508        assert_eq!(
509            StatementType::from_sql("UPDATE users SET name = 'x'"),
510            StatementType::Update
511        );
512        assert_eq!(
513            StatementType::from_sql("DELETE FROM users"),
514            StatementType::Delete
515        );
516        assert_eq!(
517            StatementType::from_sql("CREATE TABLE foo (id INT)"),
518            StatementType::Ddl
519        );
520        assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
521        assert_eq!(
522            StatementType::from_sql("SET search_path = public"),
523            StatementType::Set
524        );
525    }
526
527    #[test]
528    fn test_statement_type_properties() {
529        assert!(StatementType::Select.is_read_only());
530        assert!(!StatementType::Insert.is_read_only());
531
532        assert!(StatementType::Insert.is_mutation());
533        assert!(StatementType::Update.is_mutation());
534        assert!(!StatementType::Select.is_mutation());
535    }
536
537    #[tokio::test]
538    async fn test_journal_lifecycle() {
539        let journal = TransactionJournal::new();
540        let tx_id = Uuid::new_v4();
541        let session_id = Uuid::new_v4();
542        let node_id = NodeId::new();
543
544        // Begin transaction
545        journal
546            .begin_transaction(tx_id, session_id, node_id, 0)
547            .await
548            .unwrap();
549
550        // Log statements
551        journal
552            .log_statement(
553                tx_id,
554                "SELECT * FROM users".to_string(),
555                vec![],
556                Some(12345),
557                None,
558                10,
559            )
560            .await
561            .unwrap();
562
563        journal
564            .log_statement(
565                tx_id,
566                "INSERT INTO users (name) VALUES ($1)".to_string(),
567                vec![JournalValue::Text("test".to_string())],
568                None,
569                Some(1),
570                5,
571            )
572            .await
573            .unwrap();
574
575        // Check journal
576        let j = journal.get_journal(&tx_id).await.unwrap();
577        assert_eq!(j.entries.len(), 2);
578        assert!(j.has_mutations);
579
580        // Commit
581        journal.commit_transaction(tx_id).await.unwrap();
582        assert!(journal.get_journal(&tx_id).await.is_none());
583    }
584
585    #[tokio::test]
586    async fn test_savepoints() {
587        let journal = TransactionJournal::new();
588        let tx_id = Uuid::new_v4();
589        let session_id = Uuid::new_v4();
590        let node_id = NodeId::new();
591
592        journal
593            .begin_transaction(tx_id, session_id, node_id, 0)
594            .await
595            .unwrap();
596
597        // Log some statements
598        for i in 0..3 {
599            journal
600                .log_statement(
601                    tx_id,
602                    format!("INSERT INTO t VALUES ({})", i),
603                    vec![],
604                    None,
605                    Some(1),
606                    1,
607                )
608                .await
609                .unwrap();
610        }
611
612        // Create savepoint
613        journal
614            .create_savepoint(tx_id, "sp1".to_string())
615            .await
616            .unwrap();
617
618        // Log more
619        for i in 3..5 {
620            journal
621                .log_statement(
622                    tx_id,
623                    format!("INSERT INTO t VALUES ({})", i),
624                    vec![],
625                    None,
626                    Some(1),
627                    1,
628                )
629                .await
630                .unwrap();
631        }
632
633        let j = journal.get_journal(&tx_id).await.unwrap();
634        assert_eq!(j.entries.len(), 5);
635
636        // Rollback to savepoint
637        journal.rollback_to_savepoint(tx_id, "sp1").await.unwrap();
638
639        let j = journal.get_journal(&tx_id).await.unwrap();
640        assert_eq!(j.entries.len(), 3);
641    }
642
643    #[tokio::test]
644    async fn test_stats() {
645        let journal = TransactionJournal::new();
646        let tx_id = Uuid::new_v4();
647        let session_id = Uuid::new_v4();
648        let node_id = NodeId::new();
649
650        journal
651            .begin_transaction(tx_id, session_id, node_id, 0)
652            .await
653            .unwrap();
654        journal
655            .log_statement(tx_id, "SELECT 1".to_string(), vec![], None, None, 1)
656            .await
657            .unwrap();
658
659        let stats = journal.stats().await;
660        assert_eq!(stats.active_transactions, 1);
661        assert_eq!(stats.total_entries, 1);
662        assert!(stats.enabled);
663    }
664}