1use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use serde::Serialize;
20use tokio::sync::mpsc;
21
22use crate::backend::types::TextValue;
23use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode};
24use crate::config::MirrorConfig;
25
26#[derive(Default)]
28pub struct MirrorMetrics {
29 pub enqueued: AtomicU64,
31 pub mirrored: AtomicU64,
33 pub dropped: AtomicU64,
35 pub errors: AtomicU64,
37}
38
39#[derive(Debug, Clone, Serialize)]
41pub struct MigrationStatus {
42 pub enabled: bool,
43 pub target: String,
44 pub writes_only: bool,
45 pub enqueued: u64,
46 pub mirrored: u64,
47 pub dropped: u64,
48 pub errors: u64,
49 pub lag: u64,
51 pub migration_ready: bool,
55}
56
57pub fn status(target: &str, writes_only: bool, m: &MirrorMetrics) -> MigrationStatus {
59 let enqueued = m.enqueued.load(Ordering::Relaxed);
60 let mirrored = m.mirrored.load(Ordering::Relaxed);
61 let dropped = m.dropped.load(Ordering::Relaxed);
62 let errors = m.errors.load(Ordering::Relaxed);
63 let lag = enqueued.saturating_sub(mirrored).saturating_sub(errors);
64 MigrationStatus {
65 enabled: true,
66 target: target.to_string(),
67 writes_only,
68 enqueued,
69 mirrored,
70 dropped,
71 errors,
72 lag,
73 migration_ready: lag == 0 && dropped == 0,
74 }
75}
76
77pub struct MirrorHandle {
79 tx: mpsc::Sender<String>,
80 sample_rate: f64,
81 writes_only: bool,
82 target: String,
83 pub metrics: Arc<MirrorMetrics>,
84}
85
86impl MirrorHandle {
87 pub fn status(&self) -> MigrationStatus {
89 status(&self.target, self.writes_only, &self.metrics)
90 }
91 pub fn target(&self) -> &str {
92 &self.target
93 }
94 pub fn writes_only(&self) -> bool {
95 self.writes_only
96 }
97}
98
99impl MirrorHandle {
100 pub fn offer(&self, sql: &str, is_write: bool) {
104 if self.writes_only && !is_write {
105 return;
106 }
107 if self.sample_rate < 1.0 {
108 use rand::Rng;
110 if rand::thread_rng().gen::<f64>() >= self.sample_rate {
111 return;
112 }
113 }
114 match self.tx.try_send(sql.to_string()) {
115 Ok(()) => {
116 self.metrics.enqueued.fetch_add(1, Ordering::Relaxed);
117 }
118 Err(mpsc::error::TrySendError::Full(_)) => {
119 self.metrics.dropped.fetch_add(1, Ordering::Relaxed);
120 }
121 Err(mpsc::error::TrySendError::Closed(_)) => {}
122 }
123 }
124}
125
126pub fn spawn(config: MirrorConfig) -> MirrorHandle {
128 let (tx, rx) = mpsc::channel::<String>(config.queue_size.max(1));
129 let metrics = Arc::new(MirrorMetrics::default());
130 let handle = MirrorHandle {
131 tx,
132 sample_rate: config.sample_rate.clamp(0.0, 1.0),
133 writes_only: config.writes_only,
134 target: format!("{}:{}", config.backend_host, config.backend_port),
135 metrics: metrics.clone(),
136 };
137 tokio::spawn(worker(config, rx, metrics));
138 handle
139}
140
141#[derive(Debug, Clone)]
145pub struct CutoverTarget {
146 pub addr: String,
147 pub user: String,
148 pub password: Option<String>,
149 pub database: Option<String>,
150}
151
152#[derive(Debug, Clone, Serialize)]
154pub struct TableSnapshot {
155 pub table: String,
156 pub source_rows: u64,
157 pub copied: u64,
158}
159
160fn backend_cfg(host: &str, port: u16, user: &str, pass: Option<&str>, db: Option<&str>, app: &str) -> BackendConfig {
161 BackendConfig {
162 host: host.to_string(),
163 port,
164 user: user.to_string(),
165 password: pass.map(|s| s.to_string()),
166 database: db.map(|s| s.to_string()),
167 application_name: Some(app.to_string()),
168 tls_mode: TlsMode::Disable,
169 connect_timeout: Duration::from_secs(5),
170 query_timeout: Duration::from_secs(60),
171 tls_config: default_client_config(),
172 }
173}
174
175fn quote_ident(name: &str) -> String {
176 format!("\"{}\"", name.replace('"', "\"\""))
177}
178
179fn encode_copy_field(out: &mut Vec<u8>, v: &TextValue) {
184 match v {
185 TextValue::Null => out.extend_from_slice(b"\\N"),
186 TextValue::Text(s) => {
187 for &b in s.as_bytes() {
188 match b {
189 b'\\' => out.extend_from_slice(b"\\\\"),
190 b'\t' => out.extend_from_slice(b"\\t"),
191 b'\n' => out.extend_from_slice(b"\\n"),
192 b'\r' => out.extend_from_slice(b"\\r"),
193 _ => out.push(b),
194 }
195 }
196 }
197 }
198}
199
200fn oid_type(oid: u32) -> &'static str {
203 match oid {
204 16 => "boolean",
205 20 => "bigint",
206 21 | 23 => "integer",
207 700 => "real",
208 701 => "double precision",
209 1700 => "numeric",
210 1082 => "date",
211 1114 | 1184 => "timestamp",
212 2950 => "uuid",
213 114 | 3802 => "jsonb",
214 _ => "text",
215 }
216}
217
218pub async fn snapshot_tables(cfg: &MirrorConfig, tables: &[String]) -> Result<Vec<TableSnapshot>, String> {
224 let src_cfg = backend_cfg(
225 &cfg.source_host, cfg.source_port, &cfg.source_user,
226 cfg.source_password.as_deref(), cfg.source_database.as_deref(), "heliosproxy-snapshot-src",
227 );
228 let tgt_cfg = backend_cfg(
229 &cfg.backend_host, cfg.backend_port, &cfg.backend_user,
230 cfg.backend_password.as_deref(), cfg.backend_database.as_deref(), "heliosproxy-snapshot-tgt",
231 );
232 let mut src = BackendClient::connect(&src_cfg).await.map_err(|e| format!("source connect: {}", e))?;
233 let mut tgt = BackendClient::connect(&tgt_cfg).await.map_err(|e| format!("target connect: {}", e))?;
234
235 let mut blocked: Vec<String> = Vec::new();
241 for table in tables {
242 let qt = quote_ident(table);
243 if let Ok(res) = tgt.simple_query(&format!("SELECT 1 FROM {} LIMIT 1", qt)).await {
244 if !res.rows.is_empty() {
245 blocked.push(table.clone());
246 }
247 }
248 }
249 if !blocked.is_empty() {
250 src.close().await;
251 tgt.close().await;
252 return Err(format!(
253 "refusing snapshot: target table(s) already contain rows (snapshot would \
254 duplicate): {}. Snapshot into an empty target, or remove the rows first.",
255 blocked.join(", ")
256 ));
257 }
258
259 let mut report = Vec::new();
260 for table in tables {
261 let qt = quote_ident(table);
262 let res = src
263 .simple_query(&format!("SELECT * FROM {}", qt))
264 .await
265 .map_err(|e| format!("read {}: {}", table, e))?;
266
267 let cols_ddl: Vec<String> = res
269 .columns
270 .iter()
271 .map(|c| format!("{} {}", quote_ident(&c.name), oid_type(c.type_oid)))
272 .collect();
273 let create = format!("CREATE TABLE IF NOT EXISTS {} ({})", qt, cols_ddl.join(", "));
274 tgt.execute(&create).await.map_err(|e| format!("create {} on target: {}", table, e))?;
275
276 let col_list = res.columns.iter().map(|c| quote_ident(&c.name)).collect::<Vec<_>>().join(", ");
277
278 let use_copy =
281 std::env::var("HELIOS_SNAPSHOT_USE_COPY").map(|v| v != "0").unwrap_or(true);
282
283 let mut copied: Option<u64> = None;
284 if use_copy {
285 let mut copy_buf: Vec<u8> = Vec::new();
286 for row in &res.rows {
287 for (i, v) in row.iter().enumerate() {
288 if i > 0 {
289 copy_buf.push(b'\t');
290 }
291 encode_copy_field(&mut copy_buf, v);
292 }
293 copy_buf.push(b'\n');
294 }
295 let copy_sql = format!("COPY {} ({}) FROM STDIN", qt, col_list);
296 match tgt.copy_in(©_sql, ©_buf).await {
297 Ok(n) => copied = Some(n),
298 Err(e) => {
299 tracing::warn!(
302 table = %table,
303 error = %e,
304 "COPY snapshot failed; falling back to per-row INSERT"
305 );
306 }
307 }
308 }
309
310 let copied = match copied {
312 Some(n) => n,
313 None => {
314 let placeholders =
315 (1..=res.columns.len()).map(|i| format!("${}", i)).collect::<Vec<_>>().join(", ");
316 let insert = format!("INSERT INTO {} ({}) VALUES ({})", qt, col_list, placeholders);
317 let mut copied = 0u64;
318 for row in &res.rows {
319 let params: Vec<ParamValue> = row
320 .iter()
321 .map(|v| match v {
322 TextValue::Null => ParamValue::Null,
323 TextValue::Text(s) => ParamValue::Text(s.clone()),
324 })
325 .collect();
326 tgt.query_with_params(&insert, ¶ms)
327 .await
328 .map_err(|e| format!("insert into {}: {}", table, e))?;
329 copied += 1;
330 }
331 copied
332 }
333 };
334 report.push(TableSnapshot { table: table.clone(), source_rows: res.rows.len() as u64, copied });
335 }
336 src.close().await;
337 tgt.close().await;
338 Ok(report)
339}
340
341async fn worker(config: MirrorConfig, mut rx: mpsc::Receiver<String>, metrics: Arc<MirrorMetrics>) {
342 let bcfg = BackendConfig {
343 host: config.backend_host.clone(),
344 port: config.backend_port,
345 user: config.backend_user.clone(),
346 password: config.backend_password.clone(),
347 database: config.backend_database.clone(),
348 application_name: Some("heliosproxy-mirror".to_string()),
349 tls_mode: TlsMode::Disable,
350 connect_timeout: Duration::from_secs(5),
351 query_timeout: Duration::from_secs(30),
352 tls_config: default_client_config(),
353 };
354 tracing::info!(target = %bcfg.address(), "traffic mirror worker started");
355
356 let mut client: Option<BackendClient> = None;
357 while let Some(sql) = rx.recv().await {
358 if client.is_none() {
360 match BackendClient::connect(&bcfg).await {
361 Ok(c) => client = Some(c),
362 Err(e) => {
363 metrics.errors.fetch_add(1, Ordering::Relaxed);
364 tracing::debug!(error = %e, "mirror connect failed; dropping statement");
365 continue;
366 }
367 }
368 }
369 let c = client.as_mut().unwrap();
370 if let Err(e) = c.simple_query(&sql).await {
371 metrics.errors.fetch_add(1, Ordering::Relaxed);
372 tracing::debug!(error = %e, "mirror apply failed; will reconnect");
373 if let Some(c) = client.take() {
375 c.close().await;
376 }
377 } else {
378 metrics.mirrored.fetch_add(1, Ordering::Relaxed);
379 }
380 }
381 tracing::info!("traffic mirror worker stopped");
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 fn enc(v: &TextValue) -> String {
389 let mut out = Vec::new();
390 encode_copy_field(&mut out, v);
391 String::from_utf8(out).unwrap()
392 }
393
394 #[test]
395 fn copy_field_encoding() {
396 assert_eq!(enc(&TextValue::Null), "\\N");
398 assert_eq!(enc(&TextValue::Text(String::new())), "");
399 assert_eq!(enc(&TextValue::Text("alice".into())), "alice");
401 assert_eq!(enc(&TextValue::Text("a\tb".into())), "a\\tb");
404 assert_eq!(enc(&TextValue::Text("a\nb".into())), "a\\nb");
405 assert_eq!(enc(&TextValue::Text("a\rb".into())), "a\\rb");
406 assert_eq!(enc(&TextValue::Text("a\\b".into())), "a\\\\b");
407 assert_eq!(enc(&TextValue::Text("\\N".into())), "\\\\N");
409 }
410}