use crate::backend::{BackendClient, BackendConfig, ParamValue};
use crate::transaction_journal::{JournalEntry, JournalValue, TransactionJournal};
use crate::{ProxyError, Result};
use chrono::{DateTime, Utc};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct TimeTravelRequest {
pub from: DateTime<Utc>,
pub to: DateTime<Utc>,
pub target_host: String,
pub target_port: u16,
pub target_user: Option<String>,
pub target_password: Option<String>,
pub target_database: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ReplaySummary {
pub statements_replayed: u64,
pub failures: u64,
pub elapsed_ms: u64,
#[serde(with = "chrono::serde::ts_seconds")]
pub from: DateTime<Utc>,
#[serde(with = "chrono::serde::ts_seconds")]
pub to: DateTime<Utc>,
pub first_error: Option<String>,
}
pub struct ReplayEngine {
journal: Arc<TransactionJournal>,
backend_template: BackendConfig,
}
impl ReplayEngine {
pub fn new(journal: Arc<TransactionJournal>, backend_template: BackendConfig) -> Self {
Self {
journal,
backend_template,
}
}
pub async fn replay_window(
&self,
req: &TimeTravelRequest,
) -> Result<ReplaySummary> {
if req.from > req.to {
return Err(ProxyError::Internal(
"replay window: from > to".to_string(),
));
}
let entries = self.journal.entries_in_window(req.from, req.to).await;
let total = entries.len();
tracing::info!(
total_entries = total,
from = %req.from,
to = %req.to,
target = %format!("{}:{}", req.target_host, req.target_port),
"starting time-travel replay"
);
let mut cfg = self.backend_template.clone();
cfg.host = req.target_host.clone();
cfg.port = req.target_port;
if let Some(ref u) = req.target_user {
cfg.user = u.clone();
}
if let Some(ref p) = req.target_password {
cfg.password = Some(p.clone());
}
if let Some(ref d) = req.target_database {
cfg.database = Some(d.clone());
}
let start = std::time::Instant::now();
let mut client = BackendClient::connect(&cfg).await.map_err(|e| {
ProxyError::ReplayFailed(format!("connect to target: {}", e))
})?;
let mut statements_replayed: u64 = 0;
let mut failures: u64 = 0;
let mut first_error: Option<String> = None;
for (tx_id, entry) in entries {
let params: Vec<ParamValue> =
entry.parameters.iter().map(journal_value_to_param).collect();
let outcome = if params.is_empty() {
client.simple_query(&entry.statement).await
} else {
client.query_with_params(&entry.statement, ¶ms).await
};
match outcome {
Ok(_) => {
statements_replayed += 1;
}
Err(e) => {
failures += 1;
if first_error.is_none() {
first_error = Some(format!(
"tx {} seq {}: {}",
tx_id, entry.sequence, e
));
}
tracing::warn!(
tx = %tx_id,
sequence = entry.sequence,
error = %e,
"replay statement failed"
);
}
}
}
client.close().await;
Ok(ReplaySummary {
statements_replayed,
failures,
elapsed_ms: start.elapsed().as_millis() as u64,
from: req.from,
to: req.to,
first_error,
})
}
}
fn journal_value_to_param(v: &JournalValue) -> ParamValue {
match v {
JournalValue::Null => ParamValue::Null,
JournalValue::Bool(b) => ParamValue::Bool(*b),
JournalValue::Int64(i) => ParamValue::Int(*i),
JournalValue::Float64(f) => ParamValue::Float(*f),
JournalValue::Text(s) => ParamValue::Text(s.clone()),
JournalValue::Bytes(b) => {
let mut s = String::with_capacity(2 + b.len() * 2);
s.push_str("\\x");
for byte in b {
s.push_str(&format!("{:02x}", byte));
}
ParamValue::Text(s)
}
JournalValue::Array(_) => ParamValue::Null,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::{tls::default_client_config, TlsMode};
use crate::transaction_journal::StatementType;
use crate::NodeId;
use std::time::Duration;
fn test_template() -> BackendConfig {
BackendConfig {
host: "placeholder".into(),
port: 0,
user: "postgres".into(),
password: None,
database: None,
application_name: Some("helios-replay".into()),
tls_mode: TlsMode::Disable,
connect_timeout: Duration::from_millis(200),
query_timeout: Duration::from_millis(200),
tls_config: default_client_config(),
}
}
fn make_entry(
sequence: u64,
statement: &str,
timestamp: DateTime<Utc>,
) -> JournalEntry {
JournalEntry {
sequence,
statement: statement.to_string(),
parameters: vec![],
result_checksum: None,
rows_affected: None,
timestamp,
statement_type: StatementType::Select,
duration_ms: 1,
}
}
#[tokio::test]
async fn test_replay_rejects_inverted_window() {
let journal = Arc::new(TransactionJournal::new());
let engine = ReplayEngine::new(journal, test_template());
let now = Utc::now();
let req = TimeTravelRequest {
from: now,
to: now - chrono::Duration::seconds(1),
target_host: "127.0.0.1".into(),
target_port: 1,
target_user: None,
target_password: None,
target_database: None,
};
let err = engine.replay_window(&req).await.unwrap_err();
assert!(matches!(err, ProxyError::Internal(_)));
}
#[tokio::test]
async fn test_replay_empty_window_still_connects() {
let journal = Arc::new(TransactionJournal::new());
let engine = ReplayEngine::new(journal, test_template());
let now = Utc::now();
let req = TimeTravelRequest {
from: now - chrono::Duration::hours(1),
to: now,
target_host: "127.0.0.1".into(),
target_port: 1, target_user: None,
target_password: None,
target_database: None,
};
let err = engine.replay_window(&req).await.unwrap_err();
match err {
ProxyError::ReplayFailed(msg) => assert!(msg.contains("connect")),
other => panic!("expected ReplayFailed, got {:?}", other),
}
}
#[tokio::test]
async fn test_entries_in_window_filters_correctly() {
let journal = Arc::new(TransactionJournal::new());
let tx_id = Uuid::new_v4();
let session = Uuid::new_v4();
let node = NodeId::new();
let base = Utc::now();
journal
.begin_transaction(tx_id, session, node, 0)
.await
.unwrap();
let _ = base; journal
.log_statement(
tx_id,
"SELECT 1".to_string(),
vec![],
None,
None,
1,
)
.await
.unwrap();
let from = Utc::now() - chrono::Duration::seconds(5);
let to = Utc::now() + chrono::Duration::seconds(5);
let entries = journal.entries_in_window(from, to).await;
assert_eq!(entries.len(), 1, "single in-window entry");
let far_past_to = Utc::now() - chrono::Duration::hours(1);
let far_past_from = far_past_to - chrono::Duration::hours(1);
let entries = journal.entries_in_window(far_past_from, far_past_to).await;
assert!(entries.is_empty(), "no entries in far-past window");
}
#[test]
fn test_journal_value_to_param_matches_failover_shape() {
assert!(matches!(
journal_value_to_param(&JournalValue::Null),
ParamValue::Null
));
assert!(matches!(
journal_value_to_param(&JournalValue::Bool(true)),
ParamValue::Bool(true)
));
assert!(matches!(
journal_value_to_param(&JournalValue::Int64(-7)),
ParamValue::Int(-7)
));
}
#[test]
fn test_credential_overrides_replace_template_fields() {
let mut cfg = test_template();
cfg.user = "default_user".into();
cfg.password = None;
cfg.database = None;
let req = TimeTravelRequest {
from: Utc::now(),
to: Utc::now(),
target_host: "h".into(),
target_port: 5432,
target_user: Some("override_user".into()),
target_password: Some("secret".into()),
target_database: Some("staging".into()),
};
if let Some(ref u) = req.target_user {
cfg.user = u.clone();
}
if let Some(ref p) = req.target_password {
cfg.password = Some(p.clone());
}
if let Some(ref d) = req.target_database {
cfg.database = Some(d.clone());
}
assert_eq!(cfg.user, "override_user");
assert_eq!(cfg.password.as_deref(), Some("secret"));
assert_eq!(cfg.database.as_deref(), Some("staging"));
}
#[test]
fn test_credential_overrides_none_keeps_template_fields() {
let mut cfg = test_template();
cfg.user = "default_user".into();
cfg.password = Some("template_pw".into());
cfg.database = Some("default_db".into());
let req = TimeTravelRequest {
from: Utc::now(),
to: Utc::now(),
target_host: "h".into(),
target_port: 5432,
target_user: None,
target_password: None,
target_database: None,
};
if let Some(ref u) = req.target_user {
cfg.user = u.clone();
}
let _ = req;
assert_eq!(cfg.user, "default_user");
assert_eq!(cfg.password.as_deref(), Some("template_pw"));
assert_eq!(cfg.database.as_deref(), Some("default_db"));
}
#[test]
fn test_replay_summary_serializes() {
let s = ReplaySummary {
statements_replayed: 5,
failures: 1,
elapsed_ms: 42,
from: Utc::now(),
to: Utc::now(),
first_error: Some("oops".into()),
};
let j = serde_json::to_string(&s).unwrap();
assert!(j.contains("\"statements_replayed\":5"));
assert!(j.contains("\"failures\":1"));
assert!(j.contains("oops"));
}
}