1use std::collections::BTreeMap;
16
17use crate::error::ArrayResult;
18use crate::schema::ArraySchema;
19use crate::tile::sparse_tile::{SparseTile, SparseTileBuilder};
20use crate::types::cell_value::value::CellValue;
21use crate::types::coord::value::CoordValue;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinaryOp {
25 Add,
26 Sub,
27 Mul,
28 Div,
29}
30
31pub fn elementwise(
34 schema: &ArraySchema,
35 a: &SparseTile,
36 b: &SparseTile,
37 op: BinaryOp,
38) -> ArrayResult<SparseTile> {
39 let n_attrs = schema.attrs.len();
40 let by_coord_a = index_rows(a);
41 let by_coord_b = index_rows(b);
42 let mut keys: BTreeMap<Vec<CoordKey>, ()> = BTreeMap::new();
43 for k in by_coord_a.keys().chain(by_coord_b.keys()) {
44 keys.insert(k.clone(), ());
45 }
46
47 let mut builder = SparseTileBuilder::new(schema);
48 for key in keys.keys() {
49 let coord = decode_key(key);
50 let lhs = by_coord_a.get(key);
51 let rhs = by_coord_b.get(key);
52 let mut out = Vec::with_capacity(n_attrs);
53 for i in 0..n_attrs {
54 let l = lhs.map(|v| v[i].clone()).unwrap_or(CellValue::Null);
55 let r = rhs.map(|v| v[i].clone()).unwrap_or(CellValue::Null);
56 out.push(apply(&l, &r, op));
57 }
58 builder.push(&coord, &out)?;
59 }
60 Ok(builder.build())
61}
62
63fn apply(l: &CellValue, r: &CellValue, op: BinaryOp) -> CellValue {
64 let (lf, rf) = match (to_f64(l), to_f64(r)) {
65 (Some(a), Some(b)) => (a, b),
66 _ => return passthrough(l, r),
67 };
68 let v = match op {
69 BinaryOp::Add => lf + rf,
70 BinaryOp::Sub => lf - rf,
71 BinaryOp::Mul => lf * rf,
72 BinaryOp::Div => {
73 if rf == 0.0 {
74 return CellValue::Null;
75 }
76 lf / rf
77 }
78 };
79 if let (CellValue::Int64(_), CellValue::Int64(_)) = (l, r)
82 && v.fract() == 0.0
83 && v.is_finite()
84 {
85 return CellValue::Int64(v as i64);
86 }
87 CellValue::Float64(v)
88}
89
90fn passthrough(l: &CellValue, r: &CellValue) -> CellValue {
91 if !l.is_null() {
92 l.clone()
93 } else if !r.is_null() {
94 r.clone()
95 } else {
96 CellValue::Null
97 }
98}
99
100fn to_f64(v: &CellValue) -> Option<f64> {
101 match v {
102 CellValue::Int64(x) => Some(*x as f64),
103 CellValue::Float64(x) => Some(*x),
104 _ => None,
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
110enum CoordKey {
111 I(i64),
112 Ts(i64),
113 F(u64), S(String),
115}
116
117fn encode_key(c: &[CoordValue]) -> Vec<CoordKey> {
118 c.iter()
119 .map(|v| match v {
120 CoordValue::Int64(x) => CoordKey::I(*x),
121 CoordValue::TimestampMs(x) => CoordKey::Ts(*x),
122 CoordValue::Float64(x) => CoordKey::F(x.to_bits()),
123 CoordValue::String(s) => CoordKey::S(s.clone()),
124 })
125 .collect()
126}
127
128fn decode_key(k: &[CoordKey]) -> Vec<CoordValue> {
129 k.iter()
130 .map(|v| match v {
131 CoordKey::I(x) => CoordValue::Int64(*x),
132 CoordKey::Ts(x) => CoordValue::TimestampMs(*x),
133 CoordKey::F(b) => CoordValue::Float64(f64::from_bits(*b)),
134 CoordKey::S(s) => CoordValue::String(s.clone()),
135 })
136 .collect()
137}
138
139fn index_rows(tile: &SparseTile) -> BTreeMap<Vec<CoordKey>, Vec<CellValue>> {
140 let mut out = BTreeMap::new();
141 let n = tile.nnz() as usize;
142 for row in 0..n {
143 let coord: Vec<CoordValue> = tile
144 .dim_dicts
145 .iter()
146 .map(|d| d.values[d.indices[row] as usize].clone())
147 .collect();
148 let attrs: Vec<CellValue> = tile.attr_cols.iter().map(|col| col[row].clone()).collect();
149 out.insert(encode_key(&coord), attrs);
150 }
151 out
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::schema::ArraySchemaBuilder;
158 use crate::schema::attr_spec::{AttrSpec, AttrType};
159 use crate::schema::dim_spec::{DimSpec, DimType};
160 use crate::types::domain::{Domain, DomainBound};
161
162 fn schema() -> ArraySchema {
163 ArraySchemaBuilder::new("g")
164 .dim(DimSpec::new(
165 "k",
166 DimType::Int64,
167 Domain::new(DomainBound::Int64(0), DomainBound::Int64(15)),
168 ))
169 .attr(AttrSpec::new("v", AttrType::Int64, true))
170 .tile_extents(vec![16])
171 .build()
172 .unwrap()
173 }
174
175 fn tile(rows: &[(i64, i64)]) -> SparseTile {
176 let s = schema();
177 let mut b = SparseTileBuilder::new(&s);
178 for (k, v) in rows {
179 b.push(&[CoordValue::Int64(*k)], &[CellValue::Int64(*v)])
180 .unwrap();
181 }
182 b.build()
183 }
184
185 #[test]
186 fn add_aligned_cells() {
187 let s = schema();
188 let a = tile(&[(0, 1), (1, 2)]);
189 let b = tile(&[(0, 10), (1, 20)]);
190 let out = elementwise(&s, &a, &b, BinaryOp::Add).unwrap();
191 assert_eq!(out.nnz(), 2);
192 assert_eq!(out.attr_cols[0][0], CellValue::Int64(11));
193 assert_eq!(out.attr_cols[0][1], CellValue::Int64(22));
194 }
195
196 #[test]
197 fn outer_join_propagates_null() {
198 let s = schema();
199 let a = tile(&[(0, 1)]);
200 let b = tile(&[(1, 2)]);
201 let out = elementwise(&s, &a, &b, BinaryOp::Add).unwrap();
202 assert_eq!(out.nnz(), 2);
203 let v0 = &out.attr_cols[0][0];
206 let v1 = &out.attr_cols[0][1];
207 assert!(matches!(v0, CellValue::Int64(1) | CellValue::Int64(2)));
208 assert!(matches!(v1, CellValue::Int64(1) | CellValue::Int64(2)));
209 }
210
211 #[test]
212 fn div_by_zero_returns_null() {
213 let s = schema();
214 let a = tile(&[(0, 5)]);
215 let b = tile(&[(0, 0)]);
216 let out = elementwise(&s, &a, &b, BinaryOp::Div).unwrap();
217 assert_eq!(out.attr_cols[0][0], CellValue::Null);
218 }
219
220 #[test]
221 fn sub_and_mul() {
222 let s = schema();
223 let a = tile(&[(0, 10)]);
224 let b = tile(&[(0, 3)]);
225 let s1 = elementwise(&s, &a, &b, BinaryOp::Sub).unwrap();
226 let s2 = elementwise(&s, &a, &b, BinaryOp::Mul).unwrap();
227 assert_eq!(s1.attr_cols[0][0], CellValue::Int64(7));
228 assert_eq!(s2.attr_cols[0][0], CellValue::Int64(30));
229 }
230}