Skip to main content

heliosdb_proxy/
transaction_journal.rs

1//! Transaction Journal - TR (Transaction Replay)
2//!
3//! Logs all statements within a transaction for replay after failover.
4//! Enables Oracle-grade TAF+TAC merged functionality.
5
6use super::{NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12/// Journal entry for a single statement
13#[derive(Debug, Clone)]
14pub struct JournalEntry {
15    /// Entry sequence number
16    pub sequence: u64,
17    /// SQL statement text
18    pub statement: String,
19    /// Bound parameters
20    pub parameters: Vec<JournalValue>,
21    /// Result checksum (for verification after replay)
22    pub result_checksum: Option<u64>,
23    /// Number of rows affected
24    pub rows_affected: Option<u64>,
25    /// Timestamp
26    pub timestamp: chrono::DateTime<chrono::Utc>,
27    /// Statement type
28    pub statement_type: StatementType,
29    /// Execution duration (ms)
30    pub duration_ms: u64,
31}
32
33/// Serializable parameter value
34#[derive(Debug, Clone)]
35pub enum JournalValue {
36    Null,
37    Bool(bool),
38    Int64(i64),
39    Float64(f64),
40    Text(String),
41    Bytes(Vec<u8>),
42    Array(Vec<JournalValue>),
43}
44
45/// Statement type classification
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48    /// SELECT query
49    Select,
50    /// INSERT statement
51    Insert,
52    /// UPDATE statement
53    Update,
54    /// DELETE statement
55    Delete,
56    /// DDL (CREATE, ALTER, DROP)
57    Ddl,
58    /// Transaction control (BEGIN, COMMIT, ROLLBACK)
59    Transaction,
60    /// SET statement
61    Set,
62    /// Other/unknown
63    Other,
64}
65
66impl StatementType {
67    /// Determine statement type from SQL
68    pub fn from_sql(sql: &str) -> Self {
69        let upper = sql.trim().to_uppercase();
70        if upper.starts_with("SELECT") {
71            StatementType::Select
72        } else if upper.starts_with("INSERT") {
73            StatementType::Insert
74        } else if upper.starts_with("UPDATE") {
75            StatementType::Update
76        } else if upper.starts_with("DELETE") {
77            StatementType::Delete
78        } else if upper.starts_with("CREATE")
79            || upper.starts_with("ALTER")
80            || upper.starts_with("DROP")
81        {
82            StatementType::Ddl
83        } else if upper.starts_with("BEGIN")
84            || upper.starts_with("COMMIT")
85            || upper.starts_with("ROLLBACK")
86            || upper.starts_with("SAVEPOINT")
87        {
88            StatementType::Transaction
89        } else if upper.starts_with("SET") {
90            StatementType::Set
91        } else {
92            StatementType::Other
93        }
94    }
95
96    /// Is this a read-only statement?
97    pub fn is_read_only(&self) -> bool {
98        matches!(self, StatementType::Select)
99    }
100
101    /// Is this a mutating statement?
102    pub fn is_mutation(&self) -> bool {
103        matches!(
104            self,
105            StatementType::Insert | StatementType::Update | StatementType::Delete | StatementType::Ddl
106        )
107    }
108}
109
110/// Transaction journal for a single transaction
111#[derive(Debug, Clone)]
112pub struct TransactionJournalEntry {
113    /// Transaction ID
114    pub tx_id: Uuid,
115    /// Session ID
116    pub session_id: Uuid,
117    /// Node where transaction started
118    pub node_id: NodeId,
119    /// Transaction start time
120    pub started_at: chrono::DateTime<chrono::Utc>,
121    /// Start LSN (for WAL synchronization)
122    pub start_lsn: u64,
123    /// Journal entries
124    pub entries: Vec<JournalEntry>,
125    /// Current sequence
126    pub current_sequence: u64,
127    /// Is transaction active
128    pub active: bool,
129    /// Has mutations
130    pub has_mutations: bool,
131    /// Savepoints
132    pub savepoints: Vec<Savepoint>,
133}
134
135/// Savepoint information
136#[derive(Debug, Clone)]
137pub struct Savepoint {
138    /// Savepoint name
139    pub name: String,
140    /// Sequence at savepoint
141    pub sequence: u64,
142    /// Created timestamp
143    pub created_at: chrono::DateTime<chrono::Utc>,
144}
145
146impl TransactionJournalEntry {
147    /// Create a new transaction journal entry
148    pub fn new(tx_id: Uuid, session_id: Uuid, node_id: NodeId, start_lsn: u64) -> Self {
149        Self {
150            tx_id,
151            session_id,
152            node_id,
153            started_at: chrono::Utc::now(),
154            start_lsn,
155            entries: Vec::new(),
156            current_sequence: 0,
157            active: true,
158            has_mutations: false,
159            savepoints: Vec::new(),
160        }
161    }
162
163    /// Add an entry to the journal
164    pub fn add_entry(&mut self, entry: JournalEntry) {
165        if entry.statement_type.is_mutation() {
166            self.has_mutations = true;
167        }
168        self.current_sequence = entry.sequence;
169        self.entries.push(entry);
170    }
171
172    /// Create a savepoint
173    pub fn create_savepoint(&mut self, name: String) {
174        self.savepoints.push(Savepoint {
175            name,
176            sequence: self.current_sequence,
177            created_at: chrono::Utc::now(),
178        });
179    }
180
181    /// Rollback to savepoint
182    pub fn rollback_to_savepoint(&mut self, name: &str) -> Option<u64> {
183        if let Some(idx) = self.savepoints.iter().position(|s| s.name == name) {
184            let savepoint = &self.savepoints[idx];
185            let sequence = savepoint.sequence;
186
187            // Truncate entries after savepoint
188            self.entries.retain(|e| e.sequence <= sequence);
189
190            // Remove later savepoints
191            self.savepoints.truncate(idx + 1);
192
193            Some(sequence)
194        } else {
195            None
196        }
197    }
198
199    /// Get entries for replay
200    pub fn entries_for_replay(&self) -> Vec<&JournalEntry> {
201        self.entries.iter().collect()
202    }
203
204    /// Get only mutation entries
205    pub fn mutation_entries(&self) -> Vec<&JournalEntry> {
206        self.entries
207            .iter()
208            .filter(|e| e.statement_type.is_mutation())
209            .collect()
210    }
211
212    /// Calculate total size of journal
213    pub fn total_size(&self) -> usize {
214        self.entries
215            .iter()
216            .map(|e| e.statement.len() + estimate_params_size(&e.parameters))
217            .sum()
218    }
219}
220
221fn estimate_params_size(params: &[JournalValue]) -> usize {
222    params
223        .iter()
224        .map(|p| match p {
225            JournalValue::Null => 1,
226            JournalValue::Bool(_) => 1,
227            JournalValue::Int64(_) => 8,
228            JournalValue::Float64(_) => 8,
229            JournalValue::Text(s) => s.len(),
230            JournalValue::Bytes(b) => b.len(),
231            JournalValue::Array(a) => estimate_params_size(a),
232        })
233        .sum()
234}
235
236/// Transaction Journal Manager
237pub struct TransactionJournal {
238    /// Active transaction journals
239    journals: Arc<RwLock<HashMap<Uuid, TransactionJournalEntry>>>,
240    /// Maximum entries per journal
241    max_entries: usize,
242    /// Maximum journal size (bytes)
243    max_size: usize,
244    /// Whether journaling is enabled
245    enabled: bool,
246}
247
248impl TransactionJournal {
249    /// Create a new transaction journal manager
250    pub fn new() -> Self {
251        Self {
252            journals: Arc::new(RwLock::new(HashMap::new())),
253            max_entries: 10000,
254            max_size: 64 * 1024 * 1024, // 64MB
255            enabled: true,
256        }
257    }
258
259    /// Configure maximum entries
260    pub fn with_max_entries(mut self, max: usize) -> Self {
261        self.max_entries = max;
262        self
263    }
264
265    /// Configure maximum size
266    pub fn with_max_size(mut self, max: usize) -> Self {
267        self.max_size = max;
268        self
269    }
270
271    /// Enable or disable journaling
272    pub fn set_enabled(&mut self, enabled: bool) {
273        self.enabled = enabled;
274    }
275
276    /// Collect every journal entry across every active transaction
277    /// whose `timestamp` falls within the inclusive window
278    /// `[from, to]`. Results are sorted in timestamp order so the
279    /// caller can replay them chronologically regardless of which
280    /// transaction they came from.
281    ///
282    /// Used by the time-travel replay engine (`src/replay/`) to
283    /// reconstruct "what happened at the source between these two
284    /// timestamps" against a staging target.
285    pub async fn entries_in_window(
286        &self,
287        from: chrono::DateTime<chrono::Utc>,
288        to: chrono::DateTime<chrono::Utc>,
289    ) -> Vec<(Uuid, JournalEntry)> {
290        let journals = self.journals.read().await;
291        let mut out: Vec<(Uuid, JournalEntry)> = Vec::new();
292        for (tx_id, j) in journals.iter() {
293            for entry in &j.entries {
294                if entry.timestamp >= from && entry.timestamp <= to {
295                    out.push((*tx_id, entry.clone()));
296                }
297            }
298        }
299        out.sort_by_key(|(_, e)| e.timestamp);
300        out
301    }
302
303    /// Start journaling a transaction
304    pub async fn begin_transaction(
305        &self,
306        tx_id: Uuid,
307        session_id: Uuid,
308        node_id: NodeId,
309        start_lsn: u64,
310    ) -> Result<()> {
311        if !self.enabled {
312            return Ok(());
313        }
314
315        let journal = TransactionJournalEntry::new(tx_id, session_id, node_id, start_lsn);
316        self.journals.write().await.insert(tx_id, journal);
317
318        tracing::debug!("Started journaling transaction {:?}", tx_id);
319        Ok(())
320    }
321
322    /// Log a statement
323    pub async fn log_statement(
324        &self,
325        tx_id: Uuid,
326        statement: String,
327        parameters: Vec<JournalValue>,
328        result_checksum: Option<u64>,
329        rows_affected: Option<u64>,
330        duration_ms: u64,
331    ) -> Result<()> {
332        if !self.enabled {
333            return Ok(());
334        }
335
336        let mut journals = self.journals.write().await;
337        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
338            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
339        })?;
340
341        // Check limits
342        if journal.entries.len() >= self.max_entries {
343            return Err(ProxyError::Internal("Transaction journal entries limit exceeded".to_string()));
344        }
345
346        if journal.total_size() >= self.max_size {
347            return Err(ProxyError::Internal("Transaction journal size limit exceeded".to_string()));
348        }
349
350        let sequence = journal.current_sequence + 1;
351        let statement_type = StatementType::from_sql(&statement);
352
353        let entry = JournalEntry {
354            sequence,
355            statement,
356            parameters,
357            result_checksum,
358            rows_affected,
359            timestamp: chrono::Utc::now(),
360            statement_type,
361            duration_ms,
362        };
363
364        journal.add_entry(entry);
365
366        Ok(())
367    }
368
369    /// Create a savepoint
370    pub async fn create_savepoint(&self, tx_id: Uuid, name: String) -> Result<()> {
371        if !self.enabled {
372            return Ok(());
373        }
374
375        let mut journals = self.journals.write().await;
376        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
377            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
378        })?;
379
380        journal.create_savepoint(name);
381        Ok(())
382    }
383
384    /// Rollback to savepoint
385    pub async fn rollback_to_savepoint(&self, tx_id: Uuid, name: &str) -> Result<()> {
386        if !self.enabled {
387            return Ok(());
388        }
389
390        let mut journals = self.journals.write().await;
391        let journal = journals.get_mut(&tx_id).ok_or_else(|| {
392            ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
393        })?;
394
395        journal
396            .rollback_to_savepoint(name)
397            .ok_or_else(|| ProxyError::Internal(format!("Savepoint '{}' not found", name)))?;
398
399        Ok(())
400    }
401
402    /// Commit transaction (clear journal)
403    pub async fn commit_transaction(&self, tx_id: Uuid) -> Result<()> {
404        self.journals.write().await.remove(&tx_id);
405        tracing::debug!("Committed and cleared journal for transaction {:?}", tx_id);
406        Ok(())
407    }
408
409    /// Rollback transaction (clear journal)
410    pub async fn rollback_transaction(&self, tx_id: Uuid) -> Result<()> {
411        self.journals.write().await.remove(&tx_id);
412        tracing::debug!("Rolled back and cleared journal for transaction {:?}", tx_id);
413        Ok(())
414    }
415
416    /// Get journal for a transaction (for replay)
417    pub async fn get_journal(&self, tx_id: &Uuid) -> Option<TransactionJournalEntry> {
418        self.journals.read().await.get(tx_id).cloned()
419    }
420
421    /// Get active transaction count
422    pub async fn active_count(&self) -> usize {
423        self.journals.read().await.len()
424    }
425
426    /// Get statistics
427    pub async fn stats(&self) -> JournalStats {
428        let journals = self.journals.read().await;
429        let total_entries: usize = journals.values().map(|j| j.entries.len()).sum();
430        let total_size: usize = journals.values().map(|j| j.total_size()).sum();
431
432        JournalStats {
433            active_transactions: journals.len(),
434            total_entries,
435            total_size_bytes: total_size,
436            enabled: self.enabled,
437        }
438    }
439
440    /// Get all active transaction journals (for failover replay)
441    pub async fn get_all_active(&self) -> Vec<TransactionJournalEntry> {
442        self.journals.read().await.values().cloned().collect()
443    }
444
445    /// Get the maximum start LSN across all active transactions
446    /// Used to determine how far the standby needs to catch up
447    pub async fn get_max_start_lsn(&self) -> Option<u64> {
448        let journals = self.journals.read().await;
449        journals.values().map(|j| j.start_lsn).max()
450    }
451
452    /// Get transactions that started on a specific node
453    /// Useful for replaying only transactions affected by a node failure
454    pub async fn get_transactions_for_node(&self, node_id: NodeId) -> Vec<TransactionJournalEntry> {
455        self.journals
456            .read()
457            .await
458            .values()
459            .filter(|j| j.node_id == node_id)
460            .cloned()
461            .collect()
462    }
463}
464
465impl Default for TransactionJournal {
466    fn default() -> Self {
467        Self::new()
468    }
469}
470
471/// Journal statistics
472#[derive(Debug, Clone)]
473pub struct JournalStats {
474    /// Number of active transactions being journaled
475    pub active_transactions: usize,
476    /// Total journal entries across all transactions
477    pub total_entries: usize,
478    /// Total size of journals in bytes
479    pub total_size_bytes: usize,
480    /// Whether journaling is enabled
481    pub enabled: bool,
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_statement_type_detection() {
490        assert_eq!(StatementType::from_sql("SELECT * FROM users"), StatementType::Select);
491        assert_eq!(StatementType::from_sql("INSERT INTO users VALUES (1)"), StatementType::Insert);
492        assert_eq!(StatementType::from_sql("UPDATE users SET name = 'x'"), StatementType::Update);
493        assert_eq!(StatementType::from_sql("DELETE FROM users"), StatementType::Delete);
494        assert_eq!(StatementType::from_sql("CREATE TABLE foo (id INT)"), StatementType::Ddl);
495        assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
496        assert_eq!(StatementType::from_sql("SET search_path = public"), StatementType::Set);
497    }
498
499    #[test]
500    fn test_statement_type_properties() {
501        assert!(StatementType::Select.is_read_only());
502        assert!(!StatementType::Insert.is_read_only());
503
504        assert!(StatementType::Insert.is_mutation());
505        assert!(StatementType::Update.is_mutation());
506        assert!(!StatementType::Select.is_mutation());
507    }
508
509    #[tokio::test]
510    async fn test_journal_lifecycle() {
511        let journal = TransactionJournal::new();
512        let tx_id = Uuid::new_v4();
513        let session_id = Uuid::new_v4();
514        let node_id = NodeId::new();
515
516        // Begin transaction
517        journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
518
519        // Log statements
520        journal.log_statement(
521            tx_id,
522            "SELECT * FROM users".to_string(),
523            vec![],
524            Some(12345),
525            None,
526            10,
527        ).await.unwrap();
528
529        journal.log_statement(
530            tx_id,
531            "INSERT INTO users (name) VALUES ($1)".to_string(),
532            vec![JournalValue::Text("test".to_string())],
533            None,
534            Some(1),
535            5,
536        ).await.unwrap();
537
538        // Check journal
539        let j = journal.get_journal(&tx_id).await.unwrap();
540        assert_eq!(j.entries.len(), 2);
541        assert!(j.has_mutations);
542
543        // Commit
544        journal.commit_transaction(tx_id).await.unwrap();
545        assert!(journal.get_journal(&tx_id).await.is_none());
546    }
547
548    #[tokio::test]
549    async fn test_savepoints() {
550        let journal = TransactionJournal::new();
551        let tx_id = Uuid::new_v4();
552        let session_id = Uuid::new_v4();
553        let node_id = NodeId::new();
554
555        journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
556
557        // Log some statements
558        for i in 0..3 {
559            journal.log_statement(
560                tx_id,
561                format!("INSERT INTO t VALUES ({})", i),
562                vec![],
563                None,
564                Some(1),
565                1,
566            ).await.unwrap();
567        }
568
569        // Create savepoint
570        journal.create_savepoint(tx_id, "sp1".to_string()).await.unwrap();
571
572        // Log more
573        for i in 3..5 {
574            journal.log_statement(
575                tx_id,
576                format!("INSERT INTO t VALUES ({})", i),
577                vec![],
578                None,
579                Some(1),
580                1,
581            ).await.unwrap();
582        }
583
584        let j = journal.get_journal(&tx_id).await.unwrap();
585        assert_eq!(j.entries.len(), 5);
586
587        // Rollback to savepoint
588        journal.rollback_to_savepoint(tx_id, "sp1").await.unwrap();
589
590        let j = journal.get_journal(&tx_id).await.unwrap();
591        assert_eq!(j.entries.len(), 3);
592    }
593
594    #[tokio::test]
595    async fn test_stats() {
596        let journal = TransactionJournal::new();
597        let tx_id = Uuid::new_v4();
598        let session_id = Uuid::new_v4();
599        let node_id = NodeId::new();
600
601        journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
602        journal.log_statement(
603            tx_id,
604            "SELECT 1".to_string(),
605            vec![],
606            None,
607            None,
608            1,
609        ).await.unwrap();
610
611        let stats = journal.stats().await;
612        assert_eq!(stats.active_transactions, 1);
613        assert_eq!(stats.total_entries, 1);
614        assert!(stats.enabled);
615    }
616}