1use super::{NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
14pub struct JournalEntry {
15 pub sequence: u64,
17 pub statement: String,
19 pub parameters: Vec<JournalValue>,
21 pub result_checksum: Option<u64>,
23 pub rows_affected: Option<u64>,
25 pub timestamp: chrono::DateTime<chrono::Utc>,
27 pub statement_type: StatementType,
29 pub duration_ms: u64,
31}
32
33#[derive(Debug, Clone)]
35pub enum JournalValue {
36 Null,
37 Bool(bool),
38 Int64(i64),
39 Float64(f64),
40 Text(String),
41 Bytes(Vec<u8>),
42 Array(Vec<JournalValue>),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum StatementType {
48 Select,
50 Insert,
52 Update,
54 Delete,
56 Ddl,
58 Transaction,
60 Set,
62 Other,
64}
65
66impl StatementType {
67 pub fn from_sql(sql: &str) -> Self {
69 let upper = sql.trim().to_uppercase();
70 if upper.starts_with("SELECT") {
71 StatementType::Select
72 } else if upper.starts_with("INSERT") {
73 StatementType::Insert
74 } else if upper.starts_with("UPDATE") {
75 StatementType::Update
76 } else if upper.starts_with("DELETE") {
77 StatementType::Delete
78 } else if upper.starts_with("CREATE")
79 || upper.starts_with("ALTER")
80 || upper.starts_with("DROP")
81 {
82 StatementType::Ddl
83 } else if upper.starts_with("BEGIN")
84 || upper.starts_with("COMMIT")
85 || upper.starts_with("ROLLBACK")
86 || upper.starts_with("SAVEPOINT")
87 {
88 StatementType::Transaction
89 } else if upper.starts_with("SET") {
90 StatementType::Set
91 } else {
92 StatementType::Other
93 }
94 }
95
96 pub fn is_read_only(&self) -> bool {
98 matches!(self, StatementType::Select)
99 }
100
101 pub fn is_mutation(&self) -> bool {
103 matches!(
104 self,
105 StatementType::Insert | StatementType::Update | StatementType::Delete | StatementType::Ddl
106 )
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct TransactionJournalEntry {
113 pub tx_id: Uuid,
115 pub session_id: Uuid,
117 pub node_id: NodeId,
119 pub started_at: chrono::DateTime<chrono::Utc>,
121 pub start_lsn: u64,
123 pub entries: Vec<JournalEntry>,
125 pub current_sequence: u64,
127 pub active: bool,
129 pub has_mutations: bool,
131 pub savepoints: Vec<Savepoint>,
133}
134
135#[derive(Debug, Clone)]
137pub struct Savepoint {
138 pub name: String,
140 pub sequence: u64,
142 pub created_at: chrono::DateTime<chrono::Utc>,
144}
145
146impl TransactionJournalEntry {
147 pub fn new(tx_id: Uuid, session_id: Uuid, node_id: NodeId, start_lsn: u64) -> Self {
149 Self {
150 tx_id,
151 session_id,
152 node_id,
153 started_at: chrono::Utc::now(),
154 start_lsn,
155 entries: Vec::new(),
156 current_sequence: 0,
157 active: true,
158 has_mutations: false,
159 savepoints: Vec::new(),
160 }
161 }
162
163 pub fn add_entry(&mut self, entry: JournalEntry) {
165 if entry.statement_type.is_mutation() {
166 self.has_mutations = true;
167 }
168 self.current_sequence = entry.sequence;
169 self.entries.push(entry);
170 }
171
172 pub fn create_savepoint(&mut self, name: String) {
174 self.savepoints.push(Savepoint {
175 name,
176 sequence: self.current_sequence,
177 created_at: chrono::Utc::now(),
178 });
179 }
180
181 pub fn rollback_to_savepoint(&mut self, name: &str) -> Option<u64> {
183 if let Some(idx) = self.savepoints.iter().position(|s| s.name == name) {
184 let savepoint = &self.savepoints[idx];
185 let sequence = savepoint.sequence;
186
187 self.entries.retain(|e| e.sequence <= sequence);
189
190 self.savepoints.truncate(idx + 1);
192
193 Some(sequence)
194 } else {
195 None
196 }
197 }
198
199 pub fn entries_for_replay(&self) -> Vec<&JournalEntry> {
201 self.entries.iter().collect()
202 }
203
204 pub fn mutation_entries(&self) -> Vec<&JournalEntry> {
206 self.entries
207 .iter()
208 .filter(|e| e.statement_type.is_mutation())
209 .collect()
210 }
211
212 pub fn total_size(&self) -> usize {
214 self.entries
215 .iter()
216 .map(|e| e.statement.len() + estimate_params_size(&e.parameters))
217 .sum()
218 }
219}
220
221fn estimate_params_size(params: &[JournalValue]) -> usize {
222 params
223 .iter()
224 .map(|p| match p {
225 JournalValue::Null => 1,
226 JournalValue::Bool(_) => 1,
227 JournalValue::Int64(_) => 8,
228 JournalValue::Float64(_) => 8,
229 JournalValue::Text(s) => s.len(),
230 JournalValue::Bytes(b) => b.len(),
231 JournalValue::Array(a) => estimate_params_size(a),
232 })
233 .sum()
234}
235
236pub struct TransactionJournal {
238 journals: Arc<RwLock<HashMap<Uuid, TransactionJournalEntry>>>,
240 max_entries: usize,
242 max_size: usize,
244 enabled: bool,
246}
247
248impl TransactionJournal {
249 pub fn new() -> Self {
251 Self {
252 journals: Arc::new(RwLock::new(HashMap::new())),
253 max_entries: 10000,
254 max_size: 64 * 1024 * 1024, enabled: true,
256 }
257 }
258
259 pub fn with_max_entries(mut self, max: usize) -> Self {
261 self.max_entries = max;
262 self
263 }
264
265 pub fn with_max_size(mut self, max: usize) -> Self {
267 self.max_size = max;
268 self
269 }
270
271 pub fn set_enabled(&mut self, enabled: bool) {
273 self.enabled = enabled;
274 }
275
276 pub async fn entries_in_window(
286 &self,
287 from: chrono::DateTime<chrono::Utc>,
288 to: chrono::DateTime<chrono::Utc>,
289 ) -> Vec<(Uuid, JournalEntry)> {
290 let journals = self.journals.read().await;
291 let mut out: Vec<(Uuid, JournalEntry)> = Vec::new();
292 for (tx_id, j) in journals.iter() {
293 for entry in &j.entries {
294 if entry.timestamp >= from && entry.timestamp <= to {
295 out.push((*tx_id, entry.clone()));
296 }
297 }
298 }
299 out.sort_by_key(|(_, e)| e.timestamp);
300 out
301 }
302
303 pub async fn begin_transaction(
305 &self,
306 tx_id: Uuid,
307 session_id: Uuid,
308 node_id: NodeId,
309 start_lsn: u64,
310 ) -> Result<()> {
311 if !self.enabled {
312 return Ok(());
313 }
314
315 let journal = TransactionJournalEntry::new(tx_id, session_id, node_id, start_lsn);
316 self.journals.write().await.insert(tx_id, journal);
317
318 tracing::debug!("Started journaling transaction {:?}", tx_id);
319 Ok(())
320 }
321
322 pub async fn log_statement(
324 &self,
325 tx_id: Uuid,
326 statement: String,
327 parameters: Vec<JournalValue>,
328 result_checksum: Option<u64>,
329 rows_affected: Option<u64>,
330 duration_ms: u64,
331 ) -> Result<()> {
332 if !self.enabled {
333 return Ok(());
334 }
335
336 let mut journals = self.journals.write().await;
337 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
338 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
339 })?;
340
341 if journal.entries.len() >= self.max_entries {
343 return Err(ProxyError::Internal("Transaction journal entries limit exceeded".to_string()));
344 }
345
346 if journal.total_size() >= self.max_size {
347 return Err(ProxyError::Internal("Transaction journal size limit exceeded".to_string()));
348 }
349
350 let sequence = journal.current_sequence + 1;
351 let statement_type = StatementType::from_sql(&statement);
352
353 let entry = JournalEntry {
354 sequence,
355 statement,
356 parameters,
357 result_checksum,
358 rows_affected,
359 timestamp: chrono::Utc::now(),
360 statement_type,
361 duration_ms,
362 };
363
364 journal.add_entry(entry);
365
366 Ok(())
367 }
368
369 pub async fn create_savepoint(&self, tx_id: Uuid, name: String) -> Result<()> {
371 if !self.enabled {
372 return Ok(());
373 }
374
375 let mut journals = self.journals.write().await;
376 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
377 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
378 })?;
379
380 journal.create_savepoint(name);
381 Ok(())
382 }
383
384 pub async fn rollback_to_savepoint(&self, tx_id: Uuid, name: &str) -> Result<()> {
386 if !self.enabled {
387 return Ok(());
388 }
389
390 let mut journals = self.journals.write().await;
391 let journal = journals.get_mut(&tx_id).ok_or_else(|| {
392 ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
393 })?;
394
395 journal
396 .rollback_to_savepoint(name)
397 .ok_or_else(|| ProxyError::Internal(format!("Savepoint '{}' not found", name)))?;
398
399 Ok(())
400 }
401
402 pub async fn commit_transaction(&self, tx_id: Uuid) -> Result<()> {
404 self.journals.write().await.remove(&tx_id);
405 tracing::debug!("Committed and cleared journal for transaction {:?}", tx_id);
406 Ok(())
407 }
408
409 pub async fn rollback_transaction(&self, tx_id: Uuid) -> Result<()> {
411 self.journals.write().await.remove(&tx_id);
412 tracing::debug!("Rolled back and cleared journal for transaction {:?}", tx_id);
413 Ok(())
414 }
415
416 pub async fn get_journal(&self, tx_id: &Uuid) -> Option<TransactionJournalEntry> {
418 self.journals.read().await.get(tx_id).cloned()
419 }
420
421 pub async fn active_count(&self) -> usize {
423 self.journals.read().await.len()
424 }
425
426 pub async fn stats(&self) -> JournalStats {
428 let journals = self.journals.read().await;
429 let total_entries: usize = journals.values().map(|j| j.entries.len()).sum();
430 let total_size: usize = journals.values().map(|j| j.total_size()).sum();
431
432 JournalStats {
433 active_transactions: journals.len(),
434 total_entries,
435 total_size_bytes: total_size,
436 enabled: self.enabled,
437 }
438 }
439
440 pub async fn get_all_active(&self) -> Vec<TransactionJournalEntry> {
442 self.journals.read().await.values().cloned().collect()
443 }
444
445 pub async fn get_max_start_lsn(&self) -> Option<u64> {
448 let journals = self.journals.read().await;
449 journals.values().map(|j| j.start_lsn).max()
450 }
451
452 pub async fn get_transactions_for_node(&self, node_id: NodeId) -> Vec<TransactionJournalEntry> {
455 self.journals
456 .read()
457 .await
458 .values()
459 .filter(|j| j.node_id == node_id)
460 .cloned()
461 .collect()
462 }
463}
464
465impl Default for TransactionJournal {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471#[derive(Debug, Clone)]
473pub struct JournalStats {
474 pub active_transactions: usize,
476 pub total_entries: usize,
478 pub total_size_bytes: usize,
480 pub enabled: bool,
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_statement_type_detection() {
490 assert_eq!(StatementType::from_sql("SELECT * FROM users"), StatementType::Select);
491 assert_eq!(StatementType::from_sql("INSERT INTO users VALUES (1)"), StatementType::Insert);
492 assert_eq!(StatementType::from_sql("UPDATE users SET name = 'x'"), StatementType::Update);
493 assert_eq!(StatementType::from_sql("DELETE FROM users"), StatementType::Delete);
494 assert_eq!(StatementType::from_sql("CREATE TABLE foo (id INT)"), StatementType::Ddl);
495 assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
496 assert_eq!(StatementType::from_sql("SET search_path = public"), StatementType::Set);
497 }
498
499 #[test]
500 fn test_statement_type_properties() {
501 assert!(StatementType::Select.is_read_only());
502 assert!(!StatementType::Insert.is_read_only());
503
504 assert!(StatementType::Insert.is_mutation());
505 assert!(StatementType::Update.is_mutation());
506 assert!(!StatementType::Select.is_mutation());
507 }
508
509 #[tokio::test]
510 async fn test_journal_lifecycle() {
511 let journal = TransactionJournal::new();
512 let tx_id = Uuid::new_v4();
513 let session_id = Uuid::new_v4();
514 let node_id = NodeId::new();
515
516 journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
518
519 journal.log_statement(
521 tx_id,
522 "SELECT * FROM users".to_string(),
523 vec![],
524 Some(12345),
525 None,
526 10,
527 ).await.unwrap();
528
529 journal.log_statement(
530 tx_id,
531 "INSERT INTO users (name) VALUES ($1)".to_string(),
532 vec![JournalValue::Text("test".to_string())],
533 None,
534 Some(1),
535 5,
536 ).await.unwrap();
537
538 let j = journal.get_journal(&tx_id).await.unwrap();
540 assert_eq!(j.entries.len(), 2);
541 assert!(j.has_mutations);
542
543 journal.commit_transaction(tx_id).await.unwrap();
545 assert!(journal.get_journal(&tx_id).await.is_none());
546 }
547
548 #[tokio::test]
549 async fn test_savepoints() {
550 let journal = TransactionJournal::new();
551 let tx_id = Uuid::new_v4();
552 let session_id = Uuid::new_v4();
553 let node_id = NodeId::new();
554
555 journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
556
557 for i in 0..3 {
559 journal.log_statement(
560 tx_id,
561 format!("INSERT INTO t VALUES ({})", i),
562 vec![],
563 None,
564 Some(1),
565 1,
566 ).await.unwrap();
567 }
568
569 journal.create_savepoint(tx_id, "sp1".to_string()).await.unwrap();
571
572 for i in 3..5 {
574 journal.log_statement(
575 tx_id,
576 format!("INSERT INTO t VALUES ({})", i),
577 vec![],
578 None,
579 Some(1),
580 1,
581 ).await.unwrap();
582 }
583
584 let j = journal.get_journal(&tx_id).await.unwrap();
585 assert_eq!(j.entries.len(), 5);
586
587 journal.rollback_to_savepoint(tx_id, "sp1").await.unwrap();
589
590 let j = journal.get_journal(&tx_id).await.unwrap();
591 assert_eq!(j.entries.len(), 3);
592 }
593
594 #[tokio::test]
595 async fn test_stats() {
596 let journal = TransactionJournal::new();
597 let tx_id = Uuid::new_v4();
598 let session_id = Uuid::new_v4();
599 let node_id = NodeId::new();
600
601 journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
602 journal.log_statement(
603 tx_id,
604 "SELECT 1".to_string(),
605 vec![],
606 None,
607 None,
608 1,
609 ).await.unwrap();
610
611 let stats = journal.stats().await;
612 assert_eq!(stats.active_transactions, 1);
613 assert_eq!(stats.total_entries, 1);
614 assert!(stats.enabled);
615 }
616}