use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde::Serialize;
use tokio::sync::mpsc;
use crate::backend::types::TextValue;
use crate::backend::{
tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode,
};
use crate::config::MirrorConfig;
#[derive(Default)]
pub struct MirrorMetrics {
pub enqueued: AtomicU64,
pub mirrored: AtomicU64,
pub dropped: AtomicU64,
pub errors: AtomicU64,
}
#[derive(Debug, Clone, Serialize)]
pub struct MigrationStatus {
pub enabled: bool,
pub target: String,
pub writes_only: bool,
pub enqueued: u64,
pub mirrored: u64,
pub dropped: u64,
pub errors: u64,
pub lag: u64,
pub migration_ready: bool,
}
pub fn status(target: &str, writes_only: bool, m: &MirrorMetrics) -> MigrationStatus {
let enqueued = m.enqueued.load(Ordering::Relaxed);
let mirrored = m.mirrored.load(Ordering::Relaxed);
let dropped = m.dropped.load(Ordering::Relaxed);
let errors = m.errors.load(Ordering::Relaxed);
let lag = enqueued.saturating_sub(mirrored).saturating_sub(errors);
MigrationStatus {
enabled: true,
target: target.to_string(),
writes_only,
enqueued,
mirrored,
dropped,
errors,
lag,
migration_ready: lag == 0 && dropped == 0,
}
}
pub struct MirrorHandle {
tx: mpsc::Sender<String>,
sample_rate: f64,
writes_only: bool,
target: String,
pub metrics: Arc<MirrorMetrics>,
}
impl MirrorHandle {
pub fn status(&self) -> MigrationStatus {
status(&self.target, self.writes_only, &self.metrics)
}
pub fn target(&self) -> &str {
&self.target
}
pub fn writes_only(&self) -> bool {
self.writes_only
}
}
impl MirrorHandle {
pub fn offer(&self, sql: &str, is_write: bool) {
if self.writes_only && !is_write {
return;
}
if self.sample_rate < 1.0 {
use rand::Rng;
if rand::thread_rng().gen::<f64>() >= self.sample_rate {
return;
}
}
match self.tx.try_send(sql.to_string()) {
Ok(()) => {
self.metrics.enqueued.fetch_add(1, Ordering::Relaxed);
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.metrics.dropped.fetch_add(1, Ordering::Relaxed);
}
Err(mpsc::error::TrySendError::Closed(_)) => {}
}
}
}
pub fn spawn(config: MirrorConfig) -> MirrorHandle {
let (tx, rx) = mpsc::channel::<String>(config.queue_size.max(1));
let metrics = Arc::new(MirrorMetrics::default());
let handle = MirrorHandle {
tx,
sample_rate: config.sample_rate.clamp(0.0, 1.0),
writes_only: config.writes_only,
target: format!("{}:{}", config.backend_host, config.backend_port),
metrics: metrics.clone(),
};
tokio::spawn(worker(config, rx, metrics));
handle
}
#[derive(Debug, Clone)]
pub struct CutoverTarget {
pub addr: String,
pub user: String,
pub password: Option<String>,
pub database: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TableSnapshot {
pub table: String,
pub source_rows: u64,
pub copied: u64,
}
fn backend_cfg(
host: &str,
port: u16,
user: &str,
pass: Option<&str>,
db: Option<&str>,
app: &str,
) -> BackendConfig {
BackendConfig {
host: host.to_string(),
port,
user: user.to_string(),
password: pass.map(|s| s.to_string()),
database: db.map(|s| s.to_string()),
application_name: Some(app.to_string()),
tls_mode: TlsMode::Disable,
connect_timeout: Duration::from_secs(5),
query_timeout: Duration::from_secs(60),
tls_config: default_client_config(),
}
}
fn quote_ident(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
fn encode_copy_field(out: &mut Vec<u8>, v: &TextValue) {
match v {
TextValue::Null => out.extend_from_slice(b"\\N"),
TextValue::Text(s) => {
for &b in s.as_bytes() {
match b {
b'\\' => out.extend_from_slice(b"\\\\"),
b'\t' => out.extend_from_slice(b"\\t"),
b'\n' => out.extend_from_slice(b"\\n"),
b'\r' => out.extend_from_slice(b"\\r"),
_ => out.push(b),
}
}
}
}
}
fn oid_type(oid: u32) -> &'static str {
match oid {
16 => "boolean",
20 => "bigint",
21 | 23 => "integer",
700 => "real",
701 => "double precision",
1700 => "numeric",
1082 => "date",
1114 | 1184 => "timestamp",
2950 => "uuid",
114 | 3802 => "jsonb",
_ => "text",
}
}
pub async fn snapshot_tables(
cfg: &MirrorConfig,
tables: &[String],
) -> Result<Vec<TableSnapshot>, String> {
let src_cfg = backend_cfg(
&cfg.source_host,
cfg.source_port,
&cfg.source_user,
cfg.source_password.as_deref(),
cfg.source_database.as_deref(),
"heliosproxy-snapshot-src",
);
let tgt_cfg = backend_cfg(
&cfg.backend_host,
cfg.backend_port,
&cfg.backend_user,
cfg.backend_password.as_deref(),
cfg.backend_database.as_deref(),
"heliosproxy-snapshot-tgt",
);
let mut src = BackendClient::connect(&src_cfg)
.await
.map_err(|e| format!("source connect: {}", e))?;
let mut tgt = BackendClient::connect(&tgt_cfg)
.await
.map_err(|e| format!("target connect: {}", e))?;
let mut blocked: Vec<String> = Vec::new();
for table in tables {
let qt = quote_ident(table);
if let Ok(res) = tgt
.simple_query(&format!("SELECT 1 FROM {} LIMIT 1", qt))
.await
{
if !res.rows.is_empty() {
blocked.push(table.clone());
}
}
}
if !blocked.is_empty() {
src.close().await;
tgt.close().await;
return Err(format!(
"refusing snapshot: target table(s) already contain rows (snapshot would \
duplicate): {}. Snapshot into an empty target, or remove the rows first.",
blocked.join(", ")
));
}
let mut report = Vec::new();
for table in tables {
let qt = quote_ident(table);
let res = src
.simple_query(&format!("SELECT * FROM {}", qt))
.await
.map_err(|e| format!("read {}: {}", table, e))?;
let cols_ddl: Vec<String> = res
.columns
.iter()
.map(|c| format!("{} {}", quote_ident(&c.name), oid_type(c.type_oid)))
.collect();
let create = format!(
"CREATE TABLE IF NOT EXISTS {} ({})",
qt,
cols_ddl.join(", ")
);
tgt.execute(&create)
.await
.map_err(|e| format!("create {} on target: {}", table, e))?;
let col_list = res
.columns
.iter()
.map(|c| quote_ident(&c.name))
.collect::<Vec<_>>()
.join(", ");
let use_copy = std::env::var("HELIOS_SNAPSHOT_USE_COPY")
.map(|v| v != "0")
.unwrap_or(true);
let mut copied: Option<u64> = None;
if use_copy {
let mut copy_buf: Vec<u8> = Vec::new();
for row in &res.rows {
for (i, v) in row.iter().enumerate() {
if i > 0 {
copy_buf.push(b'\t');
}
encode_copy_field(&mut copy_buf, v);
}
copy_buf.push(b'\n');
}
let copy_sql = format!("COPY {} ({}) FROM STDIN", qt, col_list);
match tgt.copy_in(©_sql, ©_buf).await {
Ok(n) => copied = Some(n),
Err(e) => {
tracing::warn!(
table = %table,
error = %e,
"COPY snapshot failed; falling back to per-row INSERT"
);
}
}
}
let copied = match copied {
Some(n) => n,
None => {
let placeholders = (1..=res.columns.len())
.map(|i| format!("${}", i))
.collect::<Vec<_>>()
.join(", ");
let insert = format!(
"INSERT INTO {} ({}) VALUES ({})",
qt, col_list, placeholders
);
let mut copied = 0u64;
for row in &res.rows {
let params: Vec<ParamValue> = row
.iter()
.map(|v| match v {
TextValue::Null => ParamValue::Null,
TextValue::Text(s) => ParamValue::Text(s.clone()),
})
.collect();
tgt.query_with_params(&insert, ¶ms)
.await
.map_err(|e| format!("insert into {}: {}", table, e))?;
copied += 1;
}
copied
}
};
report.push(TableSnapshot {
table: table.clone(),
source_rows: res.rows.len() as u64,
copied,
});
}
src.close().await;
tgt.close().await;
Ok(report)
}
async fn worker(config: MirrorConfig, mut rx: mpsc::Receiver<String>, metrics: Arc<MirrorMetrics>) {
let bcfg = BackendConfig {
host: config.backend_host.clone(),
port: config.backend_port,
user: config.backend_user.clone(),
password: config.backend_password.clone(),
database: config.backend_database.clone(),
application_name: Some("heliosproxy-mirror".to_string()),
tls_mode: TlsMode::Disable,
connect_timeout: Duration::from_secs(5),
query_timeout: Duration::from_secs(30),
tls_config: default_client_config(),
};
tracing::info!(target = %bcfg.address(), "traffic mirror worker started");
let mut client: Option<BackendClient> = None;
while let Some(sql) = rx.recv().await {
if client.is_none() {
match BackendClient::connect(&bcfg).await {
Ok(c) => client = Some(c),
Err(e) => {
metrics.errors.fetch_add(1, Ordering::Relaxed);
tracing::debug!(error = %e, "mirror connect failed; dropping statement");
continue;
}
}
}
let c = client.as_mut().unwrap();
if let Err(e) = c.simple_query(&sql).await {
metrics.errors.fetch_add(1, Ordering::Relaxed);
tracing::debug!(error = %e, "mirror apply failed; will reconnect");
if let Some(c) = client.take() {
c.close().await;
}
} else {
metrics.mirrored.fetch_add(1, Ordering::Relaxed);
}
}
tracing::info!("traffic mirror worker stopped");
}
#[cfg(test)]
mod tests {
use super::*;
fn enc(v: &TextValue) -> String {
let mut out = Vec::new();
encode_copy_field(&mut out, v);
String::from_utf8(out).unwrap()
}
#[test]
fn copy_field_encoding() {
assert_eq!(enc(&TextValue::Null), "\\N");
assert_eq!(enc(&TextValue::Text(String::new())), "");
assert_eq!(enc(&TextValue::Text("alice".into())), "alice");
assert_eq!(enc(&TextValue::Text("a\tb".into())), "a\\tb");
assert_eq!(enc(&TextValue::Text("a\nb".into())), "a\\nb");
assert_eq!(enc(&TextValue::Text("a\rb".into())), "a\\rb");
assert_eq!(enc(&TextValue::Text("a\\b".into())), "a\\\\b");
assert_eq!(enc(&TextValue::Text("\\N".into())), "\\\\N");
}
}