Skip to main content

exoware_server/
reduce.rs

1//! Range aggregation over decoded KV rows (same semantics as the public `Reduce` RPC).
2
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5
6use bytes::Bytes;
7use exoware_proto::{
8    RangeReduceGroup, RangeReduceOp, RangeReduceRequest, RangeReduceResponse, RangeReduceResult,
9};
10use exoware_sdk_rs as exoware_proto;
11use exoware_sdk_rs::keys::Key;
12use exoware_sdk_rs::kv_codec::{
13    canonicalize_reduced_group_values, decode_stored_row, encode_reduced_group_key, eval_expr,
14    eval_predicate, expr_needs_value, predicate_needs_value, KvReducedValue,
15};
16
17#[derive(Debug)]
18pub enum RangeError {
19    Reduce(String),
20}
21
22impl std::fmt::Display for RangeError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            RangeError::Reduce(s) => write!(f, "{s}"),
26        }
27    }
28}
29
30impl std::error::Error for RangeError {}
31
32#[derive(Debug)]
33enum ReductionState {
34    Count(u64),
35    Sum(Option<KvReducedValue>),
36    Min(Option<KvReducedValue>),
37    Max(Option<KvReducedValue>),
38}
39
40#[derive(Debug)]
41struct GroupedReductionState {
42    group_values: Vec<Option<KvReducedValue>>,
43    states: Vec<ReductionState>,
44}
45
46#[derive(Debug)]
47struct ExtractedReductionRow {
48    group_values: Vec<Option<KvReducedValue>>,
49    reducer_values: Vec<Option<KvReducedValue>>,
50}
51
52impl ReductionState {
53    fn from_op(op: RangeReduceOp) -> Self {
54        match op {
55            RangeReduceOp::CountAll | RangeReduceOp::CountField => Self::Count(0),
56            RangeReduceOp::SumField => Self::Sum(None),
57            RangeReduceOp::MinField => Self::Min(None),
58            RangeReduceOp::MaxField => Self::Max(None),
59        }
60    }
61
62    fn update(
63        &mut self,
64        op: RangeReduceOp,
65        value: Option<KvReducedValue>,
66    ) -> Result<(), RangeError> {
67        match (self, op) {
68            (Self::Count(count), RangeReduceOp::CountAll) => {
69                *count = count.saturating_add(1);
70                Ok(())
71            }
72            (Self::Count(count), RangeReduceOp::CountField) => {
73                if value.is_some() {
74                    *count = count.saturating_add(1);
75                }
76                Ok(())
77            }
78            (Self::Sum(sum), RangeReduceOp::SumField) => {
79                let Some(value) = value else {
80                    return Ok(());
81                };
82                match sum {
83                    Some(existing) => existing
84                        .checked_add_assign(&value)
85                        .map_err(RangeError::Reduce),
86                    None => {
87                        *sum = Some(value);
88                        Ok(())
89                    }
90                }
91            }
92            (Self::Min(current), RangeReduceOp::MinField) => {
93                update_extreme(current, value, Ordering::Less)
94            }
95            (Self::Max(current), RangeReduceOp::MaxField) => {
96                update_extreme(current, value, Ordering::Greater)
97            }
98            _ => Err(RangeError::Reduce(
99                "reduction state/op mismatch".to_string(),
100            )),
101        }
102    }
103
104    fn finish(self) -> Option<KvReducedValue> {
105        match self {
106            Self::Count(count) => Some(KvReducedValue::UInt64(count)),
107            Self::Sum(value) | Self::Min(value) | Self::Max(value) => value,
108        }
109    }
110}
111
112impl GroupedReductionState {
113    fn new(group_values: Vec<Option<KvReducedValue>>, request: &RangeReduceRequest) -> Self {
114        Self {
115            group_values,
116            states: request
117                .reducers
118                .iter()
119                .map(|reducer| ReductionState::from_op(reducer.op))
120                .collect(),
121        }
122    }
123
124    fn update(
125        &mut self,
126        request: &RangeReduceRequest,
127        reducer_values: Vec<Option<KvReducedValue>>,
128    ) -> Result<(), RangeError> {
129        for ((state, reducer), value) in self
130            .states
131            .iter_mut()
132            .zip(request.reducers.iter())
133            .zip(reducer_values.into_iter())
134        {
135            state.update(reducer.op, value)?;
136        }
137        Ok(())
138    }
139
140    fn finish(self) -> RangeReduceGroup {
141        RangeReduceGroup {
142            group_values: self.group_values,
143            results: self
144                .states
145                .into_iter()
146                .map(|state| RangeReduceResult {
147                    value: state.finish(),
148                })
149                .collect(),
150        }
151    }
152}
153
154fn update_extreme(
155    current: &mut Option<KvReducedValue>,
156    candidate: Option<KvReducedValue>,
157    replace_when: Ordering,
158) -> Result<(), RangeError> {
159    let Some(candidate) = candidate else {
160        return Ok(());
161    };
162    match current {
163        Some(existing) => {
164            let ordering = candidate
165                .partial_cmp_same_kind(existing)
166                .ok_or_else(|| RangeError::Reduce("min/max type mismatch".to_string()))?;
167            if ordering == replace_when {
168                *current = Some(candidate);
169            }
170        }
171        None => {
172            *current = Some(candidate);
173        }
174    }
175    Ok(())
176}
177
178fn validate_reduce_request(request: &RangeReduceRequest) -> Result<(), RangeError> {
179    if request.reducers.is_empty() && request.group_by.is_empty() {
180        return Err(RangeError::Reduce(
181            "range reduction request requires at least one reducer or group-by field".to_string(),
182        ));
183    }
184    for reducer in &request.reducers {
185        match reducer.op {
186            RangeReduceOp::CountAll => {
187                if reducer.expr.is_some() {
188                    return Err(RangeError::Reduce(
189                        "count_all reducer must not specify an expression".to_string(),
190                    ));
191                }
192            }
193            RangeReduceOp::CountField
194            | RangeReduceOp::SumField
195            | RangeReduceOp::MinField
196            | RangeReduceOp::MaxField => {
197                if reducer.expr.is_none() {
198                    return Err(RangeError::Reduce(
199                        "expression reducer requires an expression".to_string(),
200                    ));
201                }
202            }
203        }
204    }
205    Ok(())
206}
207
208fn reduce_row_into_response(
209    key: &Key,
210    value: &Bytes,
211    request: &RangeReduceRequest,
212    scalar_states: Option<&mut [ReductionState]>,
213    grouped_states: &mut BTreeMap<Vec<u8>, GroupedReductionState>,
214) -> Result<(), RangeError> {
215    let Some(extracted) = extract_reduce_row(key, value, request)? else {
216        return Ok(());
217    };
218
219    if request.group_by.is_empty() {
220        let Some(states) = scalar_states else {
221            return Err(RangeError::Reduce(
222                "missing scalar reduction state for non-grouped request".to_string(),
223            ));
224        };
225        for ((state, reducer), value) in states
226            .iter_mut()
227            .zip(request.reducers.iter())
228            .zip(extracted.reducer_values.into_iter())
229        {
230            state.update(reducer.op, value)?;
231        }
232        return Ok(());
233    }
234
235    let group_key = encode_reduced_group_key(&extracted.group_values);
236    let group = grouped_states
237        .entry(group_key)
238        .or_insert_with(|| GroupedReductionState::new(extracted.group_values.clone(), request));
239    group.update(request, extracted.reducer_values)?;
240    Ok(())
241}
242
243fn extract_reduce_row(
244    key: &Key,
245    value: &Bytes,
246    request: &RangeReduceRequest,
247) -> Result<Option<ExtractedReductionRow>, RangeError> {
248    let needs_value = request
249        .group_by
250        .iter()
251        .chain(
252            request
253                .reducers
254                .iter()
255                .filter_map(|reducer| reducer.expr.as_ref()),
256        )
257        .any(expr_needs_value)
258        || request.filter.as_ref().is_some_and(predicate_needs_value);
259    let decoded = if needs_value {
260        match decode_stored_row(value.as_ref()) {
261            Ok(row) => Some(row),
262            Err(_) => return Ok(None),
263        }
264    } else {
265        None
266    };
267    let archived = decoded.as_ref();
268
269    if let Some(filter) = &request.filter {
270        match eval_predicate(key, archived, filter) {
271            Ok(true) => {}
272            Ok(false) => return Ok(None),
273            Err(_) => return Ok(None),
274        }
275    }
276
277    let mut group_values = Vec::with_capacity(request.group_by.len());
278    for expr in &request.group_by {
279        let extracted_value = match eval_expr(key, archived, expr) {
280            Ok(value) => value,
281            Err(_) => return Ok(None),
282        };
283        group_values.push(extracted_value);
284    }
285    canonicalize_reduced_group_values(&mut group_values);
286
287    let mut reducer_values = Vec::with_capacity(request.reducers.len());
288    for reducer in &request.reducers {
289        let extracted_value = match (&reducer.expr, archived) {
290            (None, _) => None,
291            (Some(expr), _) => match eval_expr(key, archived, expr) {
292                Ok(value) => value,
293                Err(_) => return Ok(None),
294            },
295        };
296        reducer_values.push(extracted_value);
297    }
298
299    Ok(Some(ExtractedReductionRow {
300        group_values,
301        reducer_values,
302    }))
303}
304
305fn finalize_reduce_response(
306    scalar_states: Option<Vec<ReductionState>>,
307    grouped_states: BTreeMap<Vec<u8>, GroupedReductionState>,
308) -> RangeReduceResponse {
309    match scalar_states {
310        Some(states) => RangeReduceResponse {
311            results: states
312                .into_iter()
313                .map(|state| RangeReduceResult {
314                    value: state.finish(),
315                })
316                .collect(),
317            groups: Vec::new(),
318        },
319        None => RangeReduceResponse {
320            results: Vec::new(),
321            groups: grouped_states
322                .into_values()
323                .map(GroupedReductionState::finish)
324                .collect(),
325        },
326    }
327}
328
329/// Run a grouped or scalar reduction over materialized rows.
330pub fn reduce_over_rows(
331    rows: &[(Key, Bytes)],
332    request: &RangeReduceRequest,
333) -> Result<RangeReduceResponse, RangeError> {
334    validate_reduce_request(request)?;
335    let mut scalar_states = request.group_by.is_empty().then(|| {
336        request
337            .reducers
338            .iter()
339            .map(|reducer| ReductionState::from_op(reducer.op))
340            .collect::<Vec<_>>()
341    });
342    let mut grouped_states = BTreeMap::<Vec<u8>, GroupedReductionState>::new();
343
344    for (key, value) in rows {
345        reduce_row_into_response(
346            key,
347            value,
348            request,
349            scalar_states.as_deref_mut(),
350            &mut grouped_states,
351        )?;
352    }
353
354    Ok(finalize_reduce_response(scalar_states, grouped_states))
355}
356
357#[cfg(test)]
358mod tests {
359    use bytes::Bytes;
360    use commonware_codec::Encode as _;
361    use exoware_sdk_rs::keys::Key;
362    use exoware_sdk_rs::kv_codec::{
363        KvExpr, KvFieldKind, KvFieldRef, KvPredicate, KvPredicateCheck, KvPredicateConstraint,
364        KvReducedValue, StoredRow, StoredValue,
365    };
366    use exoware_sdk_rs::{RangeReduceOp, RangeReduceRequest, RangeReducerSpec};
367
368    use super::reduce_over_rows;
369
370    fn make_row(key: &[u8], values: Vec<Option<StoredValue>>) -> (Key, Bytes) {
371        let encoded = StoredRow { values }.encode();
372        (Key::from(key.to_vec()), encoded)
373    }
374
375    fn reducer(op: RangeReduceOp, expr: Option<KvExpr>) -> RangeReducerSpec {
376        RangeReducerSpec { op, expr }
377    }
378
379    fn int64_value_field(index: u16) -> KvExpr {
380        KvExpr::Field(KvFieldRef::Value {
381            index,
382            kind: KvFieldKind::Int64,
383            nullable: true,
384        })
385    }
386
387    fn float64_value_field(index: u16) -> KvExpr {
388        KvExpr::Field(KvFieldRef::Value {
389            index,
390            kind: KvFieldKind::Float64,
391            nullable: true,
392        })
393    }
394
395    fn utf8_value_field(index: u16) -> KvExpr {
396        KvExpr::Field(KvFieldRef::Value {
397            index,
398            kind: KvFieldKind::Utf8,
399            nullable: true,
400        })
401    }
402
403    fn scalar_request(reducers: Vec<RangeReducerSpec>) -> RangeReduceRequest {
404        RangeReduceRequest {
405            reducers,
406            group_by: Vec::new(),
407            filter: None,
408        }
409    }
410
411    fn result_u64(v: u64) -> Option<KvReducedValue> {
412        Some(KvReducedValue::UInt64(v))
413    }
414
415    fn result_i64(v: i64) -> Option<KvReducedValue> {
416        Some(KvReducedValue::Int64(v))
417    }
418
419    fn result_f64(v: f64) -> Option<KvReducedValue> {
420        Some(KvReducedValue::Float64(v))
421    }
422
423    #[test]
424    fn count_all_over_empty_rows() {
425        let request = scalar_request(vec![reducer(RangeReduceOp::CountAll, None)]);
426        let response = reduce_over_rows(&[], &request).unwrap();
427        assert_eq!(response.results.len(), 1);
428        assert_eq!(response.results[0].value, result_u64(0));
429    }
430
431    #[test]
432    fn count_all_over_multiple_rows() {
433        let rows = vec![
434            make_row(b"a", vec![]),
435            make_row(b"b", vec![]),
436            make_row(b"c", vec![]),
437        ];
438        let request = scalar_request(vec![reducer(RangeReduceOp::CountAll, None)]);
439        let response = reduce_over_rows(&rows, &request).unwrap();
440        assert_eq!(response.results[0].value, result_u64(3));
441    }
442
443    #[test]
444    fn count_field_skips_nulls() {
445        let rows = vec![
446            make_row(b"a", vec![Some(StoredValue::Int64(1))]),
447            make_row(b"b", vec![None]),
448            make_row(b"c", vec![Some(StoredValue::Int64(3))]),
449        ];
450        let request = scalar_request(vec![reducer(
451            RangeReduceOp::CountField,
452            Some(int64_value_field(0)),
453        )]);
454        let response = reduce_over_rows(&rows, &request).unwrap();
455        assert_eq!(response.results[0].value, result_u64(2));
456    }
457
458    #[test]
459    fn sum_int64_values() {
460        let rows = vec![
461            make_row(b"a", vec![Some(StoredValue::Int64(10))]),
462            make_row(b"b", vec![Some(StoredValue::Int64(20))]),
463            make_row(b"c", vec![Some(StoredValue::Int64(-5))]),
464        ];
465        let request = scalar_request(vec![reducer(
466            RangeReduceOp::SumField,
467            Some(int64_value_field(0)),
468        )]);
469        let response = reduce_over_rows(&rows, &request).unwrap();
470        assert_eq!(response.results[0].value, result_i64(25));
471    }
472
473    #[test]
474    fn sum_float64_values() {
475        let rows = vec![
476            make_row(b"a", vec![Some(StoredValue::Float64(1.5))]),
477            make_row(b"b", vec![Some(StoredValue::Float64(2.5))]),
478        ];
479        let request = scalar_request(vec![reducer(
480            RangeReduceOp::SumField,
481            Some(float64_value_field(0)),
482        )]);
483        let response = reduce_over_rows(&rows, &request).unwrap();
484        assert_eq!(response.results[0].value, result_f64(4.0));
485    }
486
487    #[test]
488    fn min_selects_smallest() {
489        let rows = vec![
490            make_row(b"a", vec![Some(StoredValue::Int64(30))]),
491            make_row(b"b", vec![Some(StoredValue::Int64(10))]),
492            make_row(b"c", vec![Some(StoredValue::Int64(20))]),
493        ];
494        let request = scalar_request(vec![reducer(
495            RangeReduceOp::MinField,
496            Some(int64_value_field(0)),
497        )]);
498        let response = reduce_over_rows(&rows, &request).unwrap();
499        assert_eq!(response.results[0].value, result_i64(10));
500    }
501
502    #[test]
503    fn max_selects_largest() {
504        let rows = vec![
505            make_row(b"a", vec![Some(StoredValue::Int64(30))]),
506            make_row(b"b", vec![Some(StoredValue::Int64(10))]),
507            make_row(b"c", vec![Some(StoredValue::Int64(50))]),
508        ];
509        let request = scalar_request(vec![reducer(
510            RangeReduceOp::MaxField,
511            Some(int64_value_field(0)),
512        )]);
513        let response = reduce_over_rows(&rows, &request).unwrap();
514        assert_eq!(response.results[0].value, result_i64(50));
515    }
516
517    #[test]
518    fn grouped_count() {
519        let rows = vec![
520            make_row(b"a", vec![Some(StoredValue::Utf8("x".into()))]),
521            make_row(b"b", vec![Some(StoredValue::Utf8("y".into()))]),
522            make_row(b"c", vec![Some(StoredValue::Utf8("x".into()))]),
523            make_row(b"d", vec![Some(StoredValue::Utf8("y".into()))]),
524            make_row(b"e", vec![Some(StoredValue::Utf8("x".into()))]),
525        ];
526        let request = RangeReduceRequest {
527            reducers: vec![reducer(RangeReduceOp::CountAll, None)],
528            group_by: vec![utf8_value_field(0)],
529            filter: None,
530        };
531        let response = reduce_over_rows(&rows, &request).unwrap();
532        assert!(response.results.is_empty());
533        assert_eq!(response.groups.len(), 2);
534
535        let mut counts: Vec<(Option<KvReducedValue>, Option<KvReducedValue>)> = response
536            .groups
537            .iter()
538            .map(|g| (g.group_values[0].clone(), g.results[0].value.clone()))
539            .collect();
540        counts.sort_by(|a, b| {
541            let a_str = match &a.0 {
542                Some(KvReducedValue::Utf8(s)) => s.clone(),
543                _ => String::new(),
544            };
545            let b_str = match &b.0 {
546                Some(KvReducedValue::Utf8(s)) => s.clone(),
547                _ => String::new(),
548            };
549            a_str.cmp(&b_str)
550        });
551        assert_eq!(
552            counts,
553            vec![
554                (Some(KvReducedValue::Utf8("x".into())), result_u64(3),),
555                (Some(KvReducedValue::Utf8("y".into())), result_u64(2),),
556            ]
557        );
558    }
559
560    #[test]
561    fn validates_empty_request() {
562        let request = RangeReduceRequest {
563            reducers: Vec::new(),
564            group_by: Vec::new(),
565            filter: None,
566        };
567        let err = reduce_over_rows(&[], &request).unwrap_err();
568        assert!(
569            err.to_string().contains("at least one reducer"),
570            "unexpected error: {err}"
571        );
572    }
573
574    #[test]
575    fn count_all_rejects_expression() {
576        let request = scalar_request(vec![reducer(
577            RangeReduceOp::CountAll,
578            Some(int64_value_field(0)),
579        )]);
580        let err = reduce_over_rows(&[], &request).unwrap_err();
581        assert!(
582            err.to_string()
583                .contains("count_all reducer must not specify an expression"),
584            "unexpected error: {err}"
585        );
586    }
587
588    #[test]
589    fn expression_reducer_requires_expression() {
590        for op in [
591            RangeReduceOp::SumField,
592            RangeReduceOp::MinField,
593            RangeReduceOp::MaxField,
594            RangeReduceOp::CountField,
595        ] {
596            let request = scalar_request(vec![reducer(op, None)]);
597            let err = reduce_over_rows(&[], &request).unwrap_err();
598            assert!(
599                err.to_string()
600                    .contains("expression reducer requires an expression"),
601                "op {op:?} should require an expression, got: {err}"
602            );
603        }
604    }
605
606    #[test]
607    fn filter_excludes_rows() {
608        let rows = vec![
609            make_row(b"a", vec![Some(StoredValue::Int64(10))]),
610            make_row(b"b", vec![Some(StoredValue::Int64(20))]),
611            make_row(b"c", vec![Some(StoredValue::Int64(30))]),
612        ];
613        let request = RangeReduceRequest {
614            reducers: vec![reducer(RangeReduceOp::SumField, Some(int64_value_field(0)))],
615            group_by: Vec::new(),
616            filter: Some(KvPredicate {
617                checks: vec![KvPredicateCheck {
618                    field: KvFieldRef::Value {
619                        index: 0,
620                        kind: KvFieldKind::Int64,
621                        nullable: false,
622                    },
623                    constraint: KvPredicateConstraint::IntRange {
624                        min: Some(15),
625                        max: None,
626                    },
627                }],
628                contradiction: false,
629            }),
630        };
631        let response = reduce_over_rows(&rows, &request).unwrap();
632        assert_eq!(response.results[0].value, result_i64(50));
633    }
634
635    #[test]
636    fn mixed_type_min_max_returns_error() {
637        use super::ReductionState;
638
639        let mut state = ReductionState::Min(Some(KvReducedValue::Int64(10)));
640        let result = state.update(
641            RangeReduceOp::MinField,
642            Some(KvReducedValue::Utf8("hello".into())),
643        );
644        assert!(result.is_err());
645        assert!(
646            result.unwrap_err().to_string().contains("type mismatch"),
647            "expected type mismatch error"
648        );
649    }
650}