Skip to main content

heliosdb_proxy/
mirror.rs

1//! Continuous traffic mirroring.
2//!
3//! Replays a sampled share of live (simple-query) write statements to a
4//! secondary backend, **asynchronously and off the client hot path**: the
5//! data path does a non-blocking `try_send` into a bounded queue and moves
6//! on; a background worker drains the queue and applies each statement to the
7//! mirror backend. When the queue is full, statements are dropped (and
8//! counted) rather than slowing the client — mirroring is best-effort.
9//!
10//! This is the on-ramp to the PG->Nano migration mirror (Batch G2): point the
11//! mirror at a HeliosDB-Nano instance and its write set tracks the primary.
12//! (Result diffing for blue/green validation already lives in
13//! `shadow_execute`; this module is the continuous write tail.)
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use serde::Serialize;
20use tokio::sync::mpsc;
21
22use crate::backend::types::TextValue;
23use crate::backend::{
24    tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode,
25};
26use crate::config::MirrorConfig;
27
28/// Counters surfaced for observability.
29#[derive(Default)]
30pub struct MirrorMetrics {
31    /// Statements accepted into the queue.
32    pub enqueued: AtomicU64,
33    /// Statements successfully applied to the mirror backend.
34    pub mirrored: AtomicU64,
35    /// Statements dropped because the queue was full.
36    pub dropped: AtomicU64,
37    /// Apply/connect failures.
38    pub errors: AtomicU64,
39}
40
41/// Operator-facing migration status (served at `/api/migration/status`).
42#[derive(Debug, Clone, Serialize)]
43pub struct MigrationStatus {
44    pub enabled: bool,
45    pub target: String,
46    pub writes_only: bool,
47    pub enqueued: u64,
48    pub mirrored: u64,
49    pub dropped: u64,
50    pub errors: u64,
51    /// Statements accepted but not yet applied (queue backlog).
52    pub lag: u64,
53    /// True when the mirror is enabled, the backlog is drained, and nothing
54    /// has been dropped — i.e. the secondary is caught up and a cutover is
55    /// safe with respect to the mirrored write set.
56    pub migration_ready: bool,
57}
58
59/// Compute a migration status snapshot from the live counters.
60pub fn status(target: &str, writes_only: bool, m: &MirrorMetrics) -> MigrationStatus {
61    let enqueued = m.enqueued.load(Ordering::Relaxed);
62    let mirrored = m.mirrored.load(Ordering::Relaxed);
63    let dropped = m.dropped.load(Ordering::Relaxed);
64    let errors = m.errors.load(Ordering::Relaxed);
65    let lag = enqueued.saturating_sub(mirrored).saturating_sub(errors);
66    MigrationStatus {
67        enabled: true,
68        target: target.to_string(),
69        writes_only,
70        enqueued,
71        mirrored,
72        dropped,
73        errors,
74        lag,
75        migration_ready: lag == 0 && dropped == 0,
76    }
77}
78
79/// Handle held by the server: a bounded sender plus the sampling policy.
80pub struct MirrorHandle {
81    tx: mpsc::Sender<String>,
82    sample_rate: f64,
83    writes_only: bool,
84    target: String,
85    pub metrics: Arc<MirrorMetrics>,
86}
87
88impl MirrorHandle {
89    /// A snapshot of migration status for the admin API.
90    pub fn status(&self) -> MigrationStatus {
91        status(&self.target, self.writes_only, &self.metrics)
92    }
93    pub fn target(&self) -> &str {
94        &self.target
95    }
96    pub fn writes_only(&self) -> bool {
97        self.writes_only
98    }
99}
100
101impl MirrorHandle {
102    /// Offer one statement to the mirror. `is_write` is the data path's
103    /// already-computed verb classification (avoids re-parsing). Non-blocking:
104    /// drops (and counts) when the queue is full. Returns immediately.
105    pub fn offer(&self, sql: &str, is_write: bool) {
106        if self.writes_only && !is_write {
107            return;
108        }
109        if self.sample_rate < 1.0 {
110            // Cheap per-call sample without locking a shared RNG.
111            use rand::Rng;
112            if rand::thread_rng().gen::<f64>() >= self.sample_rate {
113                return;
114            }
115        }
116        match self.tx.try_send(sql.to_string()) {
117            Ok(()) => {
118                self.metrics.enqueued.fetch_add(1, Ordering::Relaxed);
119            }
120            Err(mpsc::error::TrySendError::Full(_)) => {
121                self.metrics.dropped.fetch_add(1, Ordering::Relaxed);
122            }
123            Err(mpsc::error::TrySendError::Closed(_)) => {}
124        }
125    }
126}
127
128/// Spawn the mirror worker. Returns a handle for the data path to feed.
129pub fn spawn(config: MirrorConfig) -> MirrorHandle {
130    let (tx, rx) = mpsc::channel::<String>(config.queue_size.max(1));
131    let metrics = Arc::new(MirrorMetrics::default());
132    let handle = MirrorHandle {
133        tx,
134        sample_rate: config.sample_rate.clamp(0.0, 1.0),
135        writes_only: config.writes_only,
136        target: format!("{}:{}", config.backend_host, config.backend_port),
137        metrics: metrics.clone(),
138    };
139    tokio::spawn(worker(config, rx, metrics));
140    handle
141}
142
143/// Target the proxy redirects client traffic to after a migration cutover.
144/// New connections route here (with these credentials/database substituted
145/// for the client's), making the cutover transparent to the application.
146#[derive(Debug, Clone)]
147pub struct CutoverTarget {
148    pub addr: String,
149    pub user: String,
150    pub password: Option<String>,
151    pub database: Option<String>,
152}
153
154/// Per-table snapshot result.
155#[derive(Debug, Clone, Serialize)]
156pub struct TableSnapshot {
157    pub table: String,
158    pub source_rows: u64,
159    pub copied: u64,
160}
161
162fn backend_cfg(
163    host: &str,
164    port: u16,
165    user: &str,
166    pass: Option<&str>,
167    db: Option<&str>,
168    app: &str,
169) -> BackendConfig {
170    BackendConfig {
171        host: host.to_string(),
172        port,
173        user: user.to_string(),
174        password: pass.map(|s| s.to_string()),
175        database: db.map(|s| s.to_string()),
176        application_name: Some(app.to_string()),
177        tls_mode: TlsMode::Disable,
178        connect_timeout: Duration::from_secs(5),
179        query_timeout: Duration::from_secs(60),
180        tls_config: default_client_config(),
181    }
182}
183
184fn quote_ident(name: &str) -> String {
185    format!("\"{}\"", name.replace('"', "\"\""))
186}
187
188/// Encode one value into a buffer in PostgreSQL `COPY ... FROM STDIN` text
189/// format: `NULL` is `\N`; backslash/tab/newline/carriage-return are escaped.
190/// Source values arrive as their text representation already, so this is just
191/// the COPY-level escaping.
192fn encode_copy_field(out: &mut Vec<u8>, v: &TextValue) {
193    match v {
194        TextValue::Null => out.extend_from_slice(b"\\N"),
195        TextValue::Text(s) => {
196            for &b in s.as_bytes() {
197                match b {
198                    b'\\' => out.extend_from_slice(b"\\\\"),
199                    b'\t' => out.extend_from_slice(b"\\t"),
200                    b'\n' => out.extend_from_slice(b"\\n"),
201                    b'\r' => out.extend_from_slice(b"\\r"),
202                    _ => out.push(b),
203                }
204            }
205        }
206    }
207}
208
209/// Map a PostgreSQL type OID to a portable column type for the snapshot's
210/// `CREATE TABLE` on the secondary. Unknown OIDs fall back to `text`.
211fn oid_type(oid: u32) -> &'static str {
212    match oid {
213        16 => "boolean",
214        20 => "bigint",
215        21 | 23 => "integer",
216        700 => "real",
217        701 => "double precision",
218        1700 => "numeric",
219        1082 => "date",
220        1114 | 1184 => "timestamp",
221        2950 => "uuid",
222        114 | 3802 => "jsonb",
223        _ => "text",
224    }
225}
226
227/// Snapshot-bootstrap the secondary: for each table, read all existing rows
228/// from the source (primary) and copy them into the mirror target, creating
229/// the table there if needed. Returns a per-table report. Used by
230/// `POST /api/migration/snapshot` to seed a migration with existing data
231/// before/alongside the continuous write tail.
232pub async fn snapshot_tables(
233    cfg: &MirrorConfig,
234    tables: &[String],
235) -> Result<Vec<TableSnapshot>, String> {
236    let src_cfg = backend_cfg(
237        &cfg.source_host,
238        cfg.source_port,
239        &cfg.source_user,
240        cfg.source_password.as_deref(),
241        cfg.source_database.as_deref(),
242        "heliosproxy-snapshot-src",
243    );
244    let tgt_cfg = backend_cfg(
245        &cfg.backend_host,
246        cfg.backend_port,
247        &cfg.backend_user,
248        cfg.backend_password.as_deref(),
249        cfg.backend_database.as_deref(),
250        "heliosproxy-snapshot-tgt",
251    );
252    let mut src = BackendClient::connect(&src_cfg)
253        .await
254        .map_err(|e| format!("source connect: {}", e))?;
255    let mut tgt = BackendClient::connect(&tgt_cfg)
256        .await
257        .map_err(|e| format!("target connect: {}", e))?;
258
259    // Idempotency fence (default, non-destructive): refuse the whole snapshot if
260    // ANY target table already has rows — silently appending would duplicate.
261    // Pre-flight all targets first so we never leave a partial load. A probe
262    // error (table absent, etc.) means "not populated" → eligible to load. We do
263    // NOT truncate; the operator points at an empty target or clears it first.
264    let mut blocked: Vec<String> = Vec::new();
265    for table in tables {
266        let qt = quote_ident(table);
267        if let Ok(res) = tgt
268            .simple_query(&format!("SELECT 1 FROM {} LIMIT 1", qt))
269            .await
270        {
271            if !res.rows.is_empty() {
272                blocked.push(table.clone());
273            }
274        }
275    }
276    if !blocked.is_empty() {
277        src.close().await;
278        tgt.close().await;
279        return Err(format!(
280            "refusing snapshot: target table(s) already contain rows (snapshot would \
281             duplicate): {}. Snapshot into an empty target, or remove the rows first.",
282            blocked.join(", ")
283        ));
284    }
285
286    let mut report = Vec::new();
287    for table in tables {
288        let qt = quote_ident(table);
289        let res = src
290            .simple_query(&format!("SELECT * FROM {}", qt))
291            .await
292            .map_err(|e| format!("read {}: {}", table, e))?;
293
294        // CREATE TABLE IF NOT EXISTS on the target from the source columns.
295        let cols_ddl: Vec<String> = res
296            .columns
297            .iter()
298            .map(|c| format!("{} {}", quote_ident(&c.name), oid_type(c.type_oid)))
299            .collect();
300        let create = format!(
301            "CREATE TABLE IF NOT EXISTS {} ({})",
302            qt,
303            cols_ddl.join(", ")
304        );
305        tgt.execute(&create)
306            .await
307            .map_err(|e| format!("create {} on target: {}", table, e))?;
308
309        let col_list = res
310            .columns
311            .iter()
312            .map(|c| quote_ident(&c.name))
313            .collect::<Vec<_>>()
314            .join(", ");
315
316        // Primary path: a single COPY ... FROM STDIN bulk-load. `HELIOS_SNAPSHOT_USE_COPY=0`
317        // forces the INSERT path (ops kill-switch / fallback test).
318        let use_copy = std::env::var("HELIOS_SNAPSHOT_USE_COPY")
319            .map(|v| v != "0")
320            .unwrap_or(true);
321
322        let mut copied: Option<u64> = None;
323        if use_copy {
324            let mut copy_buf: Vec<u8> = Vec::new();
325            for row in &res.rows {
326                for (i, v) in row.iter().enumerate() {
327                    if i > 0 {
328                        copy_buf.push(b'\t');
329                    }
330                    encode_copy_field(&mut copy_buf, v);
331                }
332                copy_buf.push(b'\n');
333            }
334            let copy_sql = format!("COPY {} ({}) FROM STDIN", qt, col_list);
335            match tgt.copy_in(&copy_sql, &copy_buf).await {
336                Ok(n) => copied = Some(n),
337                Err(e) => {
338                    // COPY rejected/unsupported leaves the connection clean and
339                    // zero rows loaded (COPY is atomic) — fall through to INSERT.
340                    tracing::warn!(
341                        table = %table,
342                        error = %e,
343                        "COPY snapshot failed; falling back to per-row INSERT"
344                    );
345                }
346            }
347        }
348
349        // Fallback path (preserved): per-row parameterised INSERTs.
350        let copied = match copied {
351            Some(n) => n,
352            None => {
353                let placeholders = (1..=res.columns.len())
354                    .map(|i| format!("${}", i))
355                    .collect::<Vec<_>>()
356                    .join(", ");
357                let insert = format!(
358                    "INSERT INTO {} ({}) VALUES ({})",
359                    qt, col_list, placeholders
360                );
361                let mut copied = 0u64;
362                for row in &res.rows {
363                    let params: Vec<ParamValue> = row
364                        .iter()
365                        .map(|v| match v {
366                            TextValue::Null => ParamValue::Null,
367                            TextValue::Text(s) => ParamValue::Text(s.clone()),
368                        })
369                        .collect();
370                    tgt.query_with_params(&insert, &params)
371                        .await
372                        .map_err(|e| format!("insert into {}: {}", table, e))?;
373                    copied += 1;
374                }
375                copied
376            }
377        };
378        report.push(TableSnapshot {
379            table: table.clone(),
380            source_rows: res.rows.len() as u64,
381            copied,
382        });
383    }
384    src.close().await;
385    tgt.close().await;
386    Ok(report)
387}
388
389async fn worker(config: MirrorConfig, mut rx: mpsc::Receiver<String>, metrics: Arc<MirrorMetrics>) {
390    let bcfg = BackendConfig {
391        host: config.backend_host.clone(),
392        port: config.backend_port,
393        user: config.backend_user.clone(),
394        password: config.backend_password.clone(),
395        database: config.backend_database.clone(),
396        application_name: Some("heliosproxy-mirror".to_string()),
397        tls_mode: TlsMode::Disable,
398        connect_timeout: Duration::from_secs(5),
399        query_timeout: Duration::from_secs(30),
400        tls_config: default_client_config(),
401    };
402    tracing::info!(target = %bcfg.address(), "traffic mirror worker started");
403
404    let mut client: Option<BackendClient> = None;
405    while let Some(sql) = rx.recv().await {
406        // (Re)connect lazily so a temporarily-down mirror doesn't crash-loop.
407        if client.is_none() {
408            match BackendClient::connect(&bcfg).await {
409                Ok(c) => client = Some(c),
410                Err(e) => {
411                    metrics.errors.fetch_add(1, Ordering::Relaxed);
412                    tracing::debug!(error = %e, "mirror connect failed; dropping statement");
413                    continue;
414                }
415            }
416        }
417        let c = client.as_mut().unwrap();
418        if let Err(e) = c.simple_query(&sql).await {
419            metrics.errors.fetch_add(1, Ordering::Relaxed);
420            tracing::debug!(error = %e, "mirror apply failed; will reconnect");
421            // Drop the connection so the next statement reconnects.
422            if let Some(c) = client.take() {
423                c.close().await;
424            }
425        } else {
426            metrics.mirrored.fetch_add(1, Ordering::Relaxed);
427        }
428    }
429    tracing::info!("traffic mirror worker stopped");
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    fn enc(v: &TextValue) -> String {
437        let mut out = Vec::new();
438        encode_copy_field(&mut out, v);
439        String::from_utf8(out).unwrap()
440    }
441
442    #[test]
443    fn copy_field_encoding() {
444        // NULL -> \N (distinct from an empty string).
445        assert_eq!(enc(&TextValue::Null), "\\N");
446        assert_eq!(enc(&TextValue::Text(String::new())), "");
447        // Plain text passes through.
448        assert_eq!(enc(&TextValue::Text("alice".into())), "alice");
449        // Tab / newline / CR / backslash are escaped so they can't break the
450        // tab-delimited, newline-terminated COPY framing.
451        assert_eq!(enc(&TextValue::Text("a\tb".into())), "a\\tb");
452        assert_eq!(enc(&TextValue::Text("a\nb".into())), "a\\nb");
453        assert_eq!(enc(&TextValue::Text("a\rb".into())), "a\\rb");
454        assert_eq!(enc(&TextValue::Text("a\\b".into())), "a\\\\b");
455        // A literal backslash-N in data must not be confused with NULL.
456        assert_eq!(enc(&TextValue::Text("\\N".into())), "\\\\N");
457    }
458}