use super::{NodeId, ProxyError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct JournalEntry {
pub sequence: u64,
pub statement: String,
pub parameters: Vec<JournalValue>,
pub result_checksum: Option<u64>,
pub rows_affected: Option<u64>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub statement_type: StatementType,
pub duration_ms: u64,
}
#[derive(Debug, Clone)]
pub enum JournalValue {
Null,
Bool(bool),
Int64(i64),
Float64(f64),
Text(String),
Bytes(Vec<u8>),
Array(Vec<JournalValue>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StatementType {
Select,
Insert,
Update,
Delete,
Ddl,
Transaction,
Set,
Other,
}
impl StatementType {
pub fn from_sql(sql: &str) -> Self {
let upper = sql.trim().to_uppercase();
if upper.starts_with("SELECT") {
StatementType::Select
} else if upper.starts_with("INSERT") {
StatementType::Insert
} else if upper.starts_with("UPDATE") {
StatementType::Update
} else if upper.starts_with("DELETE") {
StatementType::Delete
} else if upper.starts_with("CREATE")
|| upper.starts_with("ALTER")
|| upper.starts_with("DROP")
{
StatementType::Ddl
} else if upper.starts_with("BEGIN")
|| upper.starts_with("COMMIT")
|| upper.starts_with("ROLLBACK")
|| upper.starts_with("SAVEPOINT")
{
StatementType::Transaction
} else if upper.starts_with("SET") {
StatementType::Set
} else {
StatementType::Other
}
}
pub fn is_read_only(&self) -> bool {
matches!(self, StatementType::Select)
}
pub fn is_mutation(&self) -> bool {
matches!(
self,
StatementType::Insert | StatementType::Update | StatementType::Delete | StatementType::Ddl
)
}
}
#[derive(Debug, Clone)]
pub struct TransactionJournalEntry {
pub tx_id: Uuid,
pub session_id: Uuid,
pub node_id: NodeId,
pub started_at: chrono::DateTime<chrono::Utc>,
pub start_lsn: u64,
pub entries: Vec<JournalEntry>,
pub current_sequence: u64,
pub active: bool,
pub has_mutations: bool,
pub savepoints: Vec<Savepoint>,
}
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
pub sequence: u64,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl TransactionJournalEntry {
pub fn new(tx_id: Uuid, session_id: Uuid, node_id: NodeId, start_lsn: u64) -> Self {
Self {
tx_id,
session_id,
node_id,
started_at: chrono::Utc::now(),
start_lsn,
entries: Vec::new(),
current_sequence: 0,
active: true,
has_mutations: false,
savepoints: Vec::new(),
}
}
pub fn add_entry(&mut self, entry: JournalEntry) {
if entry.statement_type.is_mutation() {
self.has_mutations = true;
}
self.current_sequence = entry.sequence;
self.entries.push(entry);
}
pub fn create_savepoint(&mut self, name: String) {
self.savepoints.push(Savepoint {
name,
sequence: self.current_sequence,
created_at: chrono::Utc::now(),
});
}
pub fn rollback_to_savepoint(&mut self, name: &str) -> Option<u64> {
if let Some(idx) = self.savepoints.iter().position(|s| s.name == name) {
let savepoint = &self.savepoints[idx];
let sequence = savepoint.sequence;
self.entries.retain(|e| e.sequence <= sequence);
self.savepoints.truncate(idx + 1);
Some(sequence)
} else {
None
}
}
pub fn entries_for_replay(&self) -> Vec<&JournalEntry> {
self.entries.iter().collect()
}
pub fn mutation_entries(&self) -> Vec<&JournalEntry> {
self.entries
.iter()
.filter(|e| e.statement_type.is_mutation())
.collect()
}
pub fn total_size(&self) -> usize {
self.entries
.iter()
.map(|e| e.statement.len() + estimate_params_size(&e.parameters))
.sum()
}
}
fn estimate_params_size(params: &[JournalValue]) -> usize {
params
.iter()
.map(|p| match p {
JournalValue::Null => 1,
JournalValue::Bool(_) => 1,
JournalValue::Int64(_) => 8,
JournalValue::Float64(_) => 8,
JournalValue::Text(s) => s.len(),
JournalValue::Bytes(b) => b.len(),
JournalValue::Array(a) => estimate_params_size(a),
})
.sum()
}
pub struct TransactionJournal {
journals: Arc<RwLock<HashMap<Uuid, TransactionJournalEntry>>>,
max_entries: usize,
max_size: usize,
enabled: bool,
}
impl TransactionJournal {
pub fn new() -> Self {
Self {
journals: Arc::new(RwLock::new(HashMap::new())),
max_entries: 10000,
max_size: 64 * 1024 * 1024, enabled: true,
}
}
pub fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
pub fn with_max_size(mut self, max: usize) -> Self {
self.max_size = max;
self
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub async fn entries_in_window(
&self,
from: chrono::DateTime<chrono::Utc>,
to: chrono::DateTime<chrono::Utc>,
) -> Vec<(Uuid, JournalEntry)> {
let journals = self.journals.read().await;
let mut out: Vec<(Uuid, JournalEntry)> = Vec::new();
for (tx_id, j) in journals.iter() {
for entry in &j.entries {
if entry.timestamp >= from && entry.timestamp <= to {
out.push((*tx_id, entry.clone()));
}
}
}
out.sort_by_key(|(_, e)| e.timestamp);
out
}
pub async fn begin_transaction(
&self,
tx_id: Uuid,
session_id: Uuid,
node_id: NodeId,
start_lsn: u64,
) -> Result<()> {
if !self.enabled {
return Ok(());
}
let journal = TransactionJournalEntry::new(tx_id, session_id, node_id, start_lsn);
self.journals.write().await.insert(tx_id, journal);
tracing::debug!("Started journaling transaction {:?}", tx_id);
Ok(())
}
pub async fn log_statement(
&self,
tx_id: Uuid,
statement: String,
parameters: Vec<JournalValue>,
result_checksum: Option<u64>,
rows_affected: Option<u64>,
duration_ms: u64,
) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut journals = self.journals.write().await;
let journal = journals.get_mut(&tx_id).ok_or_else(|| {
ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
})?;
if journal.entries.len() >= self.max_entries {
return Err(ProxyError::Internal("Transaction journal entries limit exceeded".to_string()));
}
if journal.total_size() >= self.max_size {
return Err(ProxyError::Internal("Transaction journal size limit exceeded".to_string()));
}
let sequence = journal.current_sequence + 1;
let statement_type = StatementType::from_sql(&statement);
let entry = JournalEntry {
sequence,
statement,
parameters,
result_checksum,
rows_affected,
timestamp: chrono::Utc::now(),
statement_type,
duration_ms,
};
journal.add_entry(entry);
Ok(())
}
pub async fn create_savepoint(&self, tx_id: Uuid, name: String) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut journals = self.journals.write().await;
let journal = journals.get_mut(&tx_id).ok_or_else(|| {
ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
})?;
journal.create_savepoint(name);
Ok(())
}
pub async fn rollback_to_savepoint(&self, tx_id: Uuid, name: &str) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut journals = self.journals.write().await;
let journal = journals.get_mut(&tx_id).ok_or_else(|| {
ProxyError::Internal(format!("No journal for transaction {:?}", tx_id))
})?;
journal
.rollback_to_savepoint(name)
.ok_or_else(|| ProxyError::Internal(format!("Savepoint '{}' not found", name)))?;
Ok(())
}
pub async fn commit_transaction(&self, tx_id: Uuid) -> Result<()> {
self.journals.write().await.remove(&tx_id);
tracing::debug!("Committed and cleared journal for transaction {:?}", tx_id);
Ok(())
}
pub async fn rollback_transaction(&self, tx_id: Uuid) -> Result<()> {
self.journals.write().await.remove(&tx_id);
tracing::debug!("Rolled back and cleared journal for transaction {:?}", tx_id);
Ok(())
}
pub async fn get_journal(&self, tx_id: &Uuid) -> Option<TransactionJournalEntry> {
self.journals.read().await.get(tx_id).cloned()
}
pub async fn active_count(&self) -> usize {
self.journals.read().await.len()
}
pub async fn stats(&self) -> JournalStats {
let journals = self.journals.read().await;
let total_entries: usize = journals.values().map(|j| j.entries.len()).sum();
let total_size: usize = journals.values().map(|j| j.total_size()).sum();
JournalStats {
active_transactions: journals.len(),
total_entries,
total_size_bytes: total_size,
enabled: self.enabled,
}
}
pub async fn get_all_active(&self) -> Vec<TransactionJournalEntry> {
self.journals.read().await.values().cloned().collect()
}
pub async fn get_max_start_lsn(&self) -> Option<u64> {
let journals = self.journals.read().await;
journals.values().map(|j| j.start_lsn).max()
}
pub async fn get_transactions_for_node(&self, node_id: NodeId) -> Vec<TransactionJournalEntry> {
self.journals
.read()
.await
.values()
.filter(|j| j.node_id == node_id)
.cloned()
.collect()
}
}
impl Default for TransactionJournal {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct JournalStats {
pub active_transactions: usize,
pub total_entries: usize,
pub total_size_bytes: usize,
pub enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_statement_type_detection() {
assert_eq!(StatementType::from_sql("SELECT * FROM users"), StatementType::Select);
assert_eq!(StatementType::from_sql("INSERT INTO users VALUES (1)"), StatementType::Insert);
assert_eq!(StatementType::from_sql("UPDATE users SET name = 'x'"), StatementType::Update);
assert_eq!(StatementType::from_sql("DELETE FROM users"), StatementType::Delete);
assert_eq!(StatementType::from_sql("CREATE TABLE foo (id INT)"), StatementType::Ddl);
assert_eq!(StatementType::from_sql("BEGIN"), StatementType::Transaction);
assert_eq!(StatementType::from_sql("SET search_path = public"), StatementType::Set);
}
#[test]
fn test_statement_type_properties() {
assert!(StatementType::Select.is_read_only());
assert!(!StatementType::Insert.is_read_only());
assert!(StatementType::Insert.is_mutation());
assert!(StatementType::Update.is_mutation());
assert!(!StatementType::Select.is_mutation());
}
#[tokio::test]
async fn test_journal_lifecycle() {
let journal = TransactionJournal::new();
let tx_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let node_id = NodeId::new();
journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
journal.log_statement(
tx_id,
"SELECT * FROM users".to_string(),
vec![],
Some(12345),
None,
10,
).await.unwrap();
journal.log_statement(
tx_id,
"INSERT INTO users (name) VALUES ($1)".to_string(),
vec![JournalValue::Text("test".to_string())],
None,
Some(1),
5,
).await.unwrap();
let j = journal.get_journal(&tx_id).await.unwrap();
assert_eq!(j.entries.len(), 2);
assert!(j.has_mutations);
journal.commit_transaction(tx_id).await.unwrap();
assert!(journal.get_journal(&tx_id).await.is_none());
}
#[tokio::test]
async fn test_savepoints() {
let journal = TransactionJournal::new();
let tx_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let node_id = NodeId::new();
journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
for i in 0..3 {
journal.log_statement(
tx_id,
format!("INSERT INTO t VALUES ({})", i),
vec![],
None,
Some(1),
1,
).await.unwrap();
}
journal.create_savepoint(tx_id, "sp1".to_string()).await.unwrap();
for i in 3..5 {
journal.log_statement(
tx_id,
format!("INSERT INTO t VALUES ({})", i),
vec![],
None,
Some(1),
1,
).await.unwrap();
}
let j = journal.get_journal(&tx_id).await.unwrap();
assert_eq!(j.entries.len(), 5);
journal.rollback_to_savepoint(tx_id, "sp1").await.unwrap();
let j = journal.get_journal(&tx_id).await.unwrap();
assert_eq!(j.entries.len(), 3);
}
#[tokio::test]
async fn test_stats() {
let journal = TransactionJournal::new();
let tx_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let node_id = NodeId::new();
journal.begin_transaction(tx_id, session_id, node_id, 0).await.unwrap();
journal.log_statement(
tx_id,
"SELECT 1".to_string(),
vec![],
None,
None,
1,
).await.unwrap();
let stats = journal.stats().await;
assert_eq!(stats.active_transactions, 1);
assert_eq!(stats.total_entries, 1);
assert!(stats.enabled);
}
}