1use crate::backend::{BackendClient, BackendConfig, ParamValue};
15#[cfg(test)]
16use crate::transaction_journal::JournalEntry;
17use crate::transaction_journal::{JournalValue, TransactionJournal};
18use crate::{ProxyError, Result};
19use chrono::{DateTime, Utc};
20use std::sync::Arc;
21
22#[derive(Debug, Clone)]
24pub struct TimeTravelRequest {
25 pub from: DateTime<Utc>,
27 pub to: DateTime<Utc>,
29 pub target_host: String,
31 pub target_port: u16,
33 pub target_user: Option<String>,
37 pub target_password: Option<String>,
41 pub target_database: Option<String>,
43}
44
45#[derive(Debug, Clone, serde::Serialize)]
47pub struct ReplaySummary {
48 pub statements_replayed: u64,
50 pub failures: u64,
52 pub elapsed_ms: u64,
54 #[serde(with = "chrono::serde::ts_seconds")]
56 pub from: DateTime<Utc>,
57 #[serde(with = "chrono::serde::ts_seconds")]
58 pub to: DateTime<Utc>,
59 pub first_error: Option<String>,
62}
63
64pub struct ReplayEngine {
66 journal: Arc<TransactionJournal>,
67 backend_template: BackendConfig,
69}
70
71impl ReplayEngine {
72 pub fn new(journal: Arc<TransactionJournal>, backend_template: BackendConfig) -> Self {
73 Self {
74 journal,
75 backend_template,
76 }
77 }
78
79 pub async fn replay_window(&self, req: &TimeTravelRequest) -> Result<ReplaySummary> {
88 if req.from > req.to {
89 return Err(ProxyError::Internal("replay window: from > to".to_string()));
90 }
91
92 let entries = self.journal.entries_in_window(req.from, req.to).await;
93 let total = entries.len();
94 tracing::info!(
95 total_entries = total,
96 from = %req.from,
97 to = %req.to,
98 target = %format!("{}:{}", req.target_host, req.target_port),
99 "starting time-travel replay"
100 );
101
102 let mut cfg = self.backend_template.clone();
103 cfg.host = req.target_host.clone();
104 cfg.port = req.target_port;
105 if let Some(ref u) = req.target_user {
106 cfg.user = u.clone();
107 }
108 if let Some(ref p) = req.target_password {
109 cfg.password = Some(p.clone());
110 }
111 if let Some(ref d) = req.target_database {
112 cfg.database = Some(d.clone());
113 }
114
115 let start = std::time::Instant::now();
116 let mut client = BackendClient::connect(&cfg)
117 .await
118 .map_err(|e| ProxyError::ReplayFailed(format!("connect to target: {}", e)))?;
119
120 let mut statements_replayed: u64 = 0;
121 let mut failures: u64 = 0;
122 let mut first_error: Option<String> = None;
123
124 for (tx_id, entry) in entries {
125 let params: Vec<ParamValue> = entry
126 .parameters
127 .iter()
128 .map(journal_value_to_param)
129 .collect();
130
131 let outcome = if params.is_empty() {
132 client.simple_query(&entry.statement).await
133 } else {
134 client.query_with_params(&entry.statement, ¶ms).await
135 };
136
137 match outcome {
138 Ok(_) => {
139 statements_replayed += 1;
140 }
141 Err(e) => {
142 failures += 1;
143 if first_error.is_none() {
144 first_error = Some(format!("tx {} seq {}: {}", tx_id, entry.sequence, e));
145 }
146 tracing::warn!(
147 tx = %tx_id,
148 sequence = entry.sequence,
149 error = %e,
150 "replay statement failed"
151 );
152 }
153 }
154 }
155
156 client.close().await;
157
158 Ok(ReplaySummary {
159 statements_replayed,
160 failures,
161 elapsed_ms: start.elapsed().as_millis() as u64,
162 from: req.from,
163 to: req.to,
164 first_error,
165 })
166 }
167}
168
169fn journal_value_to_param(v: &JournalValue) -> ParamValue {
173 match v {
174 JournalValue::Null => ParamValue::Null,
175 JournalValue::Bool(b) => ParamValue::Bool(*b),
176 JournalValue::Int64(i) => ParamValue::Int(*i),
177 JournalValue::Float64(f) => ParamValue::Float(*f),
178 JournalValue::Text(s) => ParamValue::Text(s.clone()),
179 JournalValue::Bytes(b) => {
180 let mut s = String::with_capacity(2 + b.len() * 2);
181 s.push_str("\\x");
182 for byte in b {
183 s.push_str(&format!("{:02x}", byte));
184 }
185 ParamValue::Text(s)
186 }
187 JournalValue::Array(_) => ParamValue::Null,
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::backend::{tls::default_client_config, TlsMode};
195 use crate::transaction_journal::StatementType;
196 use crate::NodeId;
197 use std::time::Duration;
198 use uuid::Uuid;
199
200 fn test_template() -> BackendConfig {
201 BackendConfig {
202 host: "placeholder".into(),
203 port: 0,
204 user: "postgres".into(),
205 password: None,
206 database: None,
207 application_name: Some("helios-replay".into()),
208 tls_mode: TlsMode::Disable,
209 connect_timeout: Duration::from_millis(200),
210 query_timeout: Duration::from_millis(200),
211 tls_config: default_client_config(),
212 }
213 }
214
215 fn make_entry(sequence: u64, statement: &str, timestamp: DateTime<Utc>) -> JournalEntry {
216 JournalEntry {
217 sequence,
218 statement: statement.to_string(),
219 parameters: vec![],
220 result_checksum: None,
221 rows_affected: None,
222 timestamp,
223 statement_type: StatementType::Select,
224 duration_ms: 1,
225 }
226 }
227
228 #[tokio::test]
229 async fn test_replay_rejects_inverted_window() {
230 let journal = Arc::new(TransactionJournal::new());
231 let engine = ReplayEngine::new(journal, test_template());
232 let now = Utc::now();
233 let req = TimeTravelRequest {
234 from: now,
235 to: now - chrono::Duration::seconds(1),
236 target_host: "127.0.0.1".into(),
237 target_port: 1,
238 target_user: None,
239 target_password: None,
240 target_database: None,
241 };
242 let err = engine.replay_window(&req).await.unwrap_err();
243 assert!(matches!(err, ProxyError::Internal(_)));
244 }
245
246 #[tokio::test]
251 async fn test_replay_empty_window_still_connects() {
252 let journal = Arc::new(TransactionJournal::new());
253 let engine = ReplayEngine::new(journal, test_template());
254 let now = Utc::now();
255 let req = TimeTravelRequest {
256 from: now - chrono::Duration::hours(1),
257 to: now,
258 target_host: "127.0.0.1".into(),
259 target_port: 1, target_user: None,
261 target_password: None,
262 target_database: None,
263 };
264 let err = engine.replay_window(&req).await.unwrap_err();
265 match err {
266 ProxyError::ReplayFailed(msg) => assert!(msg.contains("connect")),
267 other => panic!("expected ReplayFailed, got {:?}", other),
268 }
269 }
270
271 #[tokio::test]
275 async fn test_entries_in_window_filters_correctly() {
276 let journal = Arc::new(TransactionJournal::new());
277 let tx_id = Uuid::new_v4();
278 let session = Uuid::new_v4();
279 let node = NodeId::new();
280
281 let base = Utc::now();
282 journal
283 .begin_transaction(tx_id, session, node, 0)
284 .await
285 .unwrap();
286
287 let _ = base; journal
293 .log_statement(tx_id, "SELECT 1".to_string(), vec![], None, None, 1)
294 .await
295 .unwrap();
296
297 let from = Utc::now() - chrono::Duration::seconds(5);
298 let to = Utc::now() + chrono::Duration::seconds(5);
299 let entries = journal.entries_in_window(from, to).await;
300 assert_eq!(entries.len(), 1, "single in-window entry");
301
302 let far_past_to = Utc::now() - chrono::Duration::hours(1);
303 let far_past_from = far_past_to - chrono::Duration::hours(1);
304 let entries = journal.entries_in_window(far_past_from, far_past_to).await;
305 assert!(entries.is_empty(), "no entries in far-past window");
306 }
307
308 #[test]
309 fn test_journal_value_to_param_matches_failover_shape() {
310 assert!(matches!(
315 journal_value_to_param(&JournalValue::Null),
316 ParamValue::Null
317 ));
318 assert!(matches!(
319 journal_value_to_param(&JournalValue::Bool(true)),
320 ParamValue::Bool(true)
321 ));
322 assert!(matches!(
323 journal_value_to_param(&JournalValue::Int64(-7)),
324 ParamValue::Int(-7)
325 ));
326 }
327
328 #[test]
334 fn test_credential_overrides_replace_template_fields() {
335 let mut cfg = test_template();
336 cfg.user = "default_user".into();
337 cfg.password = None;
338 cfg.database = None;
339
340 let req = TimeTravelRequest {
341 from: Utc::now(),
342 to: Utc::now(),
343 target_host: "h".into(),
344 target_port: 5432,
345 target_user: Some("override_user".into()),
346 target_password: Some("secret".into()),
347 target_database: Some("staging".into()),
348 };
349
350 if let Some(ref u) = req.target_user {
355 cfg.user = u.clone();
356 }
357 if let Some(ref p) = req.target_password {
358 cfg.password = Some(p.clone());
359 }
360 if let Some(ref d) = req.target_database {
361 cfg.database = Some(d.clone());
362 }
363
364 assert_eq!(cfg.user, "override_user");
365 assert_eq!(cfg.password.as_deref(), Some("secret"));
366 assert_eq!(cfg.database.as_deref(), Some("staging"));
367 }
368
369 #[test]
370 fn test_credential_overrides_none_keeps_template_fields() {
371 let mut cfg = test_template();
372 cfg.user = "default_user".into();
373 cfg.password = Some("template_pw".into());
374 cfg.database = Some("default_db".into());
375
376 let req = TimeTravelRequest {
377 from: Utc::now(),
378 to: Utc::now(),
379 target_host: "h".into(),
380 target_port: 5432,
381 target_user: None,
382 target_password: None,
383 target_database: None,
384 };
385
386 if let Some(ref u) = req.target_user {
387 cfg.user = u.clone();
388 }
389 let _ = req;
391
392 assert_eq!(cfg.user, "default_user");
393 assert_eq!(cfg.password.as_deref(), Some("template_pw"));
394 assert_eq!(cfg.database.as_deref(), Some("default_db"));
395 }
396
397 #[test]
400 fn test_replay_summary_serializes() {
401 let s = ReplaySummary {
402 statements_replayed: 5,
403 failures: 1,
404 elapsed_ms: 42,
405 from: Utc::now(),
406 to: Utc::now(),
407 first_error: Some("oops".into()),
408 };
409 let j = serde_json::to_string(&s).unwrap();
410 assert!(j.contains("\"statements_replayed\":5"));
411 assert!(j.contains("\"failures\":1"));
412 assert!(j.contains("oops"));
413 }
414}