use super::transport::{LogicalEntryPayload, LogicalOperation, LogicalValue};
use super::wal_replicator::{Lsn, WalEntry, WalEntryType};
use super::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableFilter {
pub pattern: String,
pub action: FilterAction,
}
impl TableFilter {
pub fn include(pattern: impl Into<String>) -> Self {
Self {
pattern: pattern.into(),
action: FilterAction::Include,
}
}
pub fn exclude(pattern: impl Into<String>) -> Self {
Self {
pattern: pattern.into(),
action: FilterAction::Exclude,
}
}
pub fn matches(&self, table: &str) -> bool {
if self.pattern == "*" {
return true;
}
if self.pattern.contains('*') {
let parts: Vec<&str> = self.pattern.split('*').collect();
if parts.len() == 2 {
let prefix = parts[0];
let suffix = parts[1];
return table.starts_with(prefix) && table.ends_with(suffix);
}
}
self.pattern == table
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FilterAction {
Include,
Exclude,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RowFilter {
pub table: String,
pub predicate: String,
}
impl RowFilter {
pub fn new(table: impl Into<String>, predicate: impl Into<String>) -> Self {
Self {
table: table.into(),
predicate: predicate.into(),
}
}
pub fn evaluate(&self, row: &ChangeRow) -> bool {
self.evaluate_predicate(row)
}
fn evaluate_predicate(&self, row: &ChangeRow) -> bool {
let predicate = self.predicate.trim();
if let Some((field, value)) = predicate.split_once("!=") {
let field = field.trim();
let value = value.trim().trim_matches('\'');
if let Some(row_value) = row.get_field(field) {
return row_value != value;
}
return true; }
if let Some((field, value)) = predicate.split_once('=') {
let field = field.trim();
let value = value.trim().trim_matches('\'');
if let Some(row_value) = row.get_field(field) {
return row_value == value;
}
return true;
}
if let Some((field, value)) = predicate.split_once('>') {
let field = field.trim();
let value = value.trim();
if let Some(row_value) = row.get_field(field) {
if let (Ok(rv), Ok(v)) = (row_value.parse::<i64>(), value.parse::<i64>()) {
return rv > v;
}
}
return true;
}
if let Some((field, value)) = predicate.split_once('<') {
let field = field.trim();
let value = value.trim();
if let Some(row_value) = row.get_field(field) {
if let (Ok(rv), Ok(v)) = (row_value.parse::<i64>(), value.parse::<i64>()) {
return rv < v;
}
}
return true;
}
true }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnMapping {
pub table: String,
pub source_column: String,
pub target_column: Option<String>,
pub transform: ColumnTransform,
}
impl ColumnMapping {
pub fn rename(
table: impl Into<String>,
source: impl Into<String>,
target: impl Into<String>,
) -> Self {
Self {
table: table.into(),
source_column: source.into(),
target_column: Some(target.into()),
transform: ColumnTransform::Rename,
}
}
pub fn drop(table: impl Into<String>, column: impl Into<String>) -> Self {
Self {
table: table.into(),
source_column: column.into(),
target_column: None,
transform: ColumnTransform::Drop,
}
}
pub fn cast(
table: impl Into<String>,
column: impl Into<String>,
target_type: DataType,
) -> Self {
Self {
table: table.into(),
source_column: column.into(),
target_column: None,
transform: ColumnTransform::Cast(target_type),
}
}
pub fn multiply(table: impl Into<String>, column: impl Into<String>, factor: f64) -> Self {
Self {
table: table.into(),
source_column: column.into(),
target_column: None,
transform: ColumnTransform::Multiply(factor),
}
}
pub fn mask(table: impl Into<String>, column: impl Into<String>, mask_char: char) -> Self {
Self {
table: table.into(),
source_column: column.into(),
target_column: None,
transform: ColumnTransform::Mask(mask_char),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColumnTransform {
Rename,
Drop,
Cast(DataType),
Multiply(f64),
Mask(char),
Hash,
Expression(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
Integer,
Float,
String,
Boolean,
Timestamp,
Json,
Bytes,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChangeEvent {
pub lsn: Lsn,
pub tx_id: Option<u64>,
pub table: String,
pub schema: Option<String>,
pub operation: ChangeOperation,
pub row: ChangeRow,
pub old_row: Option<ChangeRow>,
pub timestamp: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ChangeOperation {
Insert,
Update,
Delete,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChangeRow {
pub fields: HashMap<String, FieldValue>,
}
impl ChangeRow {
pub fn new() -> Self {
Self::default()
}
pub fn set_field(&mut self, name: impl Into<String>, value: FieldValue) {
self.fields.insert(name.into(), value);
}
pub fn get_field(&self, name: &str) -> Option<String> {
self.fields.get(name).map(|v| v.to_string())
}
pub fn get(&self, name: &str) -> Option<&FieldValue> {
self.fields.get(name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FieldValue {
Null,
Integer(i64),
Float(f64),
String(String),
Boolean(bool),
Bytes(Vec<u8>),
Timestamp(u64),
}
impl std::fmt::Display for FieldValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldValue::Null => write!(f, "NULL"),
FieldValue::Integer(v) => write!(f, "{}", v),
FieldValue::Float(v) => write!(f, "{}", v),
FieldValue::String(v) => write!(f, "{}", v),
FieldValue::Boolean(v) => write!(f, "{}", v),
FieldValue::Bytes(v) => write!(f, "<{} bytes>", v.len()),
FieldValue::Timestamp(v) => write!(f, "{}", v),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LogicalReplicationConfig {
pub table_filters: Vec<TableFilter>,
pub row_filters: Vec<RowFilter>,
pub column_mappings: Vec<ColumnMapping>,
pub replicate_ddl: bool,
pub replicate_truncate: bool,
pub batch_size: usize,
}
impl LogicalReplicationConfig {
pub fn new() -> Self {
Self {
batch_size: 1000,
..Default::default()
}
}
pub fn add_table_filter(mut self, filter: TableFilter) -> Self {
self.table_filters.push(filter);
self
}
pub fn add_row_filter(mut self, filter: RowFilter) -> Self {
self.row_filters.push(filter);
self
}
pub fn add_column_mapping(mut self, mapping: ColumnMapping) -> Self {
self.column_mappings.push(mapping);
self
}
}
pub struct LogicalReplicationPipeline {
config: LogicalReplicationConfig,
stats: Arc<RwLock<PipelineStats>>,
current_tx: Arc<RwLock<Option<TransactionState>>>,
}
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
pub entries_processed: u64,
pub entries_passed: u64,
pub entries_filtered: u64,
pub transformations_applied: u64,
pub errors: u64,
}
struct TransactionState {
tx_id: u64,
start_lsn: Lsn,
changes: Vec<ChangeEvent>,
}
impl LogicalReplicationPipeline {
pub fn new(config: LogicalReplicationConfig) -> Self {
Self {
config,
stats: Arc::new(RwLock::new(PipelineStats::default())),
current_tx: Arc::new(RwLock::new(None)),
}
}
pub async fn process(&self, entry: &WalEntry) -> Result<Option<ChangeEvent>> {
let mut stats = self.stats.write().await;
stats.entries_processed += 1;
let event = match self.decode_entry(entry) {
Some(e) => e,
None => return Ok(None), };
if !self.should_replicate_table(&event.table) {
stats.entries_filtered += 1;
return Ok(None);
}
if !self.should_replicate_row(&event.table, &event.row) {
stats.entries_filtered += 1;
return Ok(None);
}
let transformed = self.apply_transformations(event)?;
stats.entries_passed += 1;
Ok(Some(transformed))
}
pub async fn process_batch(&self, entries: &[WalEntry]) -> Result<Vec<ChangeEvent>> {
let mut results = Vec::with_capacity(entries.len());
for entry in entries {
if let Some(event) = self.process(entry).await? {
results.push(event);
}
}
Ok(results)
}
fn decode_entry(&self, entry: &WalEntry) -> Option<ChangeEvent> {
let operation = match entry.entry_type {
WalEntryType::Insert => ChangeOperation::Insert,
WalEntryType::Update => ChangeOperation::Update,
WalEntryType::Delete => ChangeOperation::Delete,
_ => return None,
};
let decoded = self.parse_wal_data(&entry.data)?;
Some(ChangeEvent {
lsn: entry.lsn,
tx_id: decoded.tx_id,
table: decoded.table,
schema: decoded.schema,
operation,
row: decoded.row,
old_row: decoded.old_row,
timestamp: chrono::Utc::now().timestamp_micros() as u64,
})
}
fn parse_wal_data(&self, data: &[u8]) -> Option<DecodedWalData> {
if let Ok(decoded) = serde_json::from_slice::<DecodedWalData>(data) {
return Some(decoded);
}
Some(DecodedWalData {
tx_id: None,
table: "unknown".to_string(),
schema: None,
row: ChangeRow::new(),
old_row: None,
})
}
fn should_replicate_table(&self, table: &str) -> bool {
if self.config.table_filters.is_empty() {
return true; }
let mut should_include = false;
let mut explicitly_excluded = false;
for filter in &self.config.table_filters {
if filter.matches(table) {
match filter.action {
FilterAction::Include => should_include = true,
FilterAction::Exclude => explicitly_excluded = true,
}
}
}
should_include && !explicitly_excluded
}
fn should_replicate_row(&self, table: &str, row: &ChangeRow) -> bool {
for filter in &self.config.row_filters {
if filter.table == table || filter.table == "*" {
if !filter.evaluate(row) {
return false;
}
}
}
true
}
fn apply_transformations(&self, mut event: ChangeEvent) -> Result<ChangeEvent> {
for mapping in &self.config.column_mappings {
if mapping.table != event.table && mapping.table != "*" {
continue;
}
event.row = self.transform_row(&event.row, mapping)?;
if let Some(old_row) = event.old_row.take() {
event.old_row = Some(self.transform_row(&old_row, mapping)?);
}
}
Ok(event)
}
fn transform_row(&self, row: &ChangeRow, mapping: &ColumnMapping) -> Result<ChangeRow> {
let mut new_row = row.clone();
let value = match row.get(&mapping.source_column) {
Some(v) => v.clone(),
None => return Ok(new_row), };
match &mapping.transform {
ColumnTransform::Rename => {
new_row.fields.remove(&mapping.source_column);
if let Some(target) = &mapping.target_column {
new_row.fields.insert(target.clone(), value);
}
}
ColumnTransform::Drop => {
new_row.fields.remove(&mapping.source_column);
}
ColumnTransform::Cast(target_type) => {
let converted = self.cast_value(&value, *target_type)?;
let target = mapping.target_column.as_ref().unwrap_or(&mapping.source_column);
new_row.fields.insert(target.clone(), converted);
}
ColumnTransform::Multiply(factor) => {
let multiplied = self.multiply_value(&value, *factor)?;
let target = mapping.target_column.as_ref().unwrap_or(&mapping.source_column);
new_row.fields.insert(target.clone(), multiplied);
}
ColumnTransform::Mask(mask_char) => {
let masked = self.mask_value(&value, *mask_char);
let target = mapping.target_column.as_ref().unwrap_or(&mapping.source_column);
new_row.fields.insert(target.clone(), masked);
}
ColumnTransform::Hash => {
let hashed = self.hash_value(&value);
let target = mapping.target_column.as_ref().unwrap_or(&mapping.source_column);
new_row.fields.insert(target.clone(), hashed);
}
ColumnTransform::Expression(expr) => {
let target = mapping.target_column.as_ref().unwrap_or(&mapping.source_column);
let result = self.evaluate_expression(expr, &value)?;
new_row.fields.insert(target.clone(), result);
}
}
Ok(new_row)
}
fn cast_value(&self, value: &FieldValue, target: DataType) -> Result<FieldValue> {
Ok(match (value, target) {
(FieldValue::Integer(i), DataType::Float) => FieldValue::Float(*i as f64),
(FieldValue::Integer(i), DataType::String) => FieldValue::String(i.to_string()),
(FieldValue::Float(f), DataType::Integer) => FieldValue::Integer(*f as i64),
(FieldValue::Float(f), DataType::String) => FieldValue::String(f.to_string()),
(FieldValue::String(s), DataType::Integer) => {
FieldValue::Integer(s.parse().unwrap_or(0))
}
(FieldValue::String(s), DataType::Float) => {
FieldValue::Float(s.parse().unwrap_or(0.0))
}
(FieldValue::Boolean(b), DataType::Integer) => FieldValue::Integer(if *b { 1 } else { 0 }),
(FieldValue::Boolean(b), DataType::String) => FieldValue::String(b.to_string()),
_ => value.clone(), })
}
fn multiply_value(&self, value: &FieldValue, factor: f64) -> Result<FieldValue> {
Ok(match value {
FieldValue::Integer(i) => FieldValue::Integer((*i as f64 * factor) as i64),
FieldValue::Float(f) => FieldValue::Float(f * factor),
_ => value.clone(),
})
}
fn mask_value(&self, value: &FieldValue, mask_char: char) -> FieldValue {
match value {
FieldValue::String(s) => {
let masked: String = s.chars().map(|_| mask_char).collect();
FieldValue::String(masked)
}
_ => value.clone(),
}
}
fn hash_value(&self, value: &FieldValue) -> FieldValue {
let bytes = match value {
FieldValue::String(s) => s.as_bytes().to_vec(),
FieldValue::Bytes(b) => b.clone(),
FieldValue::Integer(i) => i.to_le_bytes().to_vec(),
FieldValue::Float(f) => f.to_le_bytes().to_vec(),
_ => vec![],
};
let hash = blake3::hash(&bytes);
FieldValue::String(hash.to_hex().to_string())
}
fn evaluate_expression(&self, expr: &str, value: &FieldValue) -> Result<FieldValue> {
let expr_upper = expr.to_uppercase();
let expr_trimmed = expr_upper.trim();
if let Some(paren_start) = expr_trimmed.find('(') {
let func_name = &expr_trimmed[..paren_start];
let _args_part = &expr_trimmed[paren_start..];
match func_name {
"UPPER" => {
if let FieldValue::String(s) = value {
return Ok(FieldValue::String(s.to_uppercase()));
}
return Ok(value.clone());
}
"LOWER" => {
if let FieldValue::String(s) = value {
return Ok(FieldValue::String(s.to_lowercase()));
}
return Ok(value.clone());
}
"TRIM" => {
if let FieldValue::String(s) = value {
return Ok(FieldValue::String(s.trim().to_string()));
}
return Ok(value.clone());
}
"LENGTH" | "LEN" => {
if let FieldValue::String(s) = value {
return Ok(FieldValue::Integer(s.len() as i64));
}
return Ok(FieldValue::Integer(0));
}
"ABS" => {
match value {
FieldValue::Integer(i) => return Ok(FieldValue::Integer(i.abs())),
FieldValue::Float(f) => return Ok(FieldValue::Float(f.abs())),
_ => return Ok(value.clone()),
}
}
"COALESCE" => {
if matches!(value, FieldValue::Null) {
if let Some(comma_pos) = expr.find(',') {
let default_str = &expr[comma_pos + 1..];
let default_str = default_str.trim().trim_matches(|c| c == '\'' || c == ')');
return Ok(FieldValue::String(default_str.to_string()));
}
}
return Ok(value.clone());
}
"CONCAT" => {
if let FieldValue::String(s) = value {
if let Some(comma_pos) = expr.find(',') {
let suffix = &expr[comma_pos + 1..];
let suffix = suffix.trim().trim_matches(|c| c == '\'' || c == ')');
return Ok(FieldValue::String(format!("{}{}", s, suffix)));
}
}
return Ok(value.clone());
}
"SUBSTR" | "SUBSTRING" => {
if let FieldValue::String(s) = value {
if s.len() > 10 {
return Ok(FieldValue::String(s[..10].to_string()));
}
}
return Ok(value.clone());
}
"ROUND" => {
if let FieldValue::Float(f) = value {
return Ok(FieldValue::Float(f.round()));
}
return Ok(value.clone());
}
"FLOOR" => {
if let FieldValue::Float(f) = value {
return Ok(FieldValue::Float(f.floor()));
}
return Ok(value.clone());
}
"CEIL" | "CEILING" => {
if let FieldValue::Float(f) = value {
return Ok(FieldValue::Float(f.ceil()));
}
return Ok(value.clone());
}
_ => {
tracing::warn!("Unknown expression function: {}", func_name);
}
}
}
Ok(value.clone())
}
pub fn to_logical_payload(&self, event: &ChangeEvent) -> LogicalEntryPayload {
let operation = match event.operation {
ChangeOperation::Insert => LogicalOperation::Insert,
ChangeOperation::Update => LogicalOperation::Update,
ChangeOperation::Delete => LogicalOperation::Delete,
};
let new_values = Some(
event
.row
.fields
.iter()
.map(|(k, v)| (k.clone(), Self::field_to_logical_value(v)))
.collect(),
);
let old_values = event.old_row.as_ref().map(|r| {
r.fields
.iter()
.map(|(k, v)| (k.clone(), Self::field_to_logical_value(v)))
.collect()
});
LogicalEntryPayload {
lsn: event.lsn,
tx_id: event.tx_id,
schema: event.schema.clone().unwrap_or_default(),
table: event.table.clone(),
operation,
old_values,
new_values,
timestamp_us: event.timestamp,
}
}
fn field_to_logical_value(value: &FieldValue) -> LogicalValue {
match value {
FieldValue::Null => LogicalValue::Null,
FieldValue::Integer(i) => LogicalValue::Int(*i),
FieldValue::Float(f) => LogicalValue::Float(*f),
FieldValue::String(s) => LogicalValue::Text(s.clone()),
FieldValue::Boolean(b) => LogicalValue::Bool(*b),
FieldValue::Bytes(b) => LogicalValue::Bytes(b.clone()),
FieldValue::Timestamp(t) => LogicalValue::Timestamp(*t as i64),
}
}
pub async fn stats(&self) -> PipelineStats {
self.stats.read().await.clone()
}
pub async fn reset_stats(&self) {
*self.stats.write().await = PipelineStats::default();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DecodedWalData {
tx_id: Option<u64>,
table: String,
schema: Option<String>,
row: ChangeRow,
old_row: Option<ChangeRow>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_table_filter_exact_match() {
let filter = TableFilter::include("users");
assert!(filter.matches("users"));
assert!(!filter.matches("orders"));
}
#[test]
fn test_table_filter_wildcard() {
let filter = TableFilter::include("audit_*");
assert!(filter.matches("audit_log"));
assert!(filter.matches("audit_events"));
assert!(!filter.matches("users"));
}
#[test]
fn test_table_filter_all() {
let filter = TableFilter::include("*");
assert!(filter.matches("anything"));
assert!(filter.matches("any_table"));
}
#[test]
fn test_row_filter_equality() {
let filter = RowFilter::new("users", "status = 'active'");
let mut row = ChangeRow::new();
row.set_field("status", FieldValue::String("active".to_string()));
assert!(filter.evaluate(&row));
row.set_field("status", FieldValue::String("deleted".to_string()));
assert!(!filter.evaluate(&row));
}
#[test]
fn test_row_filter_inequality() {
let filter = RowFilter::new("users", "status != 'deleted'");
let mut row = ChangeRow::new();
row.set_field("status", FieldValue::String("active".to_string()));
assert!(filter.evaluate(&row));
row.set_field("status", FieldValue::String("deleted".to_string()));
assert!(!filter.evaluate(&row));
}
#[test]
fn test_row_filter_comparison() {
let filter = RowFilter::new("orders", "amount > 100");
let mut row = ChangeRow::new();
row.set_field("amount", FieldValue::Integer(150));
assert!(filter.evaluate(&row));
row.set_field("amount", FieldValue::Integer(50));
assert!(!filter.evaluate(&row));
}
#[test]
fn test_column_mapping_rename() {
let mapping = ColumnMapping::rename("users", "email", "user_email");
assert_eq!(mapping.source_column, "email");
assert_eq!(mapping.target_column, Some("user_email".to_string()));
}
#[test]
fn test_column_mapping_drop() {
let mapping = ColumnMapping::drop("users", "password_hash");
assert!(matches!(mapping.transform, ColumnTransform::Drop));
}
#[test]
fn test_field_value_display() {
assert_eq!(FieldValue::Null.to_string(), "NULL");
assert_eq!(FieldValue::Integer(42).to_string(), "42");
assert_eq!(FieldValue::Float(3.14).to_string(), "3.14");
assert_eq!(FieldValue::String("hello".to_string()).to_string(), "hello");
assert_eq!(FieldValue::Boolean(true).to_string(), "true");
}
#[tokio::test]
async fn test_pipeline_table_filtering() {
let config = LogicalReplicationConfig::new()
.add_table_filter(TableFilter::include("users"))
.add_table_filter(TableFilter::include("orders"))
.add_table_filter(TableFilter::exclude("audit_*"));
let pipeline = LogicalReplicationPipeline::new(config);
assert!(pipeline.should_replicate_table("users"));
assert!(pipeline.should_replicate_table("orders"));
assert!(!pipeline.should_replicate_table("audit_log"));
assert!(!pipeline.should_replicate_table("unknown_table"));
}
#[tokio::test]
async fn test_pipeline_row_filtering() {
let config = LogicalReplicationConfig::new()
.add_row_filter(RowFilter::new("users", "status != 'deleted'"));
let pipeline = LogicalReplicationPipeline::new(config);
let mut active_row = ChangeRow::new();
active_row.set_field("status", FieldValue::String("active".to_string()));
assert!(pipeline.should_replicate_row("users", &active_row));
let mut deleted_row = ChangeRow::new();
deleted_row.set_field("status", FieldValue::String("deleted".to_string()));
assert!(!pipeline.should_replicate_row("users", &deleted_row));
}
#[test]
fn test_transform_multiply() {
let config = LogicalReplicationConfig::new();
let pipeline = LogicalReplicationPipeline::new(config);
let value = FieldValue::Integer(100);
let result = pipeline.multiply_value(&value, 1.5).unwrap();
assert!(matches!(result, FieldValue::Integer(150)));
}
#[test]
fn test_transform_mask() {
let config = LogicalReplicationConfig::new();
let pipeline = LogicalReplicationPipeline::new(config);
let value = FieldValue::String("secret123".to_string());
let result = pipeline.mask_value(&value, '*');
assert!(matches!(result, FieldValue::String(s) if s == "*********"));
}
#[test]
fn test_transform_cast() {
let config = LogicalReplicationConfig::new();
let pipeline = LogicalReplicationPipeline::new(config);
let value = FieldValue::Integer(42);
let result = pipeline.cast_value(&value, DataType::String).unwrap();
assert!(matches!(result, FieldValue::String(s) if s == "42"));
}
#[test]
fn test_transform_hash() {
let config = LogicalReplicationConfig::new();
let pipeline = LogicalReplicationPipeline::new(config);
let value = FieldValue::String("test".to_string());
let result = pipeline.hash_value(&value);
assert!(matches!(result, FieldValue::String(s) if !s.is_empty()));
}
#[tokio::test]
async fn test_pipeline_stats() {
let config = LogicalReplicationConfig::new();
let pipeline = LogicalReplicationPipeline::new(config);
let stats = pipeline.stats().await;
assert_eq!(stats.entries_processed, 0);
assert_eq!(stats.entries_passed, 0);
}
}