1use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use serde::Serialize;
20use tokio::sync::mpsc;
21
22use crate::backend::types::TextValue;
23use crate::backend::{
24 tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode,
25};
26use crate::config::MirrorConfig;
27
28#[derive(Default)]
30pub struct MirrorMetrics {
31 pub enqueued: AtomicU64,
33 pub mirrored: AtomicU64,
35 pub dropped: AtomicU64,
37 pub errors: AtomicU64,
39}
40
41#[derive(Debug, Clone, Serialize)]
43pub struct MigrationStatus {
44 pub enabled: bool,
45 pub target: String,
46 pub writes_only: bool,
47 pub enqueued: u64,
48 pub mirrored: u64,
49 pub dropped: u64,
50 pub errors: u64,
51 pub lag: u64,
53 pub migration_ready: bool,
57}
58
59pub fn status(target: &str, writes_only: bool, m: &MirrorMetrics) -> MigrationStatus {
61 let enqueued = m.enqueued.load(Ordering::Relaxed);
62 let mirrored = m.mirrored.load(Ordering::Relaxed);
63 let dropped = m.dropped.load(Ordering::Relaxed);
64 let errors = m.errors.load(Ordering::Relaxed);
65 let lag = enqueued.saturating_sub(mirrored).saturating_sub(errors);
66 MigrationStatus {
67 enabled: true,
68 target: target.to_string(),
69 writes_only,
70 enqueued,
71 mirrored,
72 dropped,
73 errors,
74 lag,
75 migration_ready: lag == 0 && dropped == 0,
76 }
77}
78
79pub struct MirrorHandle {
81 tx: mpsc::Sender<String>,
82 sample_rate: f64,
83 writes_only: bool,
84 target: String,
85 pub metrics: Arc<MirrorMetrics>,
86}
87
88impl MirrorHandle {
89 pub fn status(&self) -> MigrationStatus {
91 status(&self.target, self.writes_only, &self.metrics)
92 }
93 pub fn target(&self) -> &str {
94 &self.target
95 }
96 pub fn writes_only(&self) -> bool {
97 self.writes_only
98 }
99}
100
101impl MirrorHandle {
102 pub fn offer(&self, sql: &str, is_write: bool) {
106 if self.writes_only && !is_write {
107 return;
108 }
109 if self.sample_rate < 1.0 {
110 use rand::Rng;
112 if rand::thread_rng().gen::<f64>() >= self.sample_rate {
113 return;
114 }
115 }
116 match self.tx.try_send(sql.to_string()) {
117 Ok(()) => {
118 self.metrics.enqueued.fetch_add(1, Ordering::Relaxed);
119 }
120 Err(mpsc::error::TrySendError::Full(_)) => {
121 self.metrics.dropped.fetch_add(1, Ordering::Relaxed);
122 }
123 Err(mpsc::error::TrySendError::Closed(_)) => {}
124 }
125 }
126}
127
128pub fn spawn(config: MirrorConfig) -> MirrorHandle {
130 let (tx, rx) = mpsc::channel::<String>(config.queue_size.max(1));
131 let metrics = Arc::new(MirrorMetrics::default());
132 let handle = MirrorHandle {
133 tx,
134 sample_rate: config.sample_rate.clamp(0.0, 1.0),
135 writes_only: config.writes_only,
136 target: format!("{}:{}", config.backend_host, config.backend_port),
137 metrics: metrics.clone(),
138 };
139 tokio::spawn(worker(config, rx, metrics));
140 handle
141}
142
143#[derive(Debug, Clone)]
147pub struct CutoverTarget {
148 pub addr: String,
149 pub user: String,
150 pub password: Option<String>,
151 pub database: Option<String>,
152}
153
154#[derive(Debug, Clone, Serialize)]
156pub struct TableSnapshot {
157 pub table: String,
158 pub source_rows: u64,
159 pub copied: u64,
160}
161
162fn backend_cfg(
163 host: &str,
164 port: u16,
165 user: &str,
166 pass: Option<&str>,
167 db: Option<&str>,
168 app: &str,
169) -> BackendConfig {
170 BackendConfig {
171 host: host.to_string(),
172 port,
173 user: user.to_string(),
174 password: pass.map(|s| s.to_string()),
175 database: db.map(|s| s.to_string()),
176 application_name: Some(app.to_string()),
177 tls_mode: TlsMode::Disable,
178 connect_timeout: Duration::from_secs(5),
179 query_timeout: Duration::from_secs(60),
180 tls_config: default_client_config(),
181 }
182}
183
184fn quote_ident(name: &str) -> String {
185 format!("\"{}\"", name.replace('"', "\"\""))
186}
187
188fn encode_copy_field(out: &mut Vec<u8>, v: &TextValue) {
193 match v {
194 TextValue::Null => out.extend_from_slice(b"\\N"),
195 TextValue::Text(s) => {
196 for &b in s.as_bytes() {
197 match b {
198 b'\\' => out.extend_from_slice(b"\\\\"),
199 b'\t' => out.extend_from_slice(b"\\t"),
200 b'\n' => out.extend_from_slice(b"\\n"),
201 b'\r' => out.extend_from_slice(b"\\r"),
202 _ => out.push(b),
203 }
204 }
205 }
206 }
207}
208
209fn oid_type(oid: u32) -> &'static str {
212 match oid {
213 16 => "boolean",
214 20 => "bigint",
215 21 | 23 => "integer",
216 700 => "real",
217 701 => "double precision",
218 1700 => "numeric",
219 1082 => "date",
220 1114 | 1184 => "timestamp",
221 2950 => "uuid",
222 114 | 3802 => "jsonb",
223 _ => "text",
224 }
225}
226
227pub async fn snapshot_tables(
233 cfg: &MirrorConfig,
234 tables: &[String],
235) -> Result<Vec<TableSnapshot>, String> {
236 let src_cfg = backend_cfg(
237 &cfg.source_host,
238 cfg.source_port,
239 &cfg.source_user,
240 cfg.source_password.as_deref(),
241 cfg.source_database.as_deref(),
242 "heliosproxy-snapshot-src",
243 );
244 let tgt_cfg = backend_cfg(
245 &cfg.backend_host,
246 cfg.backend_port,
247 &cfg.backend_user,
248 cfg.backend_password.as_deref(),
249 cfg.backend_database.as_deref(),
250 "heliosproxy-snapshot-tgt",
251 );
252 let mut src = BackendClient::connect(&src_cfg)
253 .await
254 .map_err(|e| format!("source connect: {}", e))?;
255 let mut tgt = BackendClient::connect(&tgt_cfg)
256 .await
257 .map_err(|e| format!("target connect: {}", e))?;
258
259 let mut blocked: Vec<String> = Vec::new();
265 for table in tables {
266 let qt = quote_ident(table);
267 if let Ok(res) = tgt
268 .simple_query(&format!("SELECT 1 FROM {} LIMIT 1", qt))
269 .await
270 {
271 if !res.rows.is_empty() {
272 blocked.push(table.clone());
273 }
274 }
275 }
276 if !blocked.is_empty() {
277 src.close().await;
278 tgt.close().await;
279 return Err(format!(
280 "refusing snapshot: target table(s) already contain rows (snapshot would \
281 duplicate): {}. Snapshot into an empty target, or remove the rows first.",
282 blocked.join(", ")
283 ));
284 }
285
286 let mut report = Vec::new();
287 for table in tables {
288 let qt = quote_ident(table);
289 let res = src
290 .simple_query(&format!("SELECT * FROM {}", qt))
291 .await
292 .map_err(|e| format!("read {}: {}", table, e))?;
293
294 let cols_ddl: Vec<String> = res
296 .columns
297 .iter()
298 .map(|c| format!("{} {}", quote_ident(&c.name), oid_type(c.type_oid)))
299 .collect();
300 let create = format!(
301 "CREATE TABLE IF NOT EXISTS {} ({})",
302 qt,
303 cols_ddl.join(", ")
304 );
305 tgt.execute(&create)
306 .await
307 .map_err(|e| format!("create {} on target: {}", table, e))?;
308
309 let col_list = res
310 .columns
311 .iter()
312 .map(|c| quote_ident(&c.name))
313 .collect::<Vec<_>>()
314 .join(", ");
315
316 let use_copy = std::env::var("HELIOS_SNAPSHOT_USE_COPY")
319 .map(|v| v != "0")
320 .unwrap_or(true);
321
322 let mut copied: Option<u64> = None;
323 if use_copy {
324 let mut copy_buf: Vec<u8> = Vec::new();
325 for row in &res.rows {
326 for (i, v) in row.iter().enumerate() {
327 if i > 0 {
328 copy_buf.push(b'\t');
329 }
330 encode_copy_field(&mut copy_buf, v);
331 }
332 copy_buf.push(b'\n');
333 }
334 let copy_sql = format!("COPY {} ({}) FROM STDIN", qt, col_list);
335 match tgt.copy_in(©_sql, ©_buf).await {
336 Ok(n) => copied = Some(n),
337 Err(e) => {
338 tracing::warn!(
341 table = %table,
342 error = %e,
343 "COPY snapshot failed; falling back to per-row INSERT"
344 );
345 }
346 }
347 }
348
349 let copied = match copied {
351 Some(n) => n,
352 None => {
353 let placeholders = (1..=res.columns.len())
354 .map(|i| format!("${}", i))
355 .collect::<Vec<_>>()
356 .join(", ");
357 let insert = format!(
358 "INSERT INTO {} ({}) VALUES ({})",
359 qt, col_list, placeholders
360 );
361 let mut copied = 0u64;
362 for row in &res.rows {
363 let params: Vec<ParamValue> = row
364 .iter()
365 .map(|v| match v {
366 TextValue::Null => ParamValue::Null,
367 TextValue::Text(s) => ParamValue::Text(s.clone()),
368 })
369 .collect();
370 tgt.query_with_params(&insert, ¶ms)
371 .await
372 .map_err(|e| format!("insert into {}: {}", table, e))?;
373 copied += 1;
374 }
375 copied
376 }
377 };
378 report.push(TableSnapshot {
379 table: table.clone(),
380 source_rows: res.rows.len() as u64,
381 copied,
382 });
383 }
384 src.close().await;
385 tgt.close().await;
386 Ok(report)
387}
388
389async fn worker(config: MirrorConfig, mut rx: mpsc::Receiver<String>, metrics: Arc<MirrorMetrics>) {
390 let bcfg = BackendConfig {
391 host: config.backend_host.clone(),
392 port: config.backend_port,
393 user: config.backend_user.clone(),
394 password: config.backend_password.clone(),
395 database: config.backend_database.clone(),
396 application_name: Some("heliosproxy-mirror".to_string()),
397 tls_mode: TlsMode::Disable,
398 connect_timeout: Duration::from_secs(5),
399 query_timeout: Duration::from_secs(30),
400 tls_config: default_client_config(),
401 };
402 tracing::info!(target = %bcfg.address(), "traffic mirror worker started");
403
404 let mut client: Option<BackendClient> = None;
405 while let Some(sql) = rx.recv().await {
406 if client.is_none() {
408 match BackendClient::connect(&bcfg).await {
409 Ok(c) => client = Some(c),
410 Err(e) => {
411 metrics.errors.fetch_add(1, Ordering::Relaxed);
412 tracing::debug!(error = %e, "mirror connect failed; dropping statement");
413 continue;
414 }
415 }
416 }
417 let c = client.as_mut().unwrap();
418 if let Err(e) = c.simple_query(&sql).await {
419 metrics.errors.fetch_add(1, Ordering::Relaxed);
420 tracing::debug!(error = %e, "mirror apply failed; will reconnect");
421 if let Some(c) = client.take() {
423 c.close().await;
424 }
425 } else {
426 metrics.mirrored.fetch_add(1, Ordering::Relaxed);
427 }
428 }
429 tracing::info!("traffic mirror worker stopped");
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 fn enc(v: &TextValue) -> String {
437 let mut out = Vec::new();
438 encode_copy_field(&mut out, v);
439 String::from_utf8(out).unwrap()
440 }
441
442 #[test]
443 fn copy_field_encoding() {
444 assert_eq!(enc(&TextValue::Null), "\\N");
446 assert_eq!(enc(&TextValue::Text(String::new())), "");
447 assert_eq!(enc(&TextValue::Text("alice".into())), "alice");
449 assert_eq!(enc(&TextValue::Text("a\tb".into())), "a\\tb");
452 assert_eq!(enc(&TextValue::Text("a\nb".into())), "a\\nb");
453 assert_eq!(enc(&TextValue::Text("a\rb".into())), "a\\rb");
454 assert_eq!(enc(&TextValue::Text("a\\b".into())), "a\\\\b");
455 assert_eq!(enc(&TextValue::Text("\\N".into())), "\\\\N");
457 }
458}