Skip to main content

heliosdb_proxy/
mirror.rs

1//! Continuous traffic mirroring.
2//!
3//! Replays a sampled share of live (simple-query) write statements to a
4//! secondary backend, **asynchronously and off the client hot path**: the
5//! data path does a non-blocking `try_send` into a bounded queue and moves
6//! on; a background worker drains the queue and applies each statement to the
7//! mirror backend. When the queue is full, statements are dropped (and
8//! counted) rather than slowing the client — mirroring is best-effort.
9//!
10//! This is the on-ramp to the PG->Nano migration mirror (Batch G2): point the
11//! mirror at a HeliosDB-Nano instance and its write set tracks the primary.
12//! (Result diffing for blue/green validation already lives in
13//! `shadow_execute`; this module is the continuous write tail.)
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use serde::Serialize;
20use tokio::sync::mpsc;
21
22use crate::backend::types::TextValue;
23use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode};
24use crate::config::MirrorConfig;
25
26/// Counters surfaced for observability.
27#[derive(Default)]
28pub struct MirrorMetrics {
29    /// Statements accepted into the queue.
30    pub enqueued: AtomicU64,
31    /// Statements successfully applied to the mirror backend.
32    pub mirrored: AtomicU64,
33    /// Statements dropped because the queue was full.
34    pub dropped: AtomicU64,
35    /// Apply/connect failures.
36    pub errors: AtomicU64,
37}
38
39/// Operator-facing migration status (served at `/api/migration/status`).
40#[derive(Debug, Clone, Serialize)]
41pub struct MigrationStatus {
42    pub enabled: bool,
43    pub target: String,
44    pub writes_only: bool,
45    pub enqueued: u64,
46    pub mirrored: u64,
47    pub dropped: u64,
48    pub errors: u64,
49    /// Statements accepted but not yet applied (queue backlog).
50    pub lag: u64,
51    /// True when the mirror is enabled, the backlog is drained, and nothing
52    /// has been dropped — i.e. the secondary is caught up and a cutover is
53    /// safe with respect to the mirrored write set.
54    pub migration_ready: bool,
55}
56
57/// Compute a migration status snapshot from the live counters.
58pub fn status(target: &str, writes_only: bool, m: &MirrorMetrics) -> MigrationStatus {
59    let enqueued = m.enqueued.load(Ordering::Relaxed);
60    let mirrored = m.mirrored.load(Ordering::Relaxed);
61    let dropped = m.dropped.load(Ordering::Relaxed);
62    let errors = m.errors.load(Ordering::Relaxed);
63    let lag = enqueued.saturating_sub(mirrored).saturating_sub(errors);
64    MigrationStatus {
65        enabled: true,
66        target: target.to_string(),
67        writes_only,
68        enqueued,
69        mirrored,
70        dropped,
71        errors,
72        lag,
73        migration_ready: lag == 0 && dropped == 0,
74    }
75}
76
77/// Handle held by the server: a bounded sender plus the sampling policy.
78pub struct MirrorHandle {
79    tx: mpsc::Sender<String>,
80    sample_rate: f64,
81    writes_only: bool,
82    target: String,
83    pub metrics: Arc<MirrorMetrics>,
84}
85
86impl MirrorHandle {
87    /// A snapshot of migration status for the admin API.
88    pub fn status(&self) -> MigrationStatus {
89        status(&self.target, self.writes_only, &self.metrics)
90    }
91    pub fn target(&self) -> &str {
92        &self.target
93    }
94    pub fn writes_only(&self) -> bool {
95        self.writes_only
96    }
97}
98
99impl MirrorHandle {
100    /// Offer one statement to the mirror. `is_write` is the data path's
101    /// already-computed verb classification (avoids re-parsing). Non-blocking:
102    /// drops (and counts) when the queue is full. Returns immediately.
103    pub fn offer(&self, sql: &str, is_write: bool) {
104        if self.writes_only && !is_write {
105            return;
106        }
107        if self.sample_rate < 1.0 {
108            // Cheap per-call sample without locking a shared RNG.
109            use rand::Rng;
110            if rand::thread_rng().gen::<f64>() >= self.sample_rate {
111                return;
112            }
113        }
114        match self.tx.try_send(sql.to_string()) {
115            Ok(()) => {
116                self.metrics.enqueued.fetch_add(1, Ordering::Relaxed);
117            }
118            Err(mpsc::error::TrySendError::Full(_)) => {
119                self.metrics.dropped.fetch_add(1, Ordering::Relaxed);
120            }
121            Err(mpsc::error::TrySendError::Closed(_)) => {}
122        }
123    }
124}
125
126/// Spawn the mirror worker. Returns a handle for the data path to feed.
127pub fn spawn(config: MirrorConfig) -> MirrorHandle {
128    let (tx, rx) = mpsc::channel::<String>(config.queue_size.max(1));
129    let metrics = Arc::new(MirrorMetrics::default());
130    let handle = MirrorHandle {
131        tx,
132        sample_rate: config.sample_rate.clamp(0.0, 1.0),
133        writes_only: config.writes_only,
134        target: format!("{}:{}", config.backend_host, config.backend_port),
135        metrics: metrics.clone(),
136    };
137    tokio::spawn(worker(config, rx, metrics));
138    handle
139}
140
141/// Target the proxy redirects client traffic to after a migration cutover.
142/// New connections route here (with these credentials/database substituted
143/// for the client's), making the cutover transparent to the application.
144#[derive(Debug, Clone)]
145pub struct CutoverTarget {
146    pub addr: String,
147    pub user: String,
148    pub password: Option<String>,
149    pub database: Option<String>,
150}
151
152/// Per-table snapshot result.
153#[derive(Debug, Clone, Serialize)]
154pub struct TableSnapshot {
155    pub table: String,
156    pub source_rows: u64,
157    pub copied: u64,
158}
159
160fn backend_cfg(host: &str, port: u16, user: &str, pass: Option<&str>, db: Option<&str>, app: &str) -> BackendConfig {
161    BackendConfig {
162        host: host.to_string(),
163        port,
164        user: user.to_string(),
165        password: pass.map(|s| s.to_string()),
166        database: db.map(|s| s.to_string()),
167        application_name: Some(app.to_string()),
168        tls_mode: TlsMode::Disable,
169        connect_timeout: Duration::from_secs(5),
170        query_timeout: Duration::from_secs(60),
171        tls_config: default_client_config(),
172    }
173}
174
175fn quote_ident(name: &str) -> String {
176    format!("\"{}\"", name.replace('"', "\"\""))
177}
178
179/// Encode one value into a buffer in PostgreSQL `COPY ... FROM STDIN` text
180/// format: `NULL` is `\N`; backslash/tab/newline/carriage-return are escaped.
181/// Source values arrive as their text representation already, so this is just
182/// the COPY-level escaping.
183fn encode_copy_field(out: &mut Vec<u8>, v: &TextValue) {
184    match v {
185        TextValue::Null => out.extend_from_slice(b"\\N"),
186        TextValue::Text(s) => {
187            for &b in s.as_bytes() {
188                match b {
189                    b'\\' => out.extend_from_slice(b"\\\\"),
190                    b'\t' => out.extend_from_slice(b"\\t"),
191                    b'\n' => out.extend_from_slice(b"\\n"),
192                    b'\r' => out.extend_from_slice(b"\\r"),
193                    _ => out.push(b),
194                }
195            }
196        }
197    }
198}
199
200/// Map a PostgreSQL type OID to a portable column type for the snapshot's
201/// `CREATE TABLE` on the secondary. Unknown OIDs fall back to `text`.
202fn oid_type(oid: u32) -> &'static str {
203    match oid {
204        16 => "boolean",
205        20 => "bigint",
206        21 | 23 => "integer",
207        700 => "real",
208        701 => "double precision",
209        1700 => "numeric",
210        1082 => "date",
211        1114 | 1184 => "timestamp",
212        2950 => "uuid",
213        114 | 3802 => "jsonb",
214        _ => "text",
215    }
216}
217
218/// Snapshot-bootstrap the secondary: for each table, read all existing rows
219/// from the source (primary) and copy them into the mirror target, creating
220/// the table there if needed. Returns a per-table report. Used by
221/// `POST /api/migration/snapshot` to seed a migration with existing data
222/// before/alongside the continuous write tail.
223pub async fn snapshot_tables(cfg: &MirrorConfig, tables: &[String]) -> Result<Vec<TableSnapshot>, String> {
224    let src_cfg = backend_cfg(
225        &cfg.source_host, cfg.source_port, &cfg.source_user,
226        cfg.source_password.as_deref(), cfg.source_database.as_deref(), "heliosproxy-snapshot-src",
227    );
228    let tgt_cfg = backend_cfg(
229        &cfg.backend_host, cfg.backend_port, &cfg.backend_user,
230        cfg.backend_password.as_deref(), cfg.backend_database.as_deref(), "heliosproxy-snapshot-tgt",
231    );
232    let mut src = BackendClient::connect(&src_cfg).await.map_err(|e| format!("source connect: {}", e))?;
233    let mut tgt = BackendClient::connect(&tgt_cfg).await.map_err(|e| format!("target connect: {}", e))?;
234
235    // Idempotency fence (default, non-destructive): refuse the whole snapshot if
236    // ANY target table already has rows — silently appending would duplicate.
237    // Pre-flight all targets first so we never leave a partial load. A probe
238    // error (table absent, etc.) means "not populated" → eligible to load. We do
239    // NOT truncate; the operator points at an empty target or clears it first.
240    let mut blocked: Vec<String> = Vec::new();
241    for table in tables {
242        let qt = quote_ident(table);
243        if let Ok(res) = tgt.simple_query(&format!("SELECT 1 FROM {} LIMIT 1", qt)).await {
244            if !res.rows.is_empty() {
245                blocked.push(table.clone());
246            }
247        }
248    }
249    if !blocked.is_empty() {
250        src.close().await;
251        tgt.close().await;
252        return Err(format!(
253            "refusing snapshot: target table(s) already contain rows (snapshot would \
254             duplicate): {}. Snapshot into an empty target, or remove the rows first.",
255            blocked.join(", ")
256        ));
257    }
258
259    let mut report = Vec::new();
260    for table in tables {
261        let qt = quote_ident(table);
262        let res = src
263            .simple_query(&format!("SELECT * FROM {}", qt))
264            .await
265            .map_err(|e| format!("read {}: {}", table, e))?;
266
267        // CREATE TABLE IF NOT EXISTS on the target from the source columns.
268        let cols_ddl: Vec<String> = res
269            .columns
270            .iter()
271            .map(|c| format!("{} {}", quote_ident(&c.name), oid_type(c.type_oid)))
272            .collect();
273        let create = format!("CREATE TABLE IF NOT EXISTS {} ({})", qt, cols_ddl.join(", "));
274        tgt.execute(&create).await.map_err(|e| format!("create {} on target: {}", table, e))?;
275
276        let col_list = res.columns.iter().map(|c| quote_ident(&c.name)).collect::<Vec<_>>().join(", ");
277
278        // Primary path: a single COPY ... FROM STDIN bulk-load. `HELIOS_SNAPSHOT_USE_COPY=0`
279        // forces the INSERT path (ops kill-switch / fallback test).
280        let use_copy =
281            std::env::var("HELIOS_SNAPSHOT_USE_COPY").map(|v| v != "0").unwrap_or(true);
282
283        let mut copied: Option<u64> = None;
284        if use_copy {
285            let mut copy_buf: Vec<u8> = Vec::new();
286            for row in &res.rows {
287                for (i, v) in row.iter().enumerate() {
288                    if i > 0 {
289                        copy_buf.push(b'\t');
290                    }
291                    encode_copy_field(&mut copy_buf, v);
292                }
293                copy_buf.push(b'\n');
294            }
295            let copy_sql = format!("COPY {} ({}) FROM STDIN", qt, col_list);
296            match tgt.copy_in(&copy_sql, &copy_buf).await {
297                Ok(n) => copied = Some(n),
298                Err(e) => {
299                    // COPY rejected/unsupported leaves the connection clean and
300                    // zero rows loaded (COPY is atomic) — fall through to INSERT.
301                    tracing::warn!(
302                        table = %table,
303                        error = %e,
304                        "COPY snapshot failed; falling back to per-row INSERT"
305                    );
306                }
307            }
308        }
309
310        // Fallback path (preserved): per-row parameterised INSERTs.
311        let copied = match copied {
312            Some(n) => n,
313            None => {
314                let placeholders =
315                    (1..=res.columns.len()).map(|i| format!("${}", i)).collect::<Vec<_>>().join(", ");
316                let insert = format!("INSERT INTO {} ({}) VALUES ({})", qt, col_list, placeholders);
317                let mut copied = 0u64;
318                for row in &res.rows {
319                    let params: Vec<ParamValue> = row
320                        .iter()
321                        .map(|v| match v {
322                            TextValue::Null => ParamValue::Null,
323                            TextValue::Text(s) => ParamValue::Text(s.clone()),
324                        })
325                        .collect();
326                    tgt.query_with_params(&insert, &params)
327                        .await
328                        .map_err(|e| format!("insert into {}: {}", table, e))?;
329                    copied += 1;
330                }
331                copied
332            }
333        };
334        report.push(TableSnapshot { table: table.clone(), source_rows: res.rows.len() as u64, copied });
335    }
336    src.close().await;
337    tgt.close().await;
338    Ok(report)
339}
340
341async fn worker(config: MirrorConfig, mut rx: mpsc::Receiver<String>, metrics: Arc<MirrorMetrics>) {
342    let bcfg = BackendConfig {
343        host: config.backend_host.clone(),
344        port: config.backend_port,
345        user: config.backend_user.clone(),
346        password: config.backend_password.clone(),
347        database: config.backend_database.clone(),
348        application_name: Some("heliosproxy-mirror".to_string()),
349        tls_mode: TlsMode::Disable,
350        connect_timeout: Duration::from_secs(5),
351        query_timeout: Duration::from_secs(30),
352        tls_config: default_client_config(),
353    };
354    tracing::info!(target = %bcfg.address(), "traffic mirror worker started");
355
356    let mut client: Option<BackendClient> = None;
357    while let Some(sql) = rx.recv().await {
358        // (Re)connect lazily so a temporarily-down mirror doesn't crash-loop.
359        if client.is_none() {
360            match BackendClient::connect(&bcfg).await {
361                Ok(c) => client = Some(c),
362                Err(e) => {
363                    metrics.errors.fetch_add(1, Ordering::Relaxed);
364                    tracing::debug!(error = %e, "mirror connect failed; dropping statement");
365                    continue;
366                }
367            }
368        }
369        let c = client.as_mut().unwrap();
370        if let Err(e) = c.simple_query(&sql).await {
371            metrics.errors.fetch_add(1, Ordering::Relaxed);
372            tracing::debug!(error = %e, "mirror apply failed; will reconnect");
373            // Drop the connection so the next statement reconnects.
374            if let Some(c) = client.take() {
375                c.close().await;
376            }
377        } else {
378            metrics.mirrored.fetch_add(1, Ordering::Relaxed);
379        }
380    }
381    tracing::info!("traffic mirror worker stopped");
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    fn enc(v: &TextValue) -> String {
389        let mut out = Vec::new();
390        encode_copy_field(&mut out, v);
391        String::from_utf8(out).unwrap()
392    }
393
394    #[test]
395    fn copy_field_encoding() {
396        // NULL -> \N (distinct from an empty string).
397        assert_eq!(enc(&TextValue::Null), "\\N");
398        assert_eq!(enc(&TextValue::Text(String::new())), "");
399        // Plain text passes through.
400        assert_eq!(enc(&TextValue::Text("alice".into())), "alice");
401        // Tab / newline / CR / backslash are escaped so they can't break the
402        // tab-delimited, newline-terminated COPY framing.
403        assert_eq!(enc(&TextValue::Text("a\tb".into())), "a\\tb");
404        assert_eq!(enc(&TextValue::Text("a\nb".into())), "a\\nb");
405        assert_eq!(enc(&TextValue::Text("a\rb".into())), "a\\rb");
406        assert_eq!(enc(&TextValue::Text("a\\b".into())), "a\\\\b");
407        // A literal backslash-N in data must not be confused with NULL.
408        assert_eq!(enc(&TextValue::Text("\\N".into())), "\\\\N");
409    }
410}