rivet/tuning/memory.rs
1//! Schema-based memory estimation.
2//!
3//! Pure functions that convert an Arrow schema into:
4//! - a per-row byte estimate (`estimate_row_bytes`)
5//! - a `batch_size` count from a target memory budget in MB (`compute_batch_size_from_memory`)
6//!
7//! No DB connection required; used during plan resolution and as a fall-back
8//! when a fetch loop hasn't observed real row sizes yet.
9
10use arrow::datatypes::{DataType, SchemaRef};
11
12/// Estimate average row size in bytes from an Arrow schema.
13pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
14 const STRING_ESTIMATE: usize = 256;
15 let mut total: usize = 0;
16 for field in schema.fields() {
17 total += match field.data_type() {
18 DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
19 DataType::Int16 | DataType::UInt16 => 2,
20 DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
21 DataType::Int64
22 | DataType::UInt64
23 | DataType::Float64
24 | DataType::Date64
25 | DataType::Timestamp(_, _)
26 | DataType::Time64(_)
27 | DataType::Duration(_) => 8,
28 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
29 DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
30 STRING_ESTIMATE
31 }
32 _ => 64,
33 };
34 total += 1; // validity bitmap overhead (rounded up)
35 }
36 total.max(1)
37}
38
39/// Compute batch_size from a memory target in MB and estimated row size.
40///
41/// The 150k upper clamp bounds the *raw-row accumulator* an engine holds
42/// alongside the Arrow batch (mysql/tiberius `Vec<Row>`): for narrow rows that
43/// raw buffer is several× the compact Arrow form, so it — not the MB target —
44/// drives peak RSS. 150k still gives ~15× fewer pipeline flushes than the old
45/// static 10k (most of the throughput win) at a fraction of the peak RSS that
46/// a 500k cap incurred on narrow tables.
47pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
48 let row_bytes = estimate_row_bytes(schema);
49 let target = memory_mb * 1024 * 1024 / row_bytes;
50 target.clamp(1_000, 150_000)
51}
52
53/// Default RSS budget (MB) the scaffold sizes `parallel:` against, and the
54/// threshold `preflight::check` uses to decide whether an unindexed large-table
55/// scan is UNSAFE (over budget) or merely DEGRADED (fits) — one budget, shared by
56/// the scaffold and the check.
57pub(crate) const DEFAULT_MEM_BUDGET_MB: u64 = 2048;
58
59/// Per-worker peak RSS (MB) under the default *adaptive* batching, fitted to the
60/// sweep in `docs/bench/reports/REPORT_full_vs_parallel.md`. Anchored on measured
61/// points — ~19 MB/worker at ~40 B/row (narrow), ~105 MB at ~4 KB/row (wide) —
62/// and clamped to a ceiling of ≈ 2× the adaptive batch target. The driver is
63/// **row width × in-flight batch, not chunk_size** (chunk_size only sets file
64/// count). An explicit large `tuning.batch_size` overrides adaptive batching and
65/// raises this beyond the model.
66///
67/// This is the *catalog-width* estimate (pre-schema `avg_row_bytes`), shared by
68/// `init` (scaffold) and `preflight::check`. Its schema-resolved sibling is
69/// [`compute_batch_size_from_memory`] above; the ceiling is single-sourced from
70/// [`crate::source::batch_controller::DEFAULT_BATCH_TARGET_MB`] so the two can't
71/// silently drift on the batch target. The slope/floor remain empirical (the
72/// bench sweep, not derivable from the schema-bytes model).
73pub(crate) fn per_worker_rss_mb(avg_row_bytes: i64) -> u64 {
74 const FLOOR_MB: u64 = 18;
75 // ~2× the adaptive batch target (Arrow builders + parquet row-group + zstd
76 // hold roughly twice the raw in-flight batch). Linked to the source of truth
77 // so it tracks the batch target instead of drifting.
78 const CEIL_MB: u64 = 2 * crate::source::batch_controller::DEFAULT_BATCH_TARGET_MB as u64;
79 let b = avg_row_bytes.max(0) as u64;
80 (FLOOR_MB + b * 87 / 4096).clamp(FLOOR_MB, CEIL_MB)
81}
82
83/// Predicted peak process RSS (MB) for a chunked export with `parallel` workers.
84/// `peak ≈ 16 (process base) + parallel × per_worker_rss_mb(width)`. Linear in
85/// `parallel`; slightly *over*-estimates past ~4 workers (allocator reuse) — the
86/// safe direction for a budget. Validated against the sweep (par 4 wide: est 436
87/// vs measured 444 MB; par 8 narrow: est 166 vs 169 MB).
88pub(crate) fn estimate_peak_rss_mb(parallel: usize, avg_row_bytes: i64) -> u64 {
89 const PROCESS_BASE_MB: u64 = 16;
90 PROCESS_BASE_MB + parallel as u64 * per_worker_rss_mb(avg_row_bytes)
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use arrow::datatypes::{Field, Schema};
97 use std::sync::Arc;
98
99 #[test]
100 fn estimate_row_bytes_basic() {
101 let schema = Arc::new(Schema::new(vec![
102 Field::new("id", DataType::Int64, false),
103 Field::new("name", DataType::Utf8, true),
104 ]));
105 let est = estimate_row_bytes(&schema);
106 // Int64=8+1, Utf8=256+1 = 266
107 assert_eq!(est, 266);
108 }
109
110 #[test]
111 fn compute_batch_size_clamped() {
112 // 1 tiny column -> huge batch, clamped to 150_000
113 let schema = Arc::new(Schema::new(vec![Field::new(
114 "flag",
115 DataType::Boolean,
116 false,
117 )]));
118 assert_eq!(compute_batch_size_from_memory(256, &schema), 150_000);
119
120 // 100 large string columns -> small batch, clamped to 1_000
121 let fields: Vec<Field> = (0..100)
122 .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
123 .collect();
124 let schema = Arc::new(Schema::new(fields));
125 assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
126 }
127}