Skip to main content

rivet/tuning/
memory.rs

1//! Schema-based memory estimation.
2//!
3//! Pure functions that convert an Arrow schema into:
4//! - a per-row byte estimate (`estimate_row_bytes`)
5//! - a `batch_size` count from a target memory budget in MB (`compute_batch_size_from_memory`)
6//!
7//! No DB connection required; used during plan resolution and as a fall-back
8//! when a fetch loop hasn't observed real row sizes yet.
9
10use arrow::datatypes::{DataType, SchemaRef};
11
12/// Estimate average row size in bytes from an Arrow schema.
13pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
14    const STRING_ESTIMATE: usize = 256;
15    let mut total: usize = 0;
16    for field in schema.fields() {
17        total += match field.data_type() {
18            DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
19            DataType::Int16 | DataType::UInt16 => 2,
20            DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
21            DataType::Int64
22            | DataType::UInt64
23            | DataType::Float64
24            | DataType::Date64
25            | DataType::Timestamp(_, _)
26            | DataType::Time64(_)
27            | DataType::Duration(_) => 8,
28            DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
29            DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
30                STRING_ESTIMATE
31            }
32            _ => 64,
33        };
34        total += 1; // validity bitmap overhead (rounded up)
35    }
36    total.max(1)
37}
38
39/// Compute batch_size from a memory target in MB and estimated row size.
40pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
41    let row_bytes = estimate_row_bytes(schema);
42    let target = memory_mb * 1024 * 1024 / row_bytes;
43    target.clamp(1_000, 500_000)
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use arrow::datatypes::{Field, Schema};
50    use std::sync::Arc;
51
52    #[test]
53    fn estimate_row_bytes_basic() {
54        let schema = Arc::new(Schema::new(vec![
55            Field::new("id", DataType::Int64, false),
56            Field::new("name", DataType::Utf8, true),
57        ]));
58        let est = estimate_row_bytes(&schema);
59        // Int64=8+1, Utf8=256+1 = 266
60        assert_eq!(est, 266);
61    }
62
63    #[test]
64    fn compute_batch_size_clamped() {
65        // 1 tiny column -> huge batch, clamped to 500_000
66        let schema = Arc::new(Schema::new(vec![Field::new(
67            "flag",
68            DataType::Boolean,
69            false,
70        )]));
71        assert_eq!(compute_batch_size_from_memory(256, &schema), 500_000);
72
73        // 100 large string columns -> small batch, clamped to 1_000
74        let fields: Vec<Field> = (0..100)
75            .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
76            .collect();
77        let schema = Arc::new(Schema::new(fields));
78        assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
79    }
80}