use std::fmt::Debug;
use reifydb_type::{
storage::{DataBitVec, DataVec},
util::bitvec::BitVec,
value::{
Value,
container::{
bool::BoolContainer, number::NumberContainer, temporal::TemporalContainer, uuid::UuidContainer,
},
date::Date,
datetime::DateTime,
duration::Duration,
is::{IsNumber, IsTemporal, IsUuid},
time::Time,
uuid::{Uuid4, Uuid7},
},
};
use crate::value::column::ColumnData;
impl ColumnData {
pub fn scatter_merge(
&self,
other: &ColumnData,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> ColumnData {
if let (
ColumnData::Option {
inner: a_inner,
bitvec: a_bv,
},
ColumnData::Option {
inner: b_inner,
bitvec: b_bv,
},
) = (self, other)
{
let merged_inner = a_inner.scatter_merge(b_inner, then_mask, else_mask, total_len);
let merged_bv = merge_validity_bitvecs(a_bv, b_bv, then_mask, else_mask, total_len);
return match merged_inner {
ColumnData::Option {
inner: nested_inner,
bitvec: nested_bv,
} => ColumnData::Option {
inner: nested_inner,
bitvec: merged_bv.and(&nested_bv),
},
inner => ColumnData::Option {
inner: Box::new(inner),
bitvec: merged_bv,
},
};
}
if let Some(result) = scatter_merge_typed(self, other, then_mask, else_mask, total_len) {
return result;
}
scatter_merge_generic(self, other, then_mask, else_mask, total_len)
}
}
fn merge_validity_bitvecs(
then_bv: &BitVec,
else_bv: &BitVec,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> BitVec {
let mut out = BitVec::with_capacity(total_len);
for i in 0..total_len {
let bit = if DataBitVec::get(then_mask, i) {
i < DataBitVec::len(then_bv) && DataBitVec::get(then_bv, i)
} else if DataBitVec::get(else_mask, i) {
i < DataBitVec::len(else_bv) && DataBitVec::get(else_bv, i)
} else {
false
};
DataBitVec::push(&mut out, bit);
}
out
}
fn scatter_merge_generic(
self_col: &ColumnData,
other: &ColumnData,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> ColumnData {
let result_type = self_col.get_type();
let mut data = ColumnData::with_capacity(result_type.clone(), total_len);
for i in 0..total_len {
if DataBitVec::get(then_mask, i) {
data.push_value(self_col.get_value(i));
} else if DataBitVec::get(else_mask, i) {
data.push_value(other.get_value(i));
} else {
data.push_value(Value::none_of(result_type.clone()));
}
}
data
}
fn scatter_merge_typed(
self_col: &ColumnData,
other: &ColumnData,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> Option<ColumnData> {
macro_rules! number_kernel {
($variant:ident, $t:ty) => {
if let (ColumnData::$variant(a), ColumnData::$variant(b)) = (self_col, other) {
let (data, validity) = number_scatter::<$t>(a, b, then_mask, else_mask, total_len);
let inner = ColumnData::$variant(NumberContainer::new(data));
return Some(finalize(inner, validity));
}
};
}
macro_rules! temporal_kernel {
($variant:ident, $t:ty) => {
if let (ColumnData::$variant(a), ColumnData::$variant(b)) = (self_col, other) {
let (data, validity) = temporal_scatter::<$t>(a, b, then_mask, else_mask, total_len);
let inner = ColumnData::$variant(TemporalContainer::new(data));
return Some(finalize(inner, validity));
}
};
}
macro_rules! uuid_kernel {
($variant:ident, $t:ty) => {
if let (ColumnData::$variant(a), ColumnData::$variant(b)) = (self_col, other) {
let (data, validity) = uuid_scatter::<$t>(a, b, then_mask, else_mask, total_len);
let inner = ColumnData::$variant(UuidContainer::new(data));
return Some(finalize(inner, validity));
}
};
}
if let (ColumnData::Bool(a), ColumnData::Bool(b)) = (self_col, other) {
let (data, validity) = bool_scatter(a, b, then_mask, else_mask, total_len);
let inner = ColumnData::Bool(BoolContainer::from_parts(data));
return Some(finalize(inner, validity));
}
number_kernel!(Float4, f32);
number_kernel!(Float8, f64);
number_kernel!(Int1, i8);
number_kernel!(Int2, i16);
number_kernel!(Int4, i32);
number_kernel!(Int8, i64);
number_kernel!(Int16, i128);
number_kernel!(Uint1, u8);
number_kernel!(Uint2, u16);
number_kernel!(Uint4, u32);
number_kernel!(Uint8, u64);
number_kernel!(Uint16, u128);
temporal_kernel!(Date, Date);
temporal_kernel!(DateTime, DateTime);
temporal_kernel!(Time, Time);
temporal_kernel!(Duration, Duration);
uuid_kernel!(Uuid4, Uuid4);
uuid_kernel!(Uuid7, Uuid7);
None
}
fn finalize(inner: ColumnData, validity: Option<BitVec>) -> ColumnData {
match validity {
Some(bv) => ColumnData::Option {
inner: Box::new(inner),
bitvec: bv,
},
None => inner,
}
}
fn bool_scatter(
a: &BoolContainer,
b: &BoolContainer,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> (BitVec, Option<BitVec>) {
let a_data = a.data();
let b_data = b.data();
let mut out = BitVec::with_capacity(total_len);
let mut validity: Option<BitVec> = None;
for i in 0..total_len {
let in_then = DataBitVec::get(then_mask, i);
let in_else = !in_then && DataBitVec::get(else_mask, i);
let bit = if in_then && i < DataBitVec::len(a_data) {
DataBitVec::get(a_data, i)
} else if in_else && i < DataBitVec::len(b_data) {
DataBitVec::get(b_data, i)
} else {
false
};
DataBitVec::push(&mut out, bit);
if !in_then && !in_else {
let v = validity.get_or_insert_with(|| {
let mut bv = BitVec::with_capacity(total_len);
for _ in 0..i {
DataBitVec::push(&mut bv, true);
}
bv
});
DataBitVec::push(v, false);
} else if let Some(v) = validity.as_mut() {
DataBitVec::push(v, true);
}
}
(out, validity)
}
fn number_scatter<T>(
a: &NumberContainer<T>,
b: &NumberContainer<T>,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> (Vec<T>, Option<BitVec>)
where
T: IsNumber + Clone + Default + Debug,
{
let a_data = a.data();
let b_data = b.data();
let mut out: Vec<T> = Vec::with_capacity(total_len);
let mut validity: Option<BitVec> = None;
for i in 0..total_len {
let in_then = DataBitVec::get(then_mask, i);
let in_else = !in_then && DataBitVec::get(else_mask, i);
let value = if in_then {
DataVec::get(a_data, i).cloned().unwrap_or_default()
} else if in_else {
DataVec::get(b_data, i).cloned().unwrap_or_default()
} else {
T::default()
};
out.push(value);
if !in_then && !in_else {
let v = validity.get_or_insert_with(|| {
let mut bv = BitVec::with_capacity(total_len);
for _ in 0..i {
DataBitVec::push(&mut bv, true);
}
bv
});
DataBitVec::push(v, false);
} else if let Some(v) = validity.as_mut() {
DataBitVec::push(v, true);
}
}
(out, validity)
}
fn temporal_scatter<T>(
a: &TemporalContainer<T>,
b: &TemporalContainer<T>,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> (Vec<T>, Option<BitVec>)
where
T: IsTemporal + Clone + Default + Debug,
{
let a_data = a.data();
let b_data = b.data();
let mut out: Vec<T> = Vec::with_capacity(total_len);
let mut validity: Option<BitVec> = None;
for i in 0..total_len {
let in_then = DataBitVec::get(then_mask, i);
let in_else = !in_then && DataBitVec::get(else_mask, i);
let value = if in_then {
DataVec::get(a_data, i).cloned().unwrap_or_default()
} else if in_else {
DataVec::get(b_data, i).cloned().unwrap_or_default()
} else {
T::default()
};
out.push(value);
if !in_then && !in_else {
let v = validity.get_or_insert_with(|| {
let mut bv = BitVec::with_capacity(total_len);
for _ in 0..i {
DataBitVec::push(&mut bv, true);
}
bv
});
DataBitVec::push(v, false);
} else if let Some(v) = validity.as_mut() {
DataBitVec::push(v, true);
}
}
(out, validity)
}
fn uuid_scatter<T>(
a: &UuidContainer<T>,
b: &UuidContainer<T>,
then_mask: &BitVec,
else_mask: &BitVec,
total_len: usize,
) -> (Vec<T>, Option<BitVec>)
where
T: IsUuid + Clone + Default + Debug,
{
let a_data = a.data();
let b_data = b.data();
let mut out: Vec<T> = Vec::with_capacity(total_len);
let mut validity: Option<BitVec> = None;
for i in 0..total_len {
let in_then = DataBitVec::get(then_mask, i);
let in_else = !in_then && DataBitVec::get(else_mask, i);
let value = if in_then {
DataVec::get(a_data, i).cloned().unwrap_or_default()
} else if in_else {
DataVec::get(b_data, i).cloned().unwrap_or_default()
} else {
T::default()
};
out.push(value);
if !in_then && !in_else {
let v = validity.get_or_insert_with(|| {
let mut bv = BitVec::with_capacity(total_len);
for _ in 0..i {
DataBitVec::push(&mut bv, true);
}
bv
});
DataBitVec::push(v, false);
} else if let Some(v) = validity.as_mut() {
DataBitVec::push(v, true);
}
}
(out, validity)
}
#[cfg(test)]
mod tests {
use reifydb_type::{util::bitvec::BitVec, value::Value};
use crate::value::column::ColumnData;
#[test]
fn scatter_merge_all_mapped_int4() {
let a = ColumnData::int4([10, 20, 30, 40]);
let b = ColumnData::int4([90, 80, 70, 60]);
let then_mask = BitVec::from_slice(&[true, false, true, false]);
let else_mask = BitVec::from_slice(&[false, true, false, true]);
let merged = a.scatter_merge(&b, &then_mask, &else_mask, 4);
assert!(matches!(merged, ColumnData::Int4(_)));
assert_eq!(merged.get_value(0), Value::Int4(10));
assert_eq!(merged.get_value(1), Value::Int4(80));
assert_eq!(merged.get_value(2), Value::Int4(30));
assert_eq!(merged.get_value(3), Value::Int4(60));
}
#[test]
fn scatter_merge_unmapped_promotes_to_option() {
let a = ColumnData::int4([10, 20, 30]);
let b = ColumnData::int4([90, 80, 70]);
let then_mask = BitVec::from_slice(&[true, false, true]);
let else_mask = BitVec::from_slice(&[false, false, false]);
let merged = a.scatter_merge(&b, &then_mask, &else_mask, 3);
assert!(matches!(merged, ColumnData::Option { .. }));
assert_eq!(merged.get_value(0), Value::Int4(10));
assert_eq!(merged.get_value(1), Value::none());
assert_eq!(merged.get_value(2), Value::Int4(30));
}
#[test]
fn scatter_merge_bool_all_mapped() {
let a = ColumnData::bool([true, true, false, false]);
let b = ColumnData::bool([false, false, true, true]);
let then_mask = BitVec::from_slice(&[true, false, true, false]);
let else_mask = BitVec::from_slice(&[false, true, false, true]);
let merged = a.scatter_merge(&b, &then_mask, &else_mask, 4);
assert!(matches!(merged, ColumnData::Bool(_)));
assert_eq!(merged.get_value(0), Value::Boolean(true));
assert_eq!(merged.get_value(1), Value::Boolean(false));
assert_eq!(merged.get_value(2), Value::Boolean(false));
assert_eq!(merged.get_value(3), Value::Boolean(true));
}
#[test]
fn scatter_merge_utf8_uses_generic_fallback() {
let a = ColumnData::utf8(["a", "b", "c"]);
let b = ColumnData::utf8(["x", "y", "z"]);
let then_mask = BitVec::from_slice(&[true, false, true]);
let else_mask = BitVec::from_slice(&[false, true, false]);
let merged = a.scatter_merge(&b, &then_mask, &else_mask, 3);
assert_eq!(merged.get_value(0), Value::Utf8("a".to_string()));
assert_eq!(merged.get_value(1), Value::Utf8("y".to_string()));
assert_eq!(merged.get_value(2), Value::Utf8("c".to_string()));
}
}