#![deny(clippy::disallowed_types)]
use arrow::array::{BooleanArray, RecordBatch, StringArray};
use arrow::datatypes::DataType;
use crate::error::DbError;
pub(crate) fn extract_negative_events(batch: &RecordBatch) -> Result<Option<RecordBatch>, DbError> {
let Ok(op_idx) = batch.schema().index_of("_op") else {
return Ok(None);
};
if !matches!(batch.schema().field(op_idx).data_type(), DataType::Utf8) {
return Ok(None);
}
let op_col = batch
.column(op_idx)
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DbError::Pipeline("_op column is not Utf8".into()))?;
let mask: BooleanArray = op_col
.iter()
.map(|v| Some(v.is_some_and(|s| s == "D" || s == "U-")))
.collect();
let filtered = arrow::compute::filter_record_batch(batch, &mask)
.map_err(|e| DbError::Pipeline(format!("changelog negative filter: {e}")))?;
if filtered.num_rows() == 0 {
Ok(None)
} else {
Ok(Some(filtered))
}
}
pub(crate) fn filter_positive_events(batch: &RecordBatch) -> Result<RecordBatch, DbError> {
let Ok(op_idx) = batch.schema().index_of("_op") else {
return Ok(batch.clone());
};
if !matches!(batch.schema().field(op_idx).data_type(), DataType::Utf8) {
return Ok(batch.clone());
}
let op_col = batch
.column(op_idx)
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DbError::Pipeline("_op column is not Utf8".into()))?;
let mask: BooleanArray = op_col
.iter()
.map(|v| Some(v.is_some_and(|s| s == "I" || s == "U+" || s == "U")))
.collect();
arrow::compute::filter_record_batch(batch, &mask)
.map_err(|e| DbError::Pipeline(format!("changelog filter: {e}")))
}
pub(crate) fn prepare_for_sink(batch: &RecordBatch, changelog_sink: bool) -> RecordBatch {
if changelog_sink {
return batch.clone();
}
let Ok(idx) = batch
.schema()
.index_of(crate::aggregate_state::WEIGHT_COLUMN)
else {
return batch.clone();
};
let Some(weights) = batch
.column(idx)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
else {
return batch.clone();
};
let mask: BooleanArray = weights.iter().map(|w| Some(w.unwrap_or(0) > 0)).collect();
let Ok(filtered) = arrow::compute::filter_record_batch(batch, &mask) else {
return batch.clone();
};
if filtered.num_columns() == 0 {
return filtered;
}
let indices: Vec<usize> = (0..filtered.num_columns()).filter(|&i| i != idx).collect();
filtered.project(&indices).unwrap_or(filtered)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array};
use arrow::datatypes::{Field, Schema};
use std::sync::Arc;
fn cdc_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Float64, false),
Field::new("_op", DataType::Utf8, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
Arc::new(StringArray::from(vec!["I", "D", "U+", "U-"])),
],
)
.unwrap()
}
#[test]
fn test_filter_positive_keeps_inserts_and_updates() {
let result = filter_positive_events(&cdc_batch()).unwrap();
assert_eq!(result.num_rows(), 2); let ids = result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(ids.value(0), 1); assert_eq!(ids.value(1), 3); }
#[test]
fn test_no_op_column_passthrough() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch =
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2]))]).unwrap();
let result = filter_positive_events(&batch).unwrap();
assert_eq!(result.num_rows(), 2);
}
}