Skip to main content

nodedb_array/query/
elementwise.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Pairwise binary ops between two coord-aligned sparse tiles.
4//!
5//! Both inputs must share the same schema (same dim arity, same attr
6//! columns, same dtypes). We outer-join on coordinates: cells present
7//! in only one operand contribute a Null on the other side, which
8//! propagates as Null through the op (SQL-style null semantics).
9//!
10//! Numeric attrs (Int64, Float64) participate in the op; non-numeric
11//! attrs (String, Bytes) are passed through from the left operand
12//! unchanged — the op is undefined for them. Division by zero yields
13//! Null rather than infinity so downstream aggregates stay finite.
14
15use std::collections::BTreeMap;
16
17use crate::error::ArrayResult;
18use crate::schema::ArraySchema;
19use crate::tile::sparse_tile::{SparseTile, SparseTileBuilder};
20use crate::types::cell_value::value::CellValue;
21use crate::types::coord::value::CoordValue;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinaryOp {
25    Add,
26    Sub,
27    Mul,
28    Div,
29}
30
31/// Apply `op` to every aligned pair of cells. The result tile carries
32/// the union of both coord sets.
33pub fn elementwise(
34    schema: &ArraySchema,
35    a: &SparseTile,
36    b: &SparseTile,
37    op: BinaryOp,
38) -> ArrayResult<SparseTile> {
39    let n_attrs = schema.attrs.len();
40    let by_coord_a = index_rows(a);
41    let by_coord_b = index_rows(b);
42    let mut keys: BTreeMap<Vec<CoordKey>, ()> = BTreeMap::new();
43    for k in by_coord_a.keys().chain(by_coord_b.keys()) {
44        keys.insert(k.clone(), ());
45    }
46
47    let mut builder = SparseTileBuilder::new(schema);
48    for key in keys.keys() {
49        let coord = decode_key(key);
50        let lhs = by_coord_a.get(key);
51        let rhs = by_coord_b.get(key);
52        let mut out = Vec::with_capacity(n_attrs);
53        for i in 0..n_attrs {
54            let l = lhs.map(|v| v[i].clone()).unwrap_or(CellValue::Null);
55            let r = rhs.map(|v| v[i].clone()).unwrap_or(CellValue::Null);
56            out.push(apply(&l, &r, op));
57        }
58        builder.push(&coord, &out)?;
59    }
60    Ok(builder.build())
61}
62
63fn apply(l: &CellValue, r: &CellValue, op: BinaryOp) -> CellValue {
64    let (lf, rf) = match (to_f64(l), to_f64(r)) {
65        (Some(a), Some(b)) => (a, b),
66        _ => return passthrough(l, r),
67    };
68    let v = match op {
69        BinaryOp::Add => lf + rf,
70        BinaryOp::Sub => lf - rf,
71        BinaryOp::Mul => lf * rf,
72        BinaryOp::Div => {
73            if rf == 0.0 {
74                return CellValue::Null;
75            }
76            lf / rf
77        }
78    };
79    // Preserve Int64 when both operands were integral and result is
80    // exact — keeps integer attrs from drifting to floats unnecessarily.
81    if let (CellValue::Int64(_), CellValue::Int64(_)) = (l, r)
82        && v.fract() == 0.0
83        && v.is_finite()
84    {
85        return CellValue::Int64(v as i64);
86    }
87    CellValue::Float64(v)
88}
89
90fn passthrough(l: &CellValue, r: &CellValue) -> CellValue {
91    if !l.is_null() {
92        l.clone()
93    } else if !r.is_null() {
94        r.clone()
95    } else {
96        CellValue::Null
97    }
98}
99
100fn to_f64(v: &CellValue) -> Option<f64> {
101    match v {
102        CellValue::Int64(x) => Some(*x as f64),
103        CellValue::Float64(x) => Some(*x),
104        _ => None,
105    }
106}
107
108/// Hashable, totally-ordered key for a coordinate tuple.
109#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
110enum CoordKey {
111    I(i64),
112    Ts(i64),
113    F(u64), // f64 bits
114    S(String),
115}
116
117fn encode_key(c: &[CoordValue]) -> Vec<CoordKey> {
118    c.iter()
119        .map(|v| match v {
120            CoordValue::Int64(x) => CoordKey::I(*x),
121            CoordValue::TimestampMs(x) => CoordKey::Ts(*x),
122            CoordValue::Float64(x) => CoordKey::F(x.to_bits()),
123            CoordValue::String(s) => CoordKey::S(s.clone()),
124        })
125        .collect()
126}
127
128fn decode_key(k: &[CoordKey]) -> Vec<CoordValue> {
129    k.iter()
130        .map(|v| match v {
131            CoordKey::I(x) => CoordValue::Int64(*x),
132            CoordKey::Ts(x) => CoordValue::TimestampMs(*x),
133            CoordKey::F(b) => CoordValue::Float64(f64::from_bits(*b)),
134            CoordKey::S(s) => CoordValue::String(s.clone()),
135        })
136        .collect()
137}
138
139fn index_rows(tile: &SparseTile) -> BTreeMap<Vec<CoordKey>, Vec<CellValue>> {
140    let mut out = BTreeMap::new();
141    let n = tile.nnz() as usize;
142    for row in 0..n {
143        let coord: Vec<CoordValue> = tile
144            .dim_dicts
145            .iter()
146            .map(|d| d.values[d.indices[row] as usize].clone())
147            .collect();
148        let attrs: Vec<CellValue> = tile.attr_cols.iter().map(|col| col[row].clone()).collect();
149        out.insert(encode_key(&coord), attrs);
150    }
151    out
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::schema::ArraySchemaBuilder;
158    use crate::schema::attr_spec::{AttrSpec, AttrType};
159    use crate::schema::dim_spec::{DimSpec, DimType};
160    use crate::types::domain::{Domain, DomainBound};
161
162    fn schema() -> ArraySchema {
163        ArraySchemaBuilder::new("g")
164            .dim(DimSpec::new(
165                "k",
166                DimType::Int64,
167                Domain::new(DomainBound::Int64(0), DomainBound::Int64(15)),
168            ))
169            .attr(AttrSpec::new("v", AttrType::Int64, true))
170            .tile_extents(vec![16])
171            .build()
172            .unwrap()
173    }
174
175    fn tile(rows: &[(i64, i64)]) -> SparseTile {
176        let s = schema();
177        let mut b = SparseTileBuilder::new(&s);
178        for (k, v) in rows {
179            b.push(&[CoordValue::Int64(*k)], &[CellValue::Int64(*v)])
180                .unwrap();
181        }
182        b.build()
183    }
184
185    #[test]
186    fn add_aligned_cells() {
187        let s = schema();
188        let a = tile(&[(0, 1), (1, 2)]);
189        let b = tile(&[(0, 10), (1, 20)]);
190        let out = elementwise(&s, &a, &b, BinaryOp::Add).unwrap();
191        assert_eq!(out.nnz(), 2);
192        assert_eq!(out.attr_cols[0][0], CellValue::Int64(11));
193        assert_eq!(out.attr_cols[0][1], CellValue::Int64(22));
194    }
195
196    #[test]
197    fn outer_join_propagates_null() {
198        let s = schema();
199        let a = tile(&[(0, 1)]);
200        let b = tile(&[(1, 2)]);
201        let out = elementwise(&s, &a, &b, BinaryOp::Add).unwrap();
202        assert_eq!(out.nnz(), 2);
203        // Both rows are missing one operand → both become Null+something
204        // → passthrough gives the present operand.
205        let v0 = &out.attr_cols[0][0];
206        let v1 = &out.attr_cols[0][1];
207        assert!(matches!(v0, CellValue::Int64(1) | CellValue::Int64(2)));
208        assert!(matches!(v1, CellValue::Int64(1) | CellValue::Int64(2)));
209    }
210
211    #[test]
212    fn div_by_zero_returns_null() {
213        let s = schema();
214        let a = tile(&[(0, 5)]);
215        let b = tile(&[(0, 0)]);
216        let out = elementwise(&s, &a, &b, BinaryOp::Div).unwrap();
217        assert_eq!(out.attr_cols[0][0], CellValue::Null);
218    }
219
220    #[test]
221    fn sub_and_mul() {
222        let s = schema();
223        let a = tile(&[(0, 10)]);
224        let b = tile(&[(0, 3)]);
225        let s1 = elementwise(&s, &a, &b, BinaryOp::Sub).unwrap();
226        let s2 = elementwise(&s, &a, &b, BinaryOp::Mul).unwrap();
227        assert_eq!(s1.attr_cols[0][0], CellValue::Int64(7));
228        assert_eq!(s2.attr_cols[0][0], CellValue::Int64(30));
229    }
230}