Skip to main content

kimberlite_query/
executor.rs

1//! Query executor: executes query plans against a projection store.
2
3#![allow(clippy::ref_option)]
4#![allow(clippy::trivially_copy_pass_by_ref)]
5#![allow(clippy::items_after_statements)]
6
7use std::cmp::Ordering;
8use std::ops::Bound;
9
10// ============================================================================
11// Query Execution Constants
12// ============================================================================
13
14/// Scan buffer multiplier when ORDER BY is present (needs extra buffer for sorting).
15///
16/// **Rationale**: Client-side sorting requires loading all candidate rows before
17/// applying LIMIT. We over-fetch by 10x to handle common cases where the ORDER BY
18/// columns have high cardinality, while still bounding memory usage.
19const SCAN_LIMIT_MULTIPLIER_WITH_SORT: usize = 10;
20
21/// Scan buffer multiplier without ORDER BY (minimal buffering).
22///
23/// **Rationale**: Without sorting, we can stream results and apply LIMIT incrementally.
24/// We fetch 2x the limit to handle edge cases with deleted rows or MVCC conflicts.
25const SCAN_LIMIT_MULTIPLIER_NO_SORT: usize = 2;
26
27/// Default scan limit when no LIMIT clause is specified.
28///
29/// **Rationale**: Prevents unbounded memory allocation for large tables.
30/// Set to 10K based on:
31/// - Avg row size ~1KB → ~10MB memory footprint
32/// - p99 query latency < 50ms for 10K row scan
33/// - Sufficient for most analytical queries
34const DEFAULT_SCAN_LIMIT: usize = 10_000;
35
36/// Maximum number of aggregates per query.
37///
38/// **Rationale**: Prevents `DoS` via memory exhaustion.
39/// Each aggregate maintains state (sum, count, min, max) ≈ 64 bytes per group.
40/// 100 aggregates × 1000 groups = ~6.4MB state, which is reasonable.
41const MAX_AGGREGATES_PER_QUERY: usize = 100;
42
43use bytes::Bytes;
44use kimberlite_store::{Key, ProjectionStore, TableId};
45use kimberlite_types::Offset;
46
47use crate::error::{QueryError, Result};
48use crate::key_encoder::successor_key;
49use crate::plan::{QueryPlan, ScanOrder, SortSpec};
50use crate::schema::{ColumnName, TableDef};
51use crate::value::Value;
52
53/// Result of executing a query.
54#[derive(Debug, Clone)]
55pub struct QueryResult {
56    /// Column names in result order.
57    pub columns: Vec<ColumnName>,
58    /// Result rows.
59    pub rows: Vec<Row>,
60}
61
62impl QueryResult {
63    /// Creates an empty result with the given columns.
64    pub fn empty(columns: Vec<ColumnName>) -> Self {
65        Self {
66            columns,
67            rows: vec![],
68        }
69    }
70
71    /// Returns the number of rows.
72    pub fn len(&self) -> usize {
73        self.rows.len()
74    }
75
76    /// Returns true if there are no rows.
77    pub fn is_empty(&self) -> bool {
78        self.rows.is_empty()
79    }
80}
81
82/// A single result row.
83pub type Row = Vec<Value>;
84
85/// Executes an index scan query.
86#[allow(clippy::too_many_arguments)]
87fn execute_index_scan<S: ProjectionStore>(
88    store: &mut S,
89    table_id: TableId,
90    index_id: u64,
91    start: &Bound<Key>,
92    end: &Bound<Key>,
93    filter: &Option<crate::plan::Filter>,
94    limit: &Option<usize>,
95    order: &ScanOrder,
96    order_by: &Option<crate::plan::SortSpec>,
97    columns: &[usize],
98    column_names: &[ColumnName],
99    table_def: &TableDef,
100    position: Option<Offset>,
101) -> Result<QueryResult> {
102    let (start_key, end_key) = bounds_to_range(start, end);
103
104    // Calculate scan limit based on whether client-side sorting is needed
105    let scan_limit = if order_by.is_some() {
106        limit
107            .map(|l| l.saturating_mul(SCAN_LIMIT_MULTIPLIER_WITH_SORT))
108            .unwrap_or(DEFAULT_SCAN_LIMIT)
109    } else {
110        limit
111            .map(|l| l.saturating_mul(SCAN_LIMIT_MULTIPLIER_NO_SORT))
112            .unwrap_or(DEFAULT_SCAN_LIMIT)
113    };
114
115    // Postcondition: scan limit must be positive
116    debug_assert!(scan_limit > 0, "scan_limit must be positive");
117
118    // Calculate index table ID using hash to avoid overflow
119    use std::collections::hash_map::DefaultHasher;
120    use std::hash::{Hash, Hasher};
121
122    let mut hasher = DefaultHasher::new();
123    table_id.as_u64().hash(&mut hasher);
124    index_id.hash(&mut hasher);
125    let index_table_id = TableId::new(hasher.finish());
126
127    // Scan the index table to get composite keys
128    let index_pairs = match position {
129        Some(pos) => store.scan_at(index_table_id, start_key..end_key, scan_limit, pos)?,
130        None => store.scan(index_table_id, start_key..end_key, scan_limit)?,
131    };
132
133    let mut full_rows = Vec::new();
134    let index_iter: Box<dyn Iterator<Item = &(Key, Bytes)>> = match order {
135        ScanOrder::Ascending => Box::new(index_pairs.iter()),
136        ScanOrder::Descending => Box::new(index_pairs.iter().rev()),
137    };
138
139    for (index_key, _) in index_iter {
140        // Extract primary key from the composite index key
141        let pk_key = extract_pk_from_index_key(index_key, table_def);
142
143        // Fetch the actual row from the base table
144        let bytes_opt = match position {
145            Some(pos) => store.get_at(table_id, &pk_key, pos)?,
146            None => store.get(table_id, &pk_key)?,
147        };
148        if let Some(bytes) = bytes_opt {
149            let full_row = decode_row(&bytes, table_def)?;
150
151            // Apply filter
152            if let Some(f) = filter {
153                if !f.matches(&full_row) {
154                    continue;
155                }
156            }
157
158            full_rows.push(full_row);
159
160            // When client-side sorting is needed, don't apply limit during scan
161            if order_by.is_none() {
162                if let Some(lim) = limit {
163                    if full_rows.len() >= *lim {
164                        break;
165                    }
166                }
167            }
168        }
169    }
170
171    // Apply client-side sorting if needed (on full rows before projection)
172    if let Some(sort_spec) = order_by {
173        sort_rows(&mut full_rows, sort_spec);
174    }
175
176    // Apply limit after sorting
177    if let Some(lim) = limit {
178        full_rows.truncate(*lim);
179    }
180
181    // Project columns after sorting and limiting
182    let rows: Vec<Row> = full_rows
183        .iter()
184        .map(|full_row| project_row(full_row, columns))
185        .collect();
186
187    Ok(QueryResult {
188        columns: column_names.to_vec(),
189        rows,
190    })
191}
192
193/// Executes a table scan query.
194#[allow(clippy::too_many_arguments)]
195fn execute_table_scan<S: ProjectionStore>(
196    store: &mut S,
197    table_id: TableId,
198    filter: &Option<crate::plan::Filter>,
199    limit: &Option<usize>,
200    order: &Option<SortSpec>,
201    columns: &[usize],
202    column_names: &[ColumnName],
203    table_def: &TableDef,
204    position: Option<Offset>,
205) -> Result<QueryResult> {
206    // Scan entire table
207    let scan_limit = limit.map(|l| l * 10).unwrap_or(100_000);
208    let pairs = match position {
209        Some(pos) => store.scan_at(table_id, Key::min()..Key::max(), scan_limit, pos)?,
210        None => store.scan(table_id, Key::min()..Key::max(), scan_limit)?,
211    };
212
213    let mut full_rows = Vec::new();
214
215    for (_, bytes) in &pairs {
216        let full_row = decode_row(bytes, table_def)?;
217
218        // Apply filter
219        if let Some(f) = filter {
220            if !f.matches(&full_row) {
221                continue;
222            }
223        }
224
225        full_rows.push(full_row);
226    }
227
228    // Apply sort on full rows (before projection)
229    if let Some(sort_spec) = order {
230        sort_rows(&mut full_rows, sort_spec);
231    }
232
233    // Apply limit
234    if let Some(lim) = limit {
235        full_rows.truncate(*lim);
236    }
237
238    // Project columns after sorting and limiting
239    let rows: Vec<Row> = full_rows
240        .iter()
241        .map(|full_row| project_row(full_row, columns))
242        .collect();
243
244    Ok(QueryResult {
245        columns: column_names.to_vec(),
246        rows,
247    })
248}
249
250/// Executes a range scan query.
251#[allow(clippy::too_many_arguments)]
252fn execute_range_scan<S: ProjectionStore>(
253    store: &mut S,
254    table_id: TableId,
255    start: &Bound<Key>,
256    end: &Bound<Key>,
257    filter: &Option<crate::plan::Filter>,
258    limit: &Option<usize>,
259    order: &ScanOrder,
260    order_by: &Option<crate::plan::SortSpec>,
261    columns: &[usize],
262    column_names: &[ColumnName],
263    table_def: &TableDef,
264    position: Option<Offset>,
265) -> Result<QueryResult> {
266    let (start_key, end_key) = bounds_to_range(start, end);
267
268    // Calculate scan limit based on whether client-side sorting is needed
269    let scan_limit = if order_by.is_some() {
270        limit
271            .map(|l| l.saturating_mul(SCAN_LIMIT_MULTIPLIER_WITH_SORT))
272            .unwrap_or(DEFAULT_SCAN_LIMIT)
273    } else {
274        limit
275            .map(|l| l.saturating_mul(SCAN_LIMIT_MULTIPLIER_NO_SORT))
276            .unwrap_or(DEFAULT_SCAN_LIMIT)
277    };
278
279    // Postcondition: scan limit must be positive
280    debug_assert!(scan_limit > 0, "scan_limit must be positive");
281
282    let pairs = match position {
283        Some(pos) => store.scan_at(table_id, start_key..end_key, scan_limit, pos)?,
284        None => store.scan(table_id, start_key..end_key, scan_limit)?,
285    };
286
287    let mut full_rows = Vec::new();
288    let row_iter: Box<dyn Iterator<Item = &(Key, Bytes)>> = match order {
289        ScanOrder::Ascending => Box::new(pairs.iter()),
290        ScanOrder::Descending => Box::new(pairs.iter().rev()),
291    };
292
293    for (_, bytes) in row_iter {
294        let full_row = decode_row(bytes, table_def)?;
295
296        // Apply filter
297        if let Some(f) = filter {
298            if !f.matches(&full_row) {
299                continue;
300            }
301        }
302
303        full_rows.push(full_row);
304
305        // When client-side sorting is needed, don't apply limit during scan
306        if order_by.is_none() {
307            if let Some(lim) = limit {
308                if full_rows.len() >= *lim {
309                    break;
310                }
311            }
312        }
313    }
314
315    // Apply client-side sorting if needed (on full rows before projection)
316    if let Some(sort_spec) = order_by {
317        sort_rows(&mut full_rows, sort_spec);
318    }
319
320    // Apply limit after sorting
321    if let Some(lim) = limit {
322        full_rows.truncate(*lim);
323    }
324
325    // Project columns after sorting and limiting
326    let rows: Vec<Row> = full_rows
327        .iter()
328        .map(|full_row| project_row(full_row, columns))
329        .collect();
330
331    Ok(QueryResult {
332        columns: column_names.to_vec(),
333        rows,
334    })
335}
336
337/// Executes a point lookup query.
338fn execute_point_lookup<S: ProjectionStore>(
339    store: &mut S,
340    table_id: TableId,
341    key: &Key,
342    columns: &[usize],
343    column_names: &[ColumnName],
344    table_def: &TableDef,
345    position: Option<Offset>,
346) -> Result<QueryResult> {
347    let result = match position {
348        Some(pos) => store.get_at(table_id, key, pos)?,
349        None => store.get(table_id, key)?,
350    };
351    match result {
352        Some(bytes) => {
353            let row = decode_and_project(&bytes, columns, table_def)?;
354            Ok(QueryResult {
355                columns: column_names.to_vec(),
356                rows: vec![row],
357            })
358        }
359        None => Ok(QueryResult::empty(column_names.to_vec())),
360    }
361}
362
363/// Internal execution function that handles both current and point-in-time queries.
364#[allow(clippy::too_many_lines)]
365fn execute_internal<S: ProjectionStore>(
366    store: &mut S,
367    plan: &QueryPlan,
368    table_def: &TableDef,
369    position: Option<Offset>,
370) -> Result<QueryResult> {
371    match plan {
372        QueryPlan::PointLookup {
373            table_id,
374            key,
375            columns,
376            column_names,
377            ..
378        } => execute_point_lookup(
379            store,
380            *table_id,
381            key,
382            columns,
383            column_names,
384            table_def,
385            position,
386        ),
387
388        QueryPlan::RangeScan {
389            table_id,
390            start,
391            end,
392            filter,
393            limit,
394            order,
395            order_by,
396            columns,
397            column_names,
398            ..
399        } => execute_range_scan(
400            store,
401            *table_id,
402            start,
403            end,
404            filter,
405            limit,
406            order,
407            order_by,
408            columns,
409            column_names,
410            table_def,
411            position,
412        ),
413
414        QueryPlan::IndexScan {
415            table_id,
416            index_id,
417            start,
418            end,
419            filter,
420            limit,
421            order,
422            order_by,
423            columns,
424            column_names,
425            ..
426        } => execute_index_scan(
427            store,
428            *table_id,
429            *index_id,
430            start,
431            end,
432            filter,
433            limit,
434            order,
435            order_by,
436            columns,
437            column_names,
438            table_def,
439            position,
440        ),
441
442        QueryPlan::TableScan {
443            table_id,
444            filter,
445            limit,
446            order,
447            columns,
448            column_names,
449            ..
450        } => execute_table_scan(
451            store,
452            *table_id,
453            filter,
454            limit,
455            order,
456            columns,
457            column_names,
458            table_def,
459            position,
460        ),
461
462        QueryPlan::Aggregate {
463            source,
464            group_by_cols,
465            aggregates,
466            column_names,
467            ..
468        } => execute_aggregate(
469            store,
470            source,
471            group_by_cols,
472            aggregates,
473            column_names,
474            table_def,
475            position,
476        ),
477    }
478}
479
480/// Executes a query plan against the current store state.
481pub fn execute<S: ProjectionStore>(
482    store: &mut S,
483    plan: &QueryPlan,
484    table_def: &TableDef,
485) -> Result<QueryResult> {
486    execute_internal(store, plan, table_def, None)
487}
488
489/// Executes a query plan at a specific log position (point-in-time query).
490pub fn execute_at<S: ProjectionStore>(
491    store: &mut S,
492    plan: &QueryPlan,
493    table_def: &TableDef,
494    position: Offset,
495) -> Result<QueryResult> {
496    execute_internal(store, plan, table_def, Some(position))
497}
498
499/// Converts bounds to a range.
500///
501/// The store scan uses a half-open range [start, end), so we need to:
502/// - For Included start: use the key as-is
503/// - For Excluded start: use the successor key (to skip the excluded value)
504/// - For Included end: use successor key (to include the value)
505/// - For Excluded end: use the key as-is
506fn bounds_to_range(start: &Bound<Key>, end: &Bound<Key>) -> (Key, Key) {
507    let start_key = match start {
508        Bound::Included(k) => k.clone(),
509        Bound::Excluded(k) => successor_key(k),
510        Bound::Unbounded => Key::min(),
511    };
512
513    let end_key = match end {
514        Bound::Included(k) => successor_key(k),
515        Bound::Excluded(k) => k.clone(),
516        Bound::Unbounded => Key::max(),
517    };
518
519    (start_key, end_key)
520}
521
522/// Extracts the primary key from a composite index key.
523///
524/// Index keys are structured as: [`index_column_values`...][primary_key_values...]
525/// This function strips the index column values and returns only the primary key portion.
526///
527/// # Assertions
528/// - Index key must be longer than the number of index columns
529/// - Primary key columns must be non-empty
530fn extract_pk_from_index_key(index_key: &Key, table_def: &TableDef) -> Key {
531    use crate::key_encoder::{decode_key, encode_key};
532
533    // Decode the full composite key to get all values
534    let all_values = decode_key(index_key);
535
536    // Get the number of primary key columns
537    let pk_count = table_def.primary_key.len();
538
539    // Assertions
540    debug_assert!(pk_count > 0, "primary key columns must be non-empty");
541    debug_assert!(
542        all_values.len() >= pk_count,
543        "index key must contain at least the primary key values"
544    );
545
546    // Extract the last pk_count values (the primary key)
547    // Index key format: [index_col1, index_col2, ..., pk_col1, pk_col2, ...]
548    let pk_values: Vec<Value> = all_values
549        .iter()
550        .skip(all_values.len() - pk_count)
551        .cloned()
552        .collect();
553
554    debug_assert_eq!(
555        pk_values.len(),
556        pk_count,
557        "extracted primary key must have correct number of columns"
558    );
559
560    // Re-encode as a key
561    encode_key(&pk_values)
562}
563
564/// Decodes a JSON row and projects columns.
565fn decode_and_project(bytes: &Bytes, columns: &[usize], table_def: &TableDef) -> Result<Row> {
566    let full_row = decode_row(bytes, table_def)?;
567    Ok(project_row(&full_row, columns))
568}
569
570/// Decodes a JSON row to values.
571fn decode_row(bytes: &Bytes, table_def: &TableDef) -> Result<Row> {
572    let json: serde_json::Value = serde_json::from_slice(bytes)?;
573
574    let obj = json.as_object().ok_or_else(|| QueryError::TypeMismatch {
575        expected: "object".to_string(),
576        actual: format!("{json:?}"),
577    })?;
578
579    let mut row = Vec::with_capacity(table_def.columns.len());
580
581    for col_def in &table_def.columns {
582        let col_name = col_def.name.as_str();
583        let json_val = obj.get(col_name).unwrap_or(&serde_json::Value::Null);
584        let value = Value::from_json(json_val, col_def.data_type)?;
585        row.push(value);
586    }
587
588    Ok(row)
589}
590
591/// Projects a row to selected columns.
592fn project_row(full_row: &[Value], columns: &[usize]) -> Row {
593    // Precondition: column indices must be valid
594    debug_assert!(
595        columns.iter().all(|&idx| idx < full_row.len()),
596        "column index out of bounds: columns={:?}, row_len={}",
597        columns,
598        full_row.len()
599    );
600
601    if columns.is_empty() {
602        // Empty columns means all columns
603        return full_row.to_vec();
604    }
605
606    let projected: Vec<Value> = columns
607        .iter()
608        .map(|&idx| {
609            full_row.get(idx).cloned().unwrap_or_else(|| {
610                // This should never happen due to precondition
611                panic!(
612                    "column index {} out of bounds (row len {})",
613                    idx,
614                    full_row.len()
615                );
616            })
617        })
618        .collect();
619
620    // Postcondition: result has correct length
621    debug_assert_eq!(
622        projected.len(),
623        columns.len(),
624        "projected row length mismatch"
625    );
626
627    projected
628}
629
630/// Sorts rows according to the sort specification.
631fn sort_rows(rows: &mut [Row], spec: &SortSpec) {
632    rows.sort_by(|a, b| {
633        for (col_idx, order) in &spec.columns {
634            let a_val = a.get(*col_idx);
635            let b_val = b.get(*col_idx);
636
637            let cmp = match (a_val, b_val) {
638                (Some(av), Some(bv)) => av.compare(bv).unwrap_or(Ordering::Equal),
639                (None, None) => Ordering::Equal,
640                (None, Some(_)) => Ordering::Less,
641                (Some(_), None) => Ordering::Greater,
642            };
643
644            if cmp != Ordering::Equal {
645                return match order {
646                    ScanOrder::Ascending => cmp,
647                    ScanOrder::Descending => cmp.reverse(),
648                };
649            }
650        }
651        Ordering::Equal
652    });
653}
654
655/// Executes an aggregate query with optional grouping.
656fn execute_aggregate<S: ProjectionStore>(
657    store: &mut S,
658    source: &QueryPlan,
659    group_by_cols: &[usize],
660    aggregates: &[crate::parser::AggregateFunction],
661    column_names: &[ColumnName],
662    table_def: &TableDef,
663    position: Option<Offset>,
664) -> Result<QueryResult> {
665    use std::collections::HashMap;
666
667    // Execute source plan to get all rows
668    let source_result = execute_internal(store, source, table_def, position)?;
669
670    // Build aggregate state grouped by key
671    let mut groups: HashMap<Vec<Value>, AggregateState> = HashMap::new();
672
673    for row in source_result.rows {
674        // Extract group key (values from GROUP BY columns)
675        let group_key: Vec<Value> = if group_by_cols.is_empty() {
676            // No GROUP BY - all rows in one group
677            vec![]
678        } else {
679            group_by_cols
680                .iter()
681                .map(|&idx| row.get(idx).cloned().unwrap_or(Value::Null))
682                .collect()
683        };
684
685        // Update aggregates for this group
686        let state = groups.entry(group_key).or_insert_with(AggregateState::new);
687        state.update(&row, aggregates, table_def)?;
688    }
689
690    // Convert groups to result rows
691    let mut result_rows = Vec::new();
692    for (group_key, state) in groups {
693        let mut result_row = group_key; // Start with GROUP BY columns
694        result_row.extend(state.finalize(aggregates)); // Add aggregate results
695        result_rows.push(result_row);
696    }
697
698    // If no groups and no GROUP BY, return one row with global aggregates
699    if result_rows.is_empty() && group_by_cols.is_empty() {
700        let state = AggregateState::new();
701        let agg_values = state.finalize(aggregates);
702        result_rows.push(agg_values);
703    }
704
705    Ok(QueryResult {
706        columns: column_names.to_vec(),
707        rows: result_rows,
708    })
709}
710
711/// State for computing aggregates over a group of rows.
712#[derive(Debug, Clone)]
713struct AggregateState {
714    count: i64,
715    non_null_counts: Vec<i64>, // For COUNT(col) - tracks non-NULL values per aggregate
716    sums: Vec<Option<Value>>,
717    mins: Vec<Option<Value>>,
718    maxs: Vec<Option<Value>>,
719}
720
721impl AggregateState {
722    fn new() -> Self {
723        Self {
724            count: 0,
725            non_null_counts: Vec::new(),
726            sums: Vec::new(),
727            mins: Vec::new(),
728            maxs: Vec::new(),
729        }
730    }
731
732    fn update(
733        &mut self,
734        row: &[Value],
735        aggregates: &[crate::parser::AggregateFunction],
736        table_def: &TableDef,
737    ) -> Result<()> {
738        // Precondition: row must have at least one column
739        debug_assert!(!row.is_empty(), "row must have at least one column");
740
741        // Precondition: enforce maximum aggregates limit to prevent DoS
742        // Note: aggregates can be empty for DISTINCT queries (deduplication only)
743        assert!(
744            aggregates.len() <= MAX_AGGREGATES_PER_QUERY,
745            "too many aggregates ({} > {})",
746            aggregates.len(),
747            MAX_AGGREGATES_PER_QUERY
748        );
749
750        self.count += 1;
751
752        // Ensure vectors are sized
753        while self.sums.len() < aggregates.len() {
754            self.non_null_counts.push(0);
755            self.sums.push(None);
756            self.mins.push(None);
757            self.maxs.push(None);
758        }
759
760        // Invariant: all vectors must be same length after sizing
761        debug_assert_eq!(
762            self.sums.len(),
763            self.non_null_counts.len(),
764            "aggregate state vectors out of sync"
765        );
766        debug_assert_eq!(self.sums.len(), self.mins.len());
767        debug_assert_eq!(self.sums.len(), self.maxs.len());
768
769        for (i, agg) in aggregates.iter().enumerate() {
770            match agg {
771                crate::parser::AggregateFunction::CountStar => {
772                    // Already counted above
773                }
774                crate::parser::AggregateFunction::Count(col) => {
775                    // COUNT(col) counts non-NULL values
776                    let col_idx = table_def.find_column(col).map_or(0, |(idx, _)| idx);
777                    if let Some(val) = row.get(col_idx) {
778                        if !val.is_null() {
779                            self.non_null_counts[i] += 1;
780                        }
781                    }
782                }
783                crate::parser::AggregateFunction::Sum(col) => {
784                    let col_idx = table_def.find_column(col).map_or(0, |(idx, _)| idx);
785                    if let Some(val) = row.get(col_idx) {
786                        if !val.is_null() {
787                            self.sums[i] = Some(add_values(&self.sums[i], val)?);
788                        }
789                    }
790                }
791                crate::parser::AggregateFunction::Avg(col) => {
792                    // AVG = SUM / COUNT - compute sum here
793                    let col_idx = table_def.find_column(col).map_or(0, |(idx, _)| idx);
794                    if let Some(val) = row.get(col_idx) {
795                        if !val.is_null() {
796                            self.sums[i] = Some(add_values(&self.sums[i], val)?);
797                        }
798                    }
799                }
800                crate::parser::AggregateFunction::Min(col) => {
801                    let col_idx = table_def.find_column(col).map_or(0, |(idx, _)| idx);
802                    if let Some(val) = row.get(col_idx) {
803                        if !val.is_null() {
804                            self.mins[i] = Some(min_value(&self.mins[i], val));
805                        }
806                    }
807                }
808                crate::parser::AggregateFunction::Max(col) => {
809                    let col_idx = table_def.find_column(col).map_or(0, |(idx, _)| idx);
810                    if let Some(val) = row.get(col_idx) {
811                        if !val.is_null() {
812                            self.maxs[i] = Some(max_value(&self.maxs[i], val));
813                        }
814                    }
815                }
816            }
817        }
818
819        // Postcondition: state must match aggregate count after update
820        debug_assert_eq!(
821            self.sums.len(),
822            aggregates.len(),
823            "aggregate state must match aggregate count after update"
824        );
825
826        Ok(())
827    }
828
829    fn finalize(&self, aggregates: &[crate::parser::AggregateFunction]) -> Vec<Value> {
830        let mut result = Vec::new();
831
832        for (i, agg) in aggregates.iter().enumerate() {
833            let value = match agg {
834                crate::parser::AggregateFunction::CountStar => Value::BigInt(self.count),
835                crate::parser::AggregateFunction::Count(_) => {
836                    // Use non-NULL count for COUNT(col)
837                    Value::BigInt(self.non_null_counts.get(i).copied().unwrap_or(0))
838                }
839                crate::parser::AggregateFunction::Sum(_) => self
840                    .sums
841                    .get(i)
842                    .and_then(std::clone::Clone::clone)
843                    .unwrap_or(Value::Null),
844                crate::parser::AggregateFunction::Avg(_) => {
845                    // AVG = SUM / COUNT
846                    if self.count == 0 {
847                        Value::Null
848                    } else {
849                        match self.sums.get(i).and_then(|v| v.as_ref()) {
850                            Some(sum) => divide_value(sum, self.count).unwrap_or(Value::Null),
851                            None => Value::Null,
852                        }
853                    }
854                }
855                crate::parser::AggregateFunction::Min(_) => self
856                    .mins
857                    .get(i)
858                    .and_then(std::clone::Clone::clone)
859                    .unwrap_or(Value::Null),
860                crate::parser::AggregateFunction::Max(_) => self
861                    .maxs
862                    .get(i)
863                    .and_then(std::clone::Clone::clone)
864                    .unwrap_or(Value::Null),
865            };
866            result.push(value);
867        }
868
869        result
870    }
871}
872
873/// Adds two values for SUM aggregates.
874fn add_values(a: &Option<Value>, b: &Value) -> Result<Value> {
875    match a {
876        None => Ok(b.clone()),
877        Some(a_val) => match (a_val, b) {
878            (Value::BigInt(x), Value::BigInt(y)) => Ok(Value::BigInt(x + y)),
879            (Value::Integer(x), Value::Integer(y)) => Ok(Value::Integer(x + y)),
880            (Value::SmallInt(x), Value::SmallInt(y)) => Ok(Value::SmallInt(x + y)),
881            (Value::TinyInt(x), Value::TinyInt(y)) => Ok(Value::TinyInt(x + y)),
882            (Value::Real(x), Value::Real(y)) => Ok(Value::Real(x + y)),
883            (Value::Decimal(x, sx), Value::Decimal(y, sy)) if sx == sy => {
884                Ok(Value::Decimal(x + y, *sx))
885            }
886            _ => Err(QueryError::TypeMismatch {
887                expected: format!("{a_val:?}"),
888                actual: format!("{b:?}"),
889            }),
890        },
891    }
892}
893
894/// Returns the minimum of two values.
895fn min_value(a: &Option<Value>, b: &Value) -> Value {
896    match a {
897        None => b.clone(),
898        Some(a_val) => {
899            if let Some(ord) = a_val.compare(b) {
900                if ord == Ordering::Less {
901                    a_val.clone()
902                } else {
903                    b.clone()
904                }
905            } else {
906                a_val.clone() // Incomparable types, keep current
907            }
908        }
909    }
910}
911
912/// Returns the maximum of two values.
913fn max_value(a: &Option<Value>, b: &Value) -> Value {
914    match a {
915        None => b.clone(),
916        Some(a_val) => {
917            if let Some(ord) = a_val.compare(b) {
918                if ord == Ordering::Greater {
919                    a_val.clone()
920                } else {
921                    b.clone()
922                }
923            } else {
924                a_val.clone() // Incomparable types, keep current
925            }
926        }
927    }
928}
929
930/// Divides a value by a count for AVG aggregates.
931#[allow(clippy::cast_precision_loss)]
932fn divide_value(val: &Value, count: i64) -> Option<Value> {
933    match val {
934        Value::BigInt(x) => Some(Value::Real(*x as f64 / count as f64)),
935        Value::Integer(x) => Some(Value::Real(f64::from(*x) / count as f64)),
936        Value::SmallInt(x) => Some(Value::Real(f64::from(*x) / count as f64)),
937        Value::TinyInt(x) => Some(Value::Real(f64::from(*x) / count as f64)),
938        Value::Real(x) => Some(Value::Real(x / count as f64)),
939        Value::Decimal(x, scale) => {
940            // Convert to float for division
941            let divisor = 10_i128.pow(u32::from(*scale));
942            let float_val = *x as f64 / divisor as f64;
943            Some(Value::Real(float_val / count as f64))
944        }
945        _ => None,
946    }
947}
948
949#[cfg(test)]
950mod tests {
951    use super::*;
952    use crate::plan::Filter;
953    use crate::plan::FilterCondition;
954    use crate::plan::FilterOp;
955
956    #[test]
957    fn test_project_row() {
958        let row = vec![
959            Value::BigInt(1),
960            Value::Text("alice".to_string()),
961            Value::BigInt(30),
962        ];
963
964        let projected = project_row(&row, &[0, 2]);
965        assert_eq!(projected, vec![Value::BigInt(1), Value::BigInt(30)]);
966    }
967
968    #[test]
969    fn test_project_row_all() {
970        let row = vec![Value::BigInt(1), Value::Text("bob".to_string())];
971        let projected = project_row(&row, &[]);
972        assert_eq!(projected, row);
973    }
974
975    #[test]
976    fn test_filter_matches() {
977        let row = vec![Value::BigInt(42), Value::Text("alice".to_string())];
978
979        let filter = Filter::single(FilterCondition {
980            column_idx: 0,
981            op: FilterOp::Eq,
982            value: Value::BigInt(42),
983        });
984
985        assert!(filter.matches(&row));
986
987        let filter_miss = Filter::single(FilterCondition {
988            column_idx: 0,
989            op: FilterOp::Eq,
990            value: Value::BigInt(99),
991        });
992
993        assert!(!filter_miss.matches(&row));
994    }
995
996    #[test]
997    fn test_sort_rows() {
998        let mut rows = vec![
999            vec![Value::BigInt(3), Value::Text("c".to_string())],
1000            vec![Value::BigInt(1), Value::Text("a".to_string())],
1001            vec![Value::BigInt(2), Value::Text("b".to_string())],
1002        ];
1003
1004        let spec = SortSpec {
1005            columns: vec![(0, ScanOrder::Ascending)],
1006        };
1007
1008        sort_rows(&mut rows, &spec);
1009
1010        assert_eq!(rows[0][0], Value::BigInt(1));
1011        assert_eq!(rows[1][0], Value::BigInt(2));
1012        assert_eq!(rows[2][0], Value::BigInt(3));
1013    }
1014
1015    #[test]
1016    fn test_sort_rows_descending() {
1017        let mut rows = vec![
1018            vec![Value::BigInt(1)],
1019            vec![Value::BigInt(3)],
1020            vec![Value::BigInt(2)],
1021        ];
1022
1023        let spec = SortSpec {
1024            columns: vec![(0, ScanOrder::Descending)],
1025        };
1026
1027        sort_rows(&mut rows, &spec);
1028
1029        assert_eq!(rows[0][0], Value::BigInt(3));
1030        assert_eq!(rows[1][0], Value::BigInt(2));
1031        assert_eq!(rows[2][0], Value::BigInt(1));
1032    }
1033}