1use std::cmp::Ordering;
4use std::collections::BTreeMap;
5
6use bytes::Bytes;
7use exoware_proto::{
8 RangeReduceGroup, RangeReduceOp, RangeReduceRequest, RangeReduceResponse, RangeReduceResult,
9};
10use exoware_sdk_rs as exoware_proto;
11use exoware_sdk_rs::keys::Key;
12use exoware_sdk_rs::kv_codec::{
13 canonicalize_reduced_group_values, decode_stored_row, encode_reduced_group_key, eval_expr,
14 eval_predicate, expr_needs_value, predicate_needs_value, KvReducedValue,
15};
16
17#[derive(Debug)]
18pub enum RangeError {
19 Reduce(String),
20}
21
22impl std::fmt::Display for RangeError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 RangeError::Reduce(s) => write!(f, "{s}"),
26 }
27 }
28}
29
30impl std::error::Error for RangeError {}
31
32#[derive(Debug)]
33enum ReductionState {
34 Count(u64),
35 Sum(Option<KvReducedValue>),
36 Min(Option<KvReducedValue>),
37 Max(Option<KvReducedValue>),
38}
39
40#[derive(Debug)]
41struct GroupedReductionState {
42 group_values: Vec<Option<KvReducedValue>>,
43 states: Vec<ReductionState>,
44}
45
46#[derive(Debug)]
47struct ExtractedReductionRow {
48 group_values: Vec<Option<KvReducedValue>>,
49 reducer_values: Vec<Option<KvReducedValue>>,
50}
51
52impl ReductionState {
53 fn from_op(op: RangeReduceOp) -> Self {
54 match op {
55 RangeReduceOp::CountAll | RangeReduceOp::CountField => Self::Count(0),
56 RangeReduceOp::SumField => Self::Sum(None),
57 RangeReduceOp::MinField => Self::Min(None),
58 RangeReduceOp::MaxField => Self::Max(None),
59 }
60 }
61
62 fn update(
63 &mut self,
64 op: RangeReduceOp,
65 value: Option<KvReducedValue>,
66 ) -> Result<(), RangeError> {
67 match (self, op) {
68 (Self::Count(count), RangeReduceOp::CountAll) => {
69 *count = count.saturating_add(1);
70 Ok(())
71 }
72 (Self::Count(count), RangeReduceOp::CountField) => {
73 if value.is_some() {
74 *count = count.saturating_add(1);
75 }
76 Ok(())
77 }
78 (Self::Sum(sum), RangeReduceOp::SumField) => {
79 let Some(value) = value else {
80 return Ok(());
81 };
82 match sum {
83 Some(existing) => existing
84 .checked_add_assign(&value)
85 .map_err(RangeError::Reduce),
86 None => {
87 *sum = Some(value);
88 Ok(())
89 }
90 }
91 }
92 (Self::Min(current), RangeReduceOp::MinField) => {
93 update_extreme(current, value, Ordering::Less)
94 }
95 (Self::Max(current), RangeReduceOp::MaxField) => {
96 update_extreme(current, value, Ordering::Greater)
97 }
98 _ => Err(RangeError::Reduce(
99 "reduction state/op mismatch".to_string(),
100 )),
101 }
102 }
103
104 fn finish(self) -> Option<KvReducedValue> {
105 match self {
106 Self::Count(count) => Some(KvReducedValue::UInt64(count)),
107 Self::Sum(value) | Self::Min(value) | Self::Max(value) => value,
108 }
109 }
110}
111
112impl GroupedReductionState {
113 fn new(group_values: Vec<Option<KvReducedValue>>, request: &RangeReduceRequest) -> Self {
114 Self {
115 group_values,
116 states: request
117 .reducers
118 .iter()
119 .map(|reducer| ReductionState::from_op(reducer.op))
120 .collect(),
121 }
122 }
123
124 fn update(
125 &mut self,
126 request: &RangeReduceRequest,
127 reducer_values: Vec<Option<KvReducedValue>>,
128 ) -> Result<(), RangeError> {
129 for ((state, reducer), value) in self
130 .states
131 .iter_mut()
132 .zip(request.reducers.iter())
133 .zip(reducer_values.into_iter())
134 {
135 state.update(reducer.op, value)?;
136 }
137 Ok(())
138 }
139
140 fn finish(self) -> RangeReduceGroup {
141 RangeReduceGroup {
142 group_values: self.group_values,
143 results: self
144 .states
145 .into_iter()
146 .map(|state| RangeReduceResult {
147 value: state.finish(),
148 })
149 .collect(),
150 }
151 }
152}
153
154fn update_extreme(
155 current: &mut Option<KvReducedValue>,
156 candidate: Option<KvReducedValue>,
157 replace_when: Ordering,
158) -> Result<(), RangeError> {
159 let Some(candidate) = candidate else {
160 return Ok(());
161 };
162 match current {
163 Some(existing) => {
164 let ordering = candidate
165 .partial_cmp_same_kind(existing)
166 .ok_or_else(|| RangeError::Reduce("min/max type mismatch".to_string()))?;
167 if ordering == replace_when {
168 *current = Some(candidate);
169 }
170 }
171 None => {
172 *current = Some(candidate);
173 }
174 }
175 Ok(())
176}
177
178fn validate_reduce_request(request: &RangeReduceRequest) -> Result<(), RangeError> {
179 if request.reducers.is_empty() && request.group_by.is_empty() {
180 return Err(RangeError::Reduce(
181 "range reduction request requires at least one reducer or group-by field".to_string(),
182 ));
183 }
184 for reducer in &request.reducers {
185 match reducer.op {
186 RangeReduceOp::CountAll => {
187 if reducer.expr.is_some() {
188 return Err(RangeError::Reduce(
189 "count_all reducer must not specify an expression".to_string(),
190 ));
191 }
192 }
193 RangeReduceOp::CountField
194 | RangeReduceOp::SumField
195 | RangeReduceOp::MinField
196 | RangeReduceOp::MaxField => {
197 if reducer.expr.is_none() {
198 return Err(RangeError::Reduce(
199 "expression reducer requires an expression".to_string(),
200 ));
201 }
202 }
203 }
204 }
205 Ok(())
206}
207
208fn reduce_row_into_response(
209 key: &Key,
210 value: &Bytes,
211 request: &RangeReduceRequest,
212 scalar_states: Option<&mut [ReductionState]>,
213 grouped_states: &mut BTreeMap<Vec<u8>, GroupedReductionState>,
214) -> Result<(), RangeError> {
215 let Some(extracted) = extract_reduce_row(key, value, request)? else {
216 return Ok(());
217 };
218
219 if request.group_by.is_empty() {
220 let Some(states) = scalar_states else {
221 return Err(RangeError::Reduce(
222 "missing scalar reduction state for non-grouped request".to_string(),
223 ));
224 };
225 for ((state, reducer), value) in states
226 .iter_mut()
227 .zip(request.reducers.iter())
228 .zip(extracted.reducer_values.into_iter())
229 {
230 state.update(reducer.op, value)?;
231 }
232 return Ok(());
233 }
234
235 let group_key = encode_reduced_group_key(&extracted.group_values);
236 let group = grouped_states
237 .entry(group_key)
238 .or_insert_with(|| GroupedReductionState::new(extracted.group_values.clone(), request));
239 group.update(request, extracted.reducer_values)?;
240 Ok(())
241}
242
243fn extract_reduce_row(
244 key: &Key,
245 value: &Bytes,
246 request: &RangeReduceRequest,
247) -> Result<Option<ExtractedReductionRow>, RangeError> {
248 let needs_value = request
249 .group_by
250 .iter()
251 .chain(
252 request
253 .reducers
254 .iter()
255 .filter_map(|reducer| reducer.expr.as_ref()),
256 )
257 .any(expr_needs_value)
258 || request.filter.as_ref().is_some_and(predicate_needs_value);
259 let decoded = if needs_value {
260 match decode_stored_row(value.as_ref()) {
261 Ok(row) => Some(row),
262 Err(_) => return Ok(None),
263 }
264 } else {
265 None
266 };
267 let archived = decoded.as_ref();
268
269 if let Some(filter) = &request.filter {
270 match eval_predicate(key, archived, filter) {
271 Ok(true) => {}
272 Ok(false) => return Ok(None),
273 Err(_) => return Ok(None),
274 }
275 }
276
277 let mut group_values = Vec::with_capacity(request.group_by.len());
278 for expr in &request.group_by {
279 let extracted_value = match eval_expr(key, archived, expr) {
280 Ok(value) => value,
281 Err(_) => return Ok(None),
282 };
283 group_values.push(extracted_value);
284 }
285 canonicalize_reduced_group_values(&mut group_values);
286
287 let mut reducer_values = Vec::with_capacity(request.reducers.len());
288 for reducer in &request.reducers {
289 let extracted_value = match (&reducer.expr, archived) {
290 (None, _) => None,
291 (Some(expr), _) => match eval_expr(key, archived, expr) {
292 Ok(value) => value,
293 Err(_) => return Ok(None),
294 },
295 };
296 reducer_values.push(extracted_value);
297 }
298
299 Ok(Some(ExtractedReductionRow {
300 group_values,
301 reducer_values,
302 }))
303}
304
305fn finalize_reduce_response(
306 scalar_states: Option<Vec<ReductionState>>,
307 grouped_states: BTreeMap<Vec<u8>, GroupedReductionState>,
308) -> RangeReduceResponse {
309 match scalar_states {
310 Some(states) => RangeReduceResponse {
311 results: states
312 .into_iter()
313 .map(|state| RangeReduceResult {
314 value: state.finish(),
315 })
316 .collect(),
317 groups: Vec::new(),
318 },
319 None => RangeReduceResponse {
320 results: Vec::new(),
321 groups: grouped_states
322 .into_values()
323 .map(GroupedReductionState::finish)
324 .collect(),
325 },
326 }
327}
328
329pub fn reduce_over_rows(
331 rows: &[(Key, Bytes)],
332 request: &RangeReduceRequest,
333) -> Result<RangeReduceResponse, RangeError> {
334 validate_reduce_request(request)?;
335 let mut scalar_states = request.group_by.is_empty().then(|| {
336 request
337 .reducers
338 .iter()
339 .map(|reducer| ReductionState::from_op(reducer.op))
340 .collect::<Vec<_>>()
341 });
342 let mut grouped_states = BTreeMap::<Vec<u8>, GroupedReductionState>::new();
343
344 for (key, value) in rows {
345 reduce_row_into_response(
346 key,
347 value,
348 request,
349 scalar_states.as_deref_mut(),
350 &mut grouped_states,
351 )?;
352 }
353
354 Ok(finalize_reduce_response(scalar_states, grouped_states))
355}
356
357#[cfg(test)]
358mod tests {
359 use bytes::Bytes;
360 use commonware_codec::Encode as _;
361 use exoware_sdk_rs::keys::Key;
362 use exoware_sdk_rs::kv_codec::{
363 KvExpr, KvFieldKind, KvFieldRef, KvPredicate, KvPredicateCheck, KvPredicateConstraint,
364 KvReducedValue, StoredRow, StoredValue,
365 };
366 use exoware_sdk_rs::{RangeReduceOp, RangeReduceRequest, RangeReducerSpec};
367
368 use super::reduce_over_rows;
369
370 fn make_row(key: &[u8], values: Vec<Option<StoredValue>>) -> (Key, Bytes) {
371 let encoded = StoredRow { values }.encode();
372 (Key::from(key.to_vec()), encoded)
373 }
374
375 fn reducer(op: RangeReduceOp, expr: Option<KvExpr>) -> RangeReducerSpec {
376 RangeReducerSpec { op, expr }
377 }
378
379 fn int64_value_field(index: u16) -> KvExpr {
380 KvExpr::Field(KvFieldRef::Value {
381 index,
382 kind: KvFieldKind::Int64,
383 nullable: true,
384 })
385 }
386
387 fn float64_value_field(index: u16) -> KvExpr {
388 KvExpr::Field(KvFieldRef::Value {
389 index,
390 kind: KvFieldKind::Float64,
391 nullable: true,
392 })
393 }
394
395 fn utf8_value_field(index: u16) -> KvExpr {
396 KvExpr::Field(KvFieldRef::Value {
397 index,
398 kind: KvFieldKind::Utf8,
399 nullable: true,
400 })
401 }
402
403 fn scalar_request(reducers: Vec<RangeReducerSpec>) -> RangeReduceRequest {
404 RangeReduceRequest {
405 reducers,
406 group_by: Vec::new(),
407 filter: None,
408 }
409 }
410
411 fn result_u64(v: u64) -> Option<KvReducedValue> {
412 Some(KvReducedValue::UInt64(v))
413 }
414
415 fn result_i64(v: i64) -> Option<KvReducedValue> {
416 Some(KvReducedValue::Int64(v))
417 }
418
419 fn result_f64(v: f64) -> Option<KvReducedValue> {
420 Some(KvReducedValue::Float64(v))
421 }
422
423 #[test]
424 fn count_all_over_empty_rows() {
425 let request = scalar_request(vec![reducer(RangeReduceOp::CountAll, None)]);
426 let response = reduce_over_rows(&[], &request).unwrap();
427 assert_eq!(response.results.len(), 1);
428 assert_eq!(response.results[0].value, result_u64(0));
429 }
430
431 #[test]
432 fn count_all_over_multiple_rows() {
433 let rows = vec![
434 make_row(b"a", vec![]),
435 make_row(b"b", vec![]),
436 make_row(b"c", vec![]),
437 ];
438 let request = scalar_request(vec![reducer(RangeReduceOp::CountAll, None)]);
439 let response = reduce_over_rows(&rows, &request).unwrap();
440 assert_eq!(response.results[0].value, result_u64(3));
441 }
442
443 #[test]
444 fn count_field_skips_nulls() {
445 let rows = vec![
446 make_row(b"a", vec![Some(StoredValue::Int64(1))]),
447 make_row(b"b", vec![None]),
448 make_row(b"c", vec![Some(StoredValue::Int64(3))]),
449 ];
450 let request = scalar_request(vec![reducer(
451 RangeReduceOp::CountField,
452 Some(int64_value_field(0)),
453 )]);
454 let response = reduce_over_rows(&rows, &request).unwrap();
455 assert_eq!(response.results[0].value, result_u64(2));
456 }
457
458 #[test]
459 fn sum_int64_values() {
460 let rows = vec![
461 make_row(b"a", vec![Some(StoredValue::Int64(10))]),
462 make_row(b"b", vec![Some(StoredValue::Int64(20))]),
463 make_row(b"c", vec![Some(StoredValue::Int64(-5))]),
464 ];
465 let request = scalar_request(vec![reducer(
466 RangeReduceOp::SumField,
467 Some(int64_value_field(0)),
468 )]);
469 let response = reduce_over_rows(&rows, &request).unwrap();
470 assert_eq!(response.results[0].value, result_i64(25));
471 }
472
473 #[test]
474 fn sum_float64_values() {
475 let rows = vec![
476 make_row(b"a", vec![Some(StoredValue::Float64(1.5))]),
477 make_row(b"b", vec![Some(StoredValue::Float64(2.5))]),
478 ];
479 let request = scalar_request(vec![reducer(
480 RangeReduceOp::SumField,
481 Some(float64_value_field(0)),
482 )]);
483 let response = reduce_over_rows(&rows, &request).unwrap();
484 assert_eq!(response.results[0].value, result_f64(4.0));
485 }
486
487 #[test]
488 fn min_selects_smallest() {
489 let rows = vec![
490 make_row(b"a", vec![Some(StoredValue::Int64(30))]),
491 make_row(b"b", vec![Some(StoredValue::Int64(10))]),
492 make_row(b"c", vec![Some(StoredValue::Int64(20))]),
493 ];
494 let request = scalar_request(vec![reducer(
495 RangeReduceOp::MinField,
496 Some(int64_value_field(0)),
497 )]);
498 let response = reduce_over_rows(&rows, &request).unwrap();
499 assert_eq!(response.results[0].value, result_i64(10));
500 }
501
502 #[test]
503 fn max_selects_largest() {
504 let rows = vec![
505 make_row(b"a", vec![Some(StoredValue::Int64(30))]),
506 make_row(b"b", vec![Some(StoredValue::Int64(10))]),
507 make_row(b"c", vec![Some(StoredValue::Int64(50))]),
508 ];
509 let request = scalar_request(vec![reducer(
510 RangeReduceOp::MaxField,
511 Some(int64_value_field(0)),
512 )]);
513 let response = reduce_over_rows(&rows, &request).unwrap();
514 assert_eq!(response.results[0].value, result_i64(50));
515 }
516
517 #[test]
518 fn grouped_count() {
519 let rows = vec![
520 make_row(b"a", vec![Some(StoredValue::Utf8("x".into()))]),
521 make_row(b"b", vec![Some(StoredValue::Utf8("y".into()))]),
522 make_row(b"c", vec![Some(StoredValue::Utf8("x".into()))]),
523 make_row(b"d", vec![Some(StoredValue::Utf8("y".into()))]),
524 make_row(b"e", vec![Some(StoredValue::Utf8("x".into()))]),
525 ];
526 let request = RangeReduceRequest {
527 reducers: vec![reducer(RangeReduceOp::CountAll, None)],
528 group_by: vec![utf8_value_field(0)],
529 filter: None,
530 };
531 let response = reduce_over_rows(&rows, &request).unwrap();
532 assert!(response.results.is_empty());
533 assert_eq!(response.groups.len(), 2);
534
535 let mut counts: Vec<(Option<KvReducedValue>, Option<KvReducedValue>)> = response
536 .groups
537 .iter()
538 .map(|g| (g.group_values[0].clone(), g.results[0].value.clone()))
539 .collect();
540 counts.sort_by(|a, b| {
541 let a_str = match &a.0 {
542 Some(KvReducedValue::Utf8(s)) => s.clone(),
543 _ => String::new(),
544 };
545 let b_str = match &b.0 {
546 Some(KvReducedValue::Utf8(s)) => s.clone(),
547 _ => String::new(),
548 };
549 a_str.cmp(&b_str)
550 });
551 assert_eq!(
552 counts,
553 vec![
554 (Some(KvReducedValue::Utf8("x".into())), result_u64(3),),
555 (Some(KvReducedValue::Utf8("y".into())), result_u64(2),),
556 ]
557 );
558 }
559
560 #[test]
561 fn validates_empty_request() {
562 let request = RangeReduceRequest {
563 reducers: Vec::new(),
564 group_by: Vec::new(),
565 filter: None,
566 };
567 let err = reduce_over_rows(&[], &request).unwrap_err();
568 assert!(
569 err.to_string().contains("at least one reducer"),
570 "unexpected error: {err}"
571 );
572 }
573
574 #[test]
575 fn count_all_rejects_expression() {
576 let request = scalar_request(vec![reducer(
577 RangeReduceOp::CountAll,
578 Some(int64_value_field(0)),
579 )]);
580 let err = reduce_over_rows(&[], &request).unwrap_err();
581 assert!(
582 err.to_string()
583 .contains("count_all reducer must not specify an expression"),
584 "unexpected error: {err}"
585 );
586 }
587
588 #[test]
589 fn expression_reducer_requires_expression() {
590 for op in [
591 RangeReduceOp::SumField,
592 RangeReduceOp::MinField,
593 RangeReduceOp::MaxField,
594 RangeReduceOp::CountField,
595 ] {
596 let request = scalar_request(vec![reducer(op, None)]);
597 let err = reduce_over_rows(&[], &request).unwrap_err();
598 assert!(
599 err.to_string()
600 .contains("expression reducer requires an expression"),
601 "op {op:?} should require an expression, got: {err}"
602 );
603 }
604 }
605
606 #[test]
607 fn filter_excludes_rows() {
608 let rows = vec![
609 make_row(b"a", vec![Some(StoredValue::Int64(10))]),
610 make_row(b"b", vec![Some(StoredValue::Int64(20))]),
611 make_row(b"c", vec![Some(StoredValue::Int64(30))]),
612 ];
613 let request = RangeReduceRequest {
614 reducers: vec![reducer(RangeReduceOp::SumField, Some(int64_value_field(0)))],
615 group_by: Vec::new(),
616 filter: Some(KvPredicate {
617 checks: vec![KvPredicateCheck {
618 field: KvFieldRef::Value {
619 index: 0,
620 kind: KvFieldKind::Int64,
621 nullable: false,
622 },
623 constraint: KvPredicateConstraint::IntRange {
624 min: Some(15),
625 max: None,
626 },
627 }],
628 contradiction: false,
629 }),
630 };
631 let response = reduce_over_rows(&rows, &request).unwrap();
632 assert_eq!(response.results[0].value, result_i64(50));
633 }
634
635 #[test]
636 fn mixed_type_min_max_returns_error() {
637 use super::ReductionState;
638
639 let mut state = ReductionState::Min(Some(KvReducedValue::Int64(10)));
640 let result = state.update(
641 RangeReduceOp::MinField,
642 Some(KvReducedValue::Utf8("hello".into())),
643 );
644 assert!(result.is_err());
645 assert!(
646 result.unwrap_err().to_string().contains("type mismatch"),
647 "expected type mismatch error"
648 );
649 }
650}