1use crate::backend::{BackendClient, BackendConfig, ParamValue};
15use crate::transaction_journal::{JournalEntry, JournalValue, TransactionJournal};
16use crate::{ProxyError, Result};
17use chrono::{DateTime, Utc};
18use std::sync::Arc;
19use uuid::Uuid;
20
21#[derive(Debug, Clone)]
23pub struct TimeTravelRequest {
24 pub from: DateTime<Utc>,
26 pub to: DateTime<Utc>,
28 pub target_host: String,
30 pub target_port: u16,
32 pub target_user: Option<String>,
36 pub target_password: Option<String>,
40 pub target_database: Option<String>,
42}
43
44#[derive(Debug, Clone, serde::Serialize)]
46pub struct ReplaySummary {
47 pub statements_replayed: u64,
49 pub failures: u64,
51 pub elapsed_ms: u64,
53 #[serde(with = "chrono::serde::ts_seconds")]
55 pub from: DateTime<Utc>,
56 #[serde(with = "chrono::serde::ts_seconds")]
57 pub to: DateTime<Utc>,
58 pub first_error: Option<String>,
61}
62
63pub struct ReplayEngine {
65 journal: Arc<TransactionJournal>,
66 backend_template: BackendConfig,
68}
69
70impl ReplayEngine {
71 pub fn new(journal: Arc<TransactionJournal>, backend_template: BackendConfig) -> Self {
72 Self {
73 journal,
74 backend_template,
75 }
76 }
77
78 pub async fn replay_window(
87 &self,
88 req: &TimeTravelRequest,
89 ) -> Result<ReplaySummary> {
90 if req.from > req.to {
91 return Err(ProxyError::Internal(
92 "replay window: from > to".to_string(),
93 ));
94 }
95
96 let entries = self.journal.entries_in_window(req.from, req.to).await;
97 let total = entries.len();
98 tracing::info!(
99 total_entries = total,
100 from = %req.from,
101 to = %req.to,
102 target = %format!("{}:{}", req.target_host, req.target_port),
103 "starting time-travel replay"
104 );
105
106 let mut cfg = self.backend_template.clone();
107 cfg.host = req.target_host.clone();
108 cfg.port = req.target_port;
109 if let Some(ref u) = req.target_user {
110 cfg.user = u.clone();
111 }
112 if let Some(ref p) = req.target_password {
113 cfg.password = Some(p.clone());
114 }
115 if let Some(ref d) = req.target_database {
116 cfg.database = Some(d.clone());
117 }
118
119 let start = std::time::Instant::now();
120 let mut client = BackendClient::connect(&cfg).await.map_err(|e| {
121 ProxyError::ReplayFailed(format!("connect to target: {}", e))
122 })?;
123
124 let mut statements_replayed: u64 = 0;
125 let mut failures: u64 = 0;
126 let mut first_error: Option<String> = None;
127
128 for (tx_id, entry) in entries {
129 let params: Vec<ParamValue> =
130 entry.parameters.iter().map(journal_value_to_param).collect();
131
132 let outcome = if params.is_empty() {
133 client.simple_query(&entry.statement).await
134 } else {
135 client.query_with_params(&entry.statement, ¶ms).await
136 };
137
138 match outcome {
139 Ok(_) => {
140 statements_replayed += 1;
141 }
142 Err(e) => {
143 failures += 1;
144 if first_error.is_none() {
145 first_error = Some(format!(
146 "tx {} seq {}: {}",
147 tx_id, entry.sequence, e
148 ));
149 }
150 tracing::warn!(
151 tx = %tx_id,
152 sequence = entry.sequence,
153 error = %e,
154 "replay statement failed"
155 );
156 }
157 }
158 }
159
160 client.close().await;
161
162 Ok(ReplaySummary {
163 statements_replayed,
164 failures,
165 elapsed_ms: start.elapsed().as_millis() as u64,
166 from: req.from,
167 to: req.to,
168 first_error,
169 })
170 }
171}
172
173fn journal_value_to_param(v: &JournalValue) -> ParamValue {
177 match v {
178 JournalValue::Null => ParamValue::Null,
179 JournalValue::Bool(b) => ParamValue::Bool(*b),
180 JournalValue::Int64(i) => ParamValue::Int(*i),
181 JournalValue::Float64(f) => ParamValue::Float(*f),
182 JournalValue::Text(s) => ParamValue::Text(s.clone()),
183 JournalValue::Bytes(b) => {
184 let mut s = String::with_capacity(2 + b.len() * 2);
185 s.push_str("\\x");
186 for byte in b {
187 s.push_str(&format!("{:02x}", byte));
188 }
189 ParamValue::Text(s)
190 }
191 JournalValue::Array(_) => ParamValue::Null,
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::backend::{tls::default_client_config, TlsMode};
199 use crate::transaction_journal::StatementType;
200 use crate::NodeId;
201 use std::time::Duration;
202
203 fn test_template() -> BackendConfig {
204 BackendConfig {
205 host: "placeholder".into(),
206 port: 0,
207 user: "postgres".into(),
208 password: None,
209 database: None,
210 application_name: Some("helios-replay".into()),
211 tls_mode: TlsMode::Disable,
212 connect_timeout: Duration::from_millis(200),
213 query_timeout: Duration::from_millis(200),
214 tls_config: default_client_config(),
215 }
216 }
217
218 fn make_entry(
219 sequence: u64,
220 statement: &str,
221 timestamp: DateTime<Utc>,
222 ) -> JournalEntry {
223 JournalEntry {
224 sequence,
225 statement: statement.to_string(),
226 parameters: vec![],
227 result_checksum: None,
228 rows_affected: None,
229 timestamp,
230 statement_type: StatementType::Select,
231 duration_ms: 1,
232 }
233 }
234
235 #[tokio::test]
236 async fn test_replay_rejects_inverted_window() {
237 let journal = Arc::new(TransactionJournal::new());
238 let engine = ReplayEngine::new(journal, test_template());
239 let now = Utc::now();
240 let req = TimeTravelRequest {
241 from: now,
242 to: now - chrono::Duration::seconds(1),
243 target_host: "127.0.0.1".into(),
244 target_port: 1,
245 target_user: None,
246 target_password: None,
247 target_database: None,
248 };
249 let err = engine.replay_window(&req).await.unwrap_err();
250 assert!(matches!(err, ProxyError::Internal(_)));
251 }
252
253 #[tokio::test]
258 async fn test_replay_empty_window_still_connects() {
259 let journal = Arc::new(TransactionJournal::new());
260 let engine = ReplayEngine::new(journal, test_template());
261 let now = Utc::now();
262 let req = TimeTravelRequest {
263 from: now - chrono::Duration::hours(1),
264 to: now,
265 target_host: "127.0.0.1".into(),
266 target_port: 1, target_user: None,
268 target_password: None,
269 target_database: None,
270 };
271 let err = engine.replay_window(&req).await.unwrap_err();
272 match err {
273 ProxyError::ReplayFailed(msg) => assert!(msg.contains("connect")),
274 other => panic!("expected ReplayFailed, got {:?}", other),
275 }
276 }
277
278 #[tokio::test]
282 async fn test_entries_in_window_filters_correctly() {
283 let journal = Arc::new(TransactionJournal::new());
284 let tx_id = Uuid::new_v4();
285 let session = Uuid::new_v4();
286 let node = NodeId::new();
287
288 let base = Utc::now();
289 journal
290 .begin_transaction(tx_id, session, node, 0)
291 .await
292 .unwrap();
293
294 let _ = base; journal
300 .log_statement(
301 tx_id,
302 "SELECT 1".to_string(),
303 vec![],
304 None,
305 None,
306 1,
307 )
308 .await
309 .unwrap();
310
311 let from = Utc::now() - chrono::Duration::seconds(5);
312 let to = Utc::now() + chrono::Duration::seconds(5);
313 let entries = journal.entries_in_window(from, to).await;
314 assert_eq!(entries.len(), 1, "single in-window entry");
315
316 let far_past_to = Utc::now() - chrono::Duration::hours(1);
317 let far_past_from = far_past_to - chrono::Duration::hours(1);
318 let entries = journal.entries_in_window(far_past_from, far_past_to).await;
319 assert!(entries.is_empty(), "no entries in far-past window");
320 }
321
322 #[test]
323 fn test_journal_value_to_param_matches_failover_shape() {
324 assert!(matches!(
329 journal_value_to_param(&JournalValue::Null),
330 ParamValue::Null
331 ));
332 assert!(matches!(
333 journal_value_to_param(&JournalValue::Bool(true)),
334 ParamValue::Bool(true)
335 ));
336 assert!(matches!(
337 journal_value_to_param(&JournalValue::Int64(-7)),
338 ParamValue::Int(-7)
339 ));
340 }
341
342 #[test]
348 fn test_credential_overrides_replace_template_fields() {
349 let mut cfg = test_template();
350 cfg.user = "default_user".into();
351 cfg.password = None;
352 cfg.database = None;
353
354 let req = TimeTravelRequest {
355 from: Utc::now(),
356 to: Utc::now(),
357 target_host: "h".into(),
358 target_port: 5432,
359 target_user: Some("override_user".into()),
360 target_password: Some("secret".into()),
361 target_database: Some("staging".into()),
362 };
363
364 if let Some(ref u) = req.target_user {
369 cfg.user = u.clone();
370 }
371 if let Some(ref p) = req.target_password {
372 cfg.password = Some(p.clone());
373 }
374 if let Some(ref d) = req.target_database {
375 cfg.database = Some(d.clone());
376 }
377
378 assert_eq!(cfg.user, "override_user");
379 assert_eq!(cfg.password.as_deref(), Some("secret"));
380 assert_eq!(cfg.database.as_deref(), Some("staging"));
381 }
382
383 #[test]
384 fn test_credential_overrides_none_keeps_template_fields() {
385 let mut cfg = test_template();
386 cfg.user = "default_user".into();
387 cfg.password = Some("template_pw".into());
388 cfg.database = Some("default_db".into());
389
390 let req = TimeTravelRequest {
391 from: Utc::now(),
392 to: Utc::now(),
393 target_host: "h".into(),
394 target_port: 5432,
395 target_user: None,
396 target_password: None,
397 target_database: None,
398 };
399
400 if let Some(ref u) = req.target_user {
401 cfg.user = u.clone();
402 }
403 let _ = req;
405
406 assert_eq!(cfg.user, "default_user");
407 assert_eq!(cfg.password.as_deref(), Some("template_pw"));
408 assert_eq!(cfg.database.as_deref(), Some("default_db"));
409 }
410
411 #[test]
414 fn test_replay_summary_serializes() {
415 let s = ReplaySummary {
416 statements_replayed: 5,
417 failures: 1,
418 elapsed_ms: 42,
419 from: Utc::now(),
420 to: Utc::now(),
421 first_error: Some("oops".into()),
422 };
423 let j = serde_json::to_string(&s).unwrap();
424 assert!(j.contains("\"statements_replayed\":5"));
425 assert!(j.contains("\"failures\":1"));
426 assert!(j.contains("oops"));
427 }
428}