1use super::{NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
14pub struct JournalEntry {
15 pub sequence: u64,
17 pub statement: String,
19 pub parameters: Vec<JournalValue>,
21 pub result_checksum: Option<u64>,
23 pub rows_affected: Option<u64>,
25 pub timestamp: chrono::DateTime<chrono::Utc>,
27 pub statement_type: StatementType,
29 pub duration_ms: u64,
31}
32
33#[derive(Debug, Clone)]
35pub enum JournalValue {
36 Null,
37 Bool(bool),
38 Int64(i64),
39 Float64(f64),
40 Text(String),
41 Bytes(Vec<u8>),
42 Array(Vec<JournalValue>),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48 Select,
50 Insert,
52 Update,
54 Delete,
56 Ddl,
58 Transaction,
60 Set,
62 Other,
64}
65
66impl StatementType {
67 pub fn from_sql(sql: &str) -> Self {
69 let upper = sql.trim().to_uppercase();
70 if upper.starts_with("SELECT") {
71 StatementType::Select
72 } else if upper.starts_with("INSERT") {
73 StatementType::Insert
74 } else if upper.starts_with("UPDATE") {
75 StatementType::Update
76 } else if upper.starts_with("DELETE") {
77 StatementType::Delete
78 } else if upper.starts_with("CREATE")
79 || upper.starts_with("ALTER")
80 || upper.starts_with("DROP")
81 {
82 StatementType::Ddl
83 } else if upper.starts_with("BEGIN")
84 || upper.starts_with("COMMIT")
85 || upper.starts_with("ROLLBACK")
86 || upper.starts_with("SAVEPOINT")
87 {
88 StatementType::Transaction
89 } else if upper.starts_with("SET") {
90 StatementType::Set
91 } else {
92 StatementType::Other
93 }
94 }
95
96 pub fn is_read_only(&self) -> bool {
98 matches!(self, StatementType::Select)
99 }
100
101 pub fn is_mutation(&self) -> bool {
103 matches!(
104 self,
105 StatementType::Insert
106 | StatementType::Update
107 | StatementType::Delete
108 | StatementType::Ddl
109 )
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct TransactionJournalEntry {
116 pub tx_id: Uuid,
118 pub session_id: Uuid,
120 pub node_id: NodeId,
122 pub started_at: chrono::DateTime<chrono::Utc>,
124 pub start_lsn: u64,
126 pub entries: Vec<JournalEntry>,
128 pub current_sequence: u64,
130 pub active: bool,
132 pub has_mutations: bool,
134 pub savepoints: Vec<Savepoint>,
136}
137
138#[derive(Debug, Clone)]
140pub struct Savepoint {
141 pub name: String,
143 pub sequence: u64,
145 pub created_at: chrono::DateTime<chrono::Utc>,
147}
148
149impl TransactionJournalEntry {
150 pub fn new(tx_id: Uuid, session_id: Uuid, node_id: NodeId, start_lsn: u64) -> Self {
152 Self {
153 tx_id,
154 session_id,
155 node_id,
156 started_at: chrono::Utc::now(),
157 start_lsn,
158 entries: Vec::new(),
159 current_sequence: 0,
160 active: true,
161 has_mutations: false,
162 savepoints: Vec::new(),
163 }
164 }
165
166 pub fn add_entry(&mut self, entry: JournalEntry) {
168 if entry.statement_type.is_mutation() {
169 self.has_mutations = true;
170 }
171 self.current_sequence = entry.sequence;
172 self.entries.push(entry);
173 }
174
175 pub fn create_savepoint(&mut self, name: String) {
177 self.savepoints.push(Savepoint {
178 name,
179 sequence: self.current_sequence,
180 created_at: chrono::Utc::now(),
181 });
182 }
183
184 pub fn rollback_to_savepoint(&mut self, name: &str) -> Option<u64> {
186 if let Some(idx) = self.savepoints.iter().position(|s| s.name == name) {
187 let savepoint = &self.savepoints[idx];
188 let sequence = savepoint.sequence;
189
190 self.entries.retain(|e| e.sequence <= sequence);
192
193 self.savepoints.truncate(idx + 1);
195
196 Some(sequence)
197 } else {
198 None
199 }
200 }
201
202 pub fn entries_for_replay(&self) -> Vec<&JournalEntry> {
204 self.entries.iter().collect()
205 }
206
207 pub fn mutation_entries(&self) -> Vec<&JournalEntry> {
209 self.entries
210 .iter()
211 .filter(|e| e.statement_type.is_mutation())
212 .collect()
213 }
214
215 pub fn total_size(&self) -> usize {
217 self.entries
218 .iter()
219 .map(|e| e.statement.len() + estimate_params_size(&e.parameters))
220 .sum()
221 }
222}
223
224fn estimate_params_size(params: &[JournalValue]) -> usize {
225 params
226 .iter()
227 .map(|p| match p {
228 JournalValue::Null => 1,
229 JournalValue::Bool(_) => 1,
230 JournalValue::Int64(_) => 8,
231 JournalValue::Float64(_) => 8,
232 JournalValue::Text(s) => s.len(),
233 JournalValue::Bytes(b) => b.len(),
234 JournalValue::Array(a) => estimate_params_size(a),
235 })
236 .sum()
237}
238
239pub struct TransactionJournal {
241 journals: Arc<RwLock<HashMap<Uuid, TransactionJournalEntry>>>,
243 max_entries: usize,
245 max_size: usize,
247 enabled: bool,
249}
250
251impl TransactionJournal {
252 pub fn new() -> Self {
254 Self {
255 journals: Arc::new(RwLock::new(HashMap::new())),
256 max_entries: 10000,
257 max_size: 64 * 1024 * 1024, enabled: true,
259 }
260 }
261
262 pub fn with_max_entries(mut self, max: usize) -> Self {
264 self.max_entries = max;
265 self
266 }
267
268 pub fn with_max_size(mut self, max: usize) -> Self {
270 self.max_size = max;
271 self
272 }
273
274 pub fn set_enabled(&mut self, enabled: bool) {
276 self.enabled = enabled;
277 }
278
279 pub async fn entries_in_window(
289 &self,
290 from: chrono::DateTime<chrono::Utc>,
291 to: chrono::DateTime<chrono::Utc>,
292 ) -> Vec<(Uuid, JournalEntry)> {
293 let journals = self.journals.read().await;
294 let mut out: Vec<(Uuid, JournalEntry)> = Vec::new();
295 for (tx_id, j) in journals.iter() {
296 for entry in &j.entries {
297 if entry.timestamp >= from && entry.timestamp <= to {
298 out.push((*tx_id, entry.clone()));
299 }
300 }
301 }
302 out.sort_by_key(|(_, e)| e.timestamp);
303 out
304 }
305
306 pub async fn begin_transaction(
308 &self,
309 tx_id: Uuid,
310 session_id: Uuid,
311 node_id: NodeId,
312 start_lsn: u64,
313 ) -> Result<()> {
314 if !self.enabled {
315 return Ok(());
316 }
317
318 let journal = TransactionJournalEntry::new(tx_id, session_id, node_id, start_lsn);
319 self.journals.write().await.insert(tx_id, journal);
320
321 tracing::debug!("Started journaling transaction {:?}", tx_id);
322 Ok(())
323 }
324
325 pub async fn log_statement(
327 &self,
328 tx_id: Uuid,
329 statement: String,
330 parameters: Vec<JournalValue>,
331 result_checksum: Option<u64>,
332 rows_affected: Option<u64>,
333 duration_ms: u64,
334 ) -> Result<()> {
335 if !self.enabled {
336 return Ok(());
337 }
338
339 let mut journals = self.journals.write().await;
340 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
341 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
342 })?;
343
344 if journal.entries.len() >= self.max_entries {
346 return Err(ProxyError::Internal(
347 "Transaction journal entries limit exceeded".to_string(),
348 ));
349 }
350
351 if journal.total_size() >= self.max_size {
352 return Err(ProxyError::Internal(
353 "Transaction journal size limit exceeded".to_string(),
354 ));
355 }
356
357 let sequence = journal.current_sequence + 1;
358 let statement_type = StatementType::from_sql(&statement);
359
360 let entry = JournalEntry {
361 sequence,
362 statement,
363 parameters,
364 result_checksum,
365 rows_affected,
366 timestamp: chrono::Utc::now(),
367 statement_type,
368 duration_ms,
369 };
370
371 journal.add_entry(entry);
372
373 Ok(())
374 }
375
376 pub async fn create_savepoint(&self, tx_id: Uuid, name: String) -> Result<()> {
378 if !self.enabled {
379 return Ok(());
380 }
381
382 let mut journals = self.journals.write().await;
383 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
384 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
385 })?;
386
387 journal.create_savepoint(name);
388 Ok(())
389 }
390
391 pub async fn rollback_to_savepoint(&self, tx_id: Uuid, name: &str) -> Result<()> {
393 if !self.enabled {
394 return Ok(());
395 }
396
397 let mut journals = self.journals.write().await;
398 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
399 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
400 })?;
401
402 journal
403 .rollback_to_savepoint(name)
404 .ok_or_else(|| ProxyError::Internal(format!("Savepoint '{}' not found", name)))?;
405
406 Ok(())
407 }
408
409 pub async fn commit_transaction(&self, tx_id: Uuid) -> Result<()> {
411 self.journals.write().await.remove(&tx_id);
412 tracing::debug!("Committed and cleared journal for transaction {:?}", tx_id);
413 Ok(())
414 }
415
416 pub async fn rollback_transaction(&self, tx_id: Uuid) -> Result<()> {
418 self.journals.write().await.remove(&tx_id);
419 tracing::debug!(
420 "Rolled back and cleared journal for transaction {:?}",
421 tx_id
422 );
423 Ok(())
424 }
425
426 pub async fn get_journal(&self, tx_id: &Uuid) -> Option<TransactionJournalEntry> {
428 self.journals.read().await.get(tx_id).cloned()
429 }
430
431 pub async fn active_count(&self) -> usize {
433 self.journals.read().await.len()
434 }
435
436 pub async fn stats(&self) -> JournalStats {
438 let journals = self.journals.read().await;
439 let total_entries: usize = journals.values().map(|j| j.entries.len()).sum();
440 let total_size: usize = journals.values().map(|j| j.total_size()).sum();
441
442 JournalStats {
443 active_transactions: journals.len(),
444 total_entries,
445 total_size_bytes: total_size,
446 enabled: self.enabled,
447 }
448 }
449
450 pub async fn get_all_active(&self) -> Vec<TransactionJournalEntry> {
452 self.journals.read().await.values().cloned().collect()
453 }
454
455 pub async fn get_max_start_lsn(&self) -> Option<u64> {
458 let journals = self.journals.read().await;
459 journals.values().map(|j| j.start_lsn).max()
460 }
461
462 pub async fn get_transactions_for_node(&self, node_id: NodeId) -> Vec<TransactionJournalEntry> {
465 self.journals
466 .read()
467 .await
468 .values()
469 .filter(|j| j.node_id == node_id)
470 .cloned()
471 .collect()
472 }
473}
474
475impl Default for TransactionJournal {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct JournalStats {
484 pub active_transactions: usize,
486 pub total_entries: usize,
488 pub total_size_bytes: usize,
490 pub enabled: bool,
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_statement_type_detection() {
500 assert_eq!(
501 StatementType::from_sql("SELECT * FROM users"),
502 StatementType::Select
503 );
504 assert_eq!(
505 StatementType::from_sql("INSERT INTO users VALUES (1)"),
506 StatementType::Insert
507 );
508 assert_eq!(
509 StatementType::from_sql("UPDATE users SET name = 'x'"),
510 StatementType::Update
511 );
512 assert_eq!(
513 StatementType::from_sql("DELETE FROM users"),
514 StatementType::Delete
515 );
516 assert_eq!(
517 StatementType::from_sql("CREATE TABLE foo (id INT)"),
518 StatementType::Ddl
519 );
520 assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
521 assert_eq!(
522 StatementType::from_sql("SET search_path = public"),
523 StatementType::Set
524 );
525 }
526
527 #[test]
528 fn test_statement_type_properties() {
529 assert!(StatementType::Select.is_read_only());
530 assert!(!StatementType::Insert.is_read_only());
531
532 assert!(StatementType::Insert.is_mutation());
533 assert!(StatementType::Update.is_mutation());
534 assert!(!StatementType::Select.is_mutation());
535 }
536
537 #[tokio::test]
538 async fn test_journal_lifecycle() {
539 let journal = TransactionJournal::new();
540 let tx_id = Uuid::new_v4();
541 let session_id = Uuid::new_v4();
542 let node_id = NodeId::new();
543
544 journal
546 .begin_transaction(tx_id, session_id, node_id, 0)
547 .await
548 .unwrap();
549
550 journal
552 .log_statement(
553 tx_id,
554 "SELECT * FROM users".to_string(),
555 vec![],
556 Some(12345),
557 None,
558 10,
559 )
560 .await
561 .unwrap();
562
563 journal
564 .log_statement(
565 tx_id,
566 "INSERT INTO users (name) VALUES ($1)".to_string(),
567 vec![JournalValue::Text("test".to_string())],
568 None,
569 Some(1),
570 5,
571 )
572 .await
573 .unwrap();
574
575 let j = journal.get_journal(&tx_id).await.unwrap();
577 assert_eq!(j.entries.len(), 2);
578 assert!(j.has_mutations);
579
580 journal.commit_transaction(tx_id).await.unwrap();
582 assert!(journal.get_journal(&tx_id).await.is_none());
583 }
584
585 #[tokio::test]
586 async fn test_savepoints() {
587 let journal = TransactionJournal::new();
588 let tx_id = Uuid::new_v4();
589 let session_id = Uuid::new_v4();
590 let node_id = NodeId::new();
591
592 journal
593 .begin_transaction(tx_id, session_id, node_id, 0)
594 .await
595 .unwrap();
596
597 for i in 0..3 {
599 journal
600 .log_statement(
601 tx_id,
602 format!("INSERT INTO t VALUES ({})", i),
603 vec![],
604 None,
605 Some(1),
606 1,
607 )
608 .await
609 .unwrap();
610 }
611
612 journal
614 .create_savepoint(tx_id, "sp1".to_string())
615 .await
616 .unwrap();
617
618 for i in 3..5 {
620 journal
621 .log_statement(
622 tx_id,
623 format!("INSERT INTO t VALUES ({})", i),
624 vec![],
625 None,
626 Some(1),
627 1,
628 )
629 .await
630 .unwrap();
631 }
632
633 let j = journal.get_journal(&tx_id).await.unwrap();
634 assert_eq!(j.entries.len(), 5);
635
636 journal.rollback_to_savepoint(tx_id, "sp1").await.unwrap();
638
639 let j = journal.get_journal(&tx_id).await.unwrap();
640 assert_eq!(j.entries.len(), 3);
641 }
642
643 #[tokio::test]
644 async fn test_stats() {
645 let journal = TransactionJournal::new();
646 let tx_id = Uuid::new_v4();
647 let session_id = Uuid::new_v4();
648 let node_id = NodeId::new();
649
650 journal
651 .begin_transaction(tx_id, session_id, node_id, 0)
652 .await
653 .unwrap();
654 journal
655 .log_statement(tx_id, "SELECT 1".to_string(), vec![], None, None, 1)
656 .await
657 .unwrap();
658
659 let stats = journal.stats().await;
660 assert_eq!(stats.active_transactions, 1);
661 assert_eq!(stats.total_entries, 1);
662 assert!(stats.enabled);
663 }
664}