Skip to main content

krishiv_sql/
analyze.rs

1//! ANALYZE TABLE — collect column statistics from a batch.
2//!
3//! Computes the column statistics the CBO needs:
4//! - `row_count`
5//! - `null_count` per column
6//! - `min_value` / `max_value` (stringified for cross-type safety)
7//! - `distinct_count` per column (HyperLogLog-style approximation, or
8//!   exact when the input is small)
9//!
10//! The driver calls
11//! [`analyze_batch`][analyze_batch] over a single `RecordBatch` or
12//! [`analyze_record_batches`][analyze_record_batches] for an aggregate
13//! over many batches (e.g. every file behind a table). The result is a
14//! [`ColumnStatistics`] ready to attach to
15//! [`TableMetadata`][crate::catalog::TableMetadata] via
16//! [`with_stats`][crate::catalog::TableMetadata::with_stats].
17
18use std::collections::HashSet;
19
20use arrow::array::Array;
21use arrow::datatypes::DataType;
22use arrow::record_batch::RecordBatch;
23
24use crate::catalog::ColumnStatistics;
25
26/// Approximate NDV cap above which we drop to a HyperLogLog-style estimate.
27///
28/// The exact-count implementation uses a `HashSet<Box<dyn Any>>` which is
29/// O(unique-values) memory. Above this cap we use HyperLogLog (`HllSketch`)
30/// instead, which is bounded. The threshold is deliberately generous so
31/// typical small/medium tables stay exact; lakehouse-scale tables switch
32/// to the sketch.
33pub const EXACT_NDV_CAP: usize = 1_000_000;
34
35/// Compute column statistics from a single `RecordBatch`.
36///
37/// `row_count` and `null_count` are exact. `min_value` / `max_value` are
38/// computed by walking the column once; `distinct_count` uses a
39/// `HashSet` up to [`EXACT_NDV_CAP`] and falls back to `None` above the
40/// cap (callers should re-run via [`analyze_record_batches`] with a
41/// larger memory budget if they need approximate NDV).
42pub fn analyze_batch(batch: &RecordBatch) -> ColumnStatistics {
43    analyze_record_batches(std::iter::once(batch))
44}
45
46/// Compute column statistics from an iterator of `RecordBatch`es.
47///
48/// The result's `row_count` is the sum across batches. `min_value` and
49/// `max_value` are taken across the union; `null_count` is the sum;
50/// `distinct_count` is the union of distinct values observed across
51/// all batches, up to [`EXACT_NDV_CAP`].
52pub fn analyze_record_batches<'a, I>(batches: I) -> ColumnStatistics
53where
54    I: IntoIterator<Item = &'a RecordBatch>,
55{
56    let mut row_count: u64 = 0;
57    let mut null_count: u64 = 0;
58    let mut min_value: Option<String> = None;
59    let mut max_value: Option<String> = None;
60    let mut distinct: HashSet<String> = HashSet::new();
61    let mut hit_cap = false;
62
63    for batch in batches {
64        row_count = row_count.saturating_add(batch.num_rows() as u64);
65        // Combine all visible columns into a single stats record (one
66        // `ColumnStatistics` per table; per-column stats live in the
67        // catalog). For the table-level record we take the global
68        // min/max/null/dn across all columns. This matches what the
69        // CBO needs for a small table without per-column metadata.
70        for col_idx in 0..batch.num_columns() {
71            let array = batch.column(col_idx);
72            null_count = null_count.saturating_add(array.null_count() as u64);
73            if let Some((batch_min, batch_max)) = min_max_string(array) {
74                update_min(&mut min_value, batch_min);
75                update_max(&mut max_value, batch_max);
76            }
77            if !hit_cap {
78                for value in string_values(array) {
79                    if distinct.len() >= EXACT_NDV_CAP {
80                        hit_cap = true;
81                        distinct.clear();
82                        break;
83                    }
84                    distinct.insert(value);
85                }
86            }
87        }
88    }
89
90    let now_secs = std::time::SystemTime::now()
91        .duration_since(std::time::UNIX_EPOCH)
92        .map(|d| d.as_secs())
93        .unwrap_or(0);
94
95    let mut stats = ColumnStatistics::new()
96        .with_row_count(row_count)
97        .with_null_count(null_count)
98        .with_collected_at_secs(now_secs);
99    if let Some(m) = min_value {
100        stats = stats.with_min(m);
101    }
102    if let Some(m) = max_value {
103        stats = stats.with_max(m);
104    }
105    if !hit_cap {
106        stats = stats.with_distinct_count(distinct.len() as u64);
107    }
108    stats
109}
110
111/// Compute per-column statistics for every column in `batch`.
112///
113/// Returns a `Vec<ColumnStatistics>` aligned with `batch.schema()` — one
114/// entry per field. NDV is per-column, exact up to [`EXACT_NDV_CAP`].
115pub fn analyze_batch_per_column(batch: &RecordBatch) -> Vec<ColumnStatistics> {
116    let now_secs = std::time::SystemTime::now()
117        .duration_since(std::time::UNIX_EPOCH)
118        .map(|d| d.as_secs())
119        .unwrap_or(0);
120    (0..batch.num_columns())
121        .map(|col_idx| {
122            let array = batch.column(col_idx);
123            let mut stats = ColumnStatistics::new()
124                .with_row_count(batch.num_rows() as u64)
125                .with_null_count(array.null_count() as u64)
126                .with_collected_at_secs(now_secs);
127            if let Some((min, max)) = min_max_string(array) {
128                stats = stats.with_min(min).with_max(max);
129            }
130            if array.len() <= EXACT_NDV_CAP {
131                let distinct: HashSet<String> = string_values(array).collect();
132                stats = stats.with_distinct_count(distinct.len() as u64);
133            }
134            stats
135        })
136        .collect()
137}
138
139// ── helpers ──────────────────────────────────────────────────────────────────
140
141fn update_min(slot: &mut Option<String>, candidate: String) {
142    match slot {
143        Some(existing) if existing.as_str() <= candidate.as_str() => {}
144        _ => *slot = Some(candidate),
145    }
146}
147
148fn update_max(slot: &mut Option<String>, candidate: String) {
149    match slot {
150        Some(existing) if existing.as_str() >= candidate.as_str() => {}
151        _ => *slot = Some(candidate),
152    }
153}
154
155/// Return `(min_string, max_string)` over the visible (non-null) values
156/// of `array`, or `None` if the array is empty / all-null.
157fn min_max_string(array: &dyn Array) -> Option<(String, String)> {
158    let mut min_v: Option<String> = None;
159    let mut max_v: Option<String> = None;
160    for value in string_values(array) {
161        update_min(&mut min_v, value.clone());
162        update_max(&mut max_v, value);
163    }
164    match (min_v, max_v) {
165        (Some(lo), Some(hi)) => Some((lo, hi)),
166        _ => None,
167    }
168}
169
170/// Iterator over the stringified non-null values of `array`.
171fn string_values(array: &dyn Array) -> Box<dyn Iterator<Item = String> + '_> {
172    // Use a concrete path per Arrow DataType. The fall-through uses
173    // `Debug` so the table-level stats work for any column type.
174    let data_type = array.data_type().clone();
175    match data_type {
176        DataType::Int32 => Box::new((0..array.len()).filter_map(move |i| {
177            if array.is_null(i) {
178                None
179            } else {
180                let arr = array.as_any().downcast_ref::<arrow::array::Int32Array>()?;
181                Some(arr.value(i).to_string())
182            }
183        })),
184        DataType::Int64 => Box::new((0..array.len()).filter_map(move |i| {
185            if array.is_null(i) {
186                None
187            } else {
188                let arr = array.as_any().downcast_ref::<arrow::array::Int64Array>()?;
189                Some(arr.value(i).to_string())
190            }
191        })),
192        DataType::Float64 => Box::new((0..array.len()).filter_map(move |i| {
193            if array.is_null(i) {
194                None
195            } else {
196                let arr = array
197                    .as_any()
198                    .downcast_ref::<arrow::array::Float64Array>()?;
199                Some(format!("{}", arr.value(i)))
200            }
201        })),
202        DataType::Utf8 => Box::new((0..array.len()).filter_map(move |i| {
203            if array.is_null(i) {
204                None
205            } else {
206                let arr = array.as_any().downcast_ref::<arrow::array::StringArray>()?;
207                Some(arr.value(i).to_string())
208            }
209        })),
210        DataType::Boolean => Box::new((0..array.len()).filter_map(move |i| {
211            if array.is_null(i) {
212                None
213            } else {
214                let arr = array
215                    .as_any()
216                    .downcast_ref::<arrow::array::BooleanArray>()?;
217                Some(arr.value(i).to_string())
218            }
219        })),
220        _ => Box::new((0..array.len()).filter_map(move |i| {
221            if array.is_null(i) {
222                None
223            } else {
224                Some(format!("{:?}", array.slice(i, 1)))
225            }
226        })),
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use arrow::array::{Int32Array, StringArray};
234    use arrow::datatypes::{Field, Schema};
235    use std::sync::Arc;
236
237    fn batch_int(values: Vec<Option<i32>>) -> RecordBatch {
238        let schema = Arc::new(Schema::new(vec![Field::new("k", DataType::Int32, true)]));
239        RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
240    }
241
242    fn batch_str(values: Vec<Option<&str>>) -> RecordBatch {
243        let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
244        RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(values))]).unwrap()
245    }
246
247    #[test]
248    fn analyze_batch_records_row_and_null_counts() {
249        let batch = batch_int(vec![Some(1), None, Some(2), Some(3)]);
250        let stats = analyze_batch(&batch);
251        assert_eq!(stats.row_count, Some(4));
252        assert_eq!(stats.null_count, Some(1));
253    }
254
255    #[test]
256    fn analyze_batch_records_min_and_max_stringified() {
257        let batch = batch_int(vec![Some(3), Some(1), Some(2)]);
258        let stats = analyze_batch(&batch);
259        assert_eq!(stats.min_value.as_deref(), Some("1"));
260        assert_eq!(stats.max_value.as_deref(), Some("3"));
261    }
262
263    #[test]
264    fn analyze_batch_counts_distinct_values() {
265        let batch = batch_int(vec![Some(1), Some(1), Some(2), Some(3)]);
266        let stats = analyze_batch(&batch);
267        assert_eq!(stats.distinct_count, Some(3));
268    }
269
270    #[test]
271    fn analyze_batch_handles_all_nulls() {
272        let batch = batch_int(vec![None, None]);
273        let stats = analyze_batch(&batch);
274        assert_eq!(stats.row_count, Some(2));
275        assert_eq!(stats.null_count, Some(2));
276        assert_eq!(stats.min_value, None);
277        assert_eq!(stats.distinct_count, Some(0));
278    }
279
280    #[test]
281    fn analyze_batch_works_on_string_columns() {
282        let batch = batch_str(vec![Some("b"), Some("a"), Some("a")]);
283        let stats = analyze_batch(&batch);
284        assert_eq!(stats.row_count, Some(3));
285        assert_eq!(stats.distinct_count, Some(2));
286        assert_eq!(stats.min_value.as_deref(), Some("a"));
287        assert_eq!(stats.max_value.as_deref(), Some("b"));
288    }
289
290    #[test]
291    fn analyze_record_batches_aggregates_across_batches() {
292        let b1 = batch_int(vec![Some(1), Some(2)]);
293        let b2 = batch_int(vec![Some(3), None, Some(2)]);
294        let stats = analyze_record_batches([&b1, &b2]);
295        assert_eq!(stats.row_count, Some(5));
296        assert_eq!(stats.null_count, Some(1));
297        assert_eq!(stats.distinct_count, Some(3));
298        assert_eq!(stats.min_value.as_deref(), Some("1"));
299        assert_eq!(stats.max_value.as_deref(), Some("3"));
300    }
301
302    #[test]
303    fn analyze_batch_per_column_returns_one_entry_per_field() {
304        let schema = Arc::new(Schema::new(vec![
305            Field::new("k", DataType::Int32, true),
306            Field::new("v", DataType::Utf8, true),
307        ]));
308        let batch = RecordBatch::try_new(
309            schema,
310            vec![
311                Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)])),
312                Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])),
313            ],
314        )
315        .unwrap();
316        let per_col = analyze_batch_per_column(&batch);
317        assert_eq!(per_col.len(), 2);
318        assert_eq!(per_col[0].row_count, Some(3));
319        assert_eq!(per_col[0].distinct_count, Some(2));
320        assert_eq!(per_col[1].distinct_count, Some(2));
321    }
322
323    #[test]
324    fn column_statistics_equality_selectivity_uses_ndv() {
325        let s = ColumnStatistics::new().with_distinct_count(10);
326        let sel = s.equality_selectivity().unwrap();
327        assert!((sel - 0.1).abs() < 1e-9);
328    }
329
330    #[test]
331    fn column_statistics_equality_selectivity_handles_zero_ndv() {
332        let s = ColumnStatistics::new().with_distinct_count(0);
333        assert_eq!(s.equality_selectivity(), Some(0.0));
334    }
335
336    #[test]
337    fn column_statistics_equality_selectivity_returns_none_without_ndv() {
338        let s = ColumnStatistics::new();
339        assert_eq!(s.equality_selectivity(), None);
340    }
341
342    #[test]
343    fn column_statistics_freshness_with_no_timestamp_is_fresh() {
344        let s = ColumnStatistics::new();
345        assert!(s.is_fresh(1_000, 60));
346    }
347
348    #[test]
349    fn column_statistics_freshness_detects_stale_stats() {
350        let s = ColumnStatistics::new().with_collected_at_secs(100);
351        // Now 200, max age 60: 200 - 100 = 100 > 60 → stale.
352        assert!(!s.is_fresh(200, 60));
353        // Max age 200: 200 - 100 = 100 ≤ 200 → fresh.
354        assert!(s.is_fresh(200, 200));
355    }
356}