Skip to main content

rivet_cli/
tuning.rs

1use arrow::datatypes::{DataType, SchemaRef};
2use serde::Deserialize;
3
4#[derive(Debug, Clone)]
5pub struct SourceTuning {
6    pub batch_size: usize,
7    pub batch_size_memory_mb: Option<usize>,
8    pub throttle_ms: u64,
9    pub statement_timeout_s: u64,
10    pub max_retries: u32,
11    pub retry_backoff_ms: u64,
12    pub lock_timeout_s: u64,
13    pub memory_threshold_mb: usize,
14    configured_profile: TuningProfile,
15}
16
17#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)]
18#[serde(rename_all = "lowercase")]
19pub enum TuningProfile {
20    Fast,
21    Balanced,
22    Safe,
23}
24
25#[derive(Debug, Deserialize, Default, Clone)]
26pub struct TuningConfig {
27    pub profile: Option<TuningProfile>,
28    pub batch_size: Option<usize>,
29    /// Target memory per batch in MB. Mutually exclusive with batch_size.
30    pub batch_size_memory_mb: Option<usize>,
31    pub throttle_ms: Option<u64>,
32    pub statement_timeout_s: Option<u64>,
33    pub max_retries: Option<u32>,
34    pub retry_backoff_ms: Option<u64>,
35    pub lock_timeout_s: Option<u64>,
36    pub memory_threshold_mb: Option<usize>,
37}
38
39/// Layer `export` on top of `source`: each field uses export when set, otherwise source.
40/// `None` only when both inputs are `None`.
41pub fn merge_tuning_config(
42    source: Option<&TuningConfig>,
43    export: Option<&TuningConfig>,
44) -> Option<TuningConfig> {
45    match (source, export) {
46        (None, None) => None,
47        (Some(s), None) => Some(s.clone()),
48        (None, Some(e)) => Some(e.clone()),
49        (Some(s), Some(e)) => Some(TuningConfig {
50            profile: e.profile.or(s.profile),
51            batch_size: e.batch_size.or(s.batch_size),
52            batch_size_memory_mb: e.batch_size_memory_mb.or(s.batch_size_memory_mb),
53            throttle_ms: e.throttle_ms.or(s.throttle_ms),
54            statement_timeout_s: e.statement_timeout_s.or(s.statement_timeout_s),
55            max_retries: e.max_retries.or(s.max_retries),
56            retry_backoff_ms: e.retry_backoff_ms.or(s.retry_backoff_ms),
57            lock_timeout_s: e.lock_timeout_s.or(s.lock_timeout_s),
58            memory_threshold_mb: e.memory_threshold_mb.or(s.memory_threshold_mb),
59        }),
60    }
61}
62
63impl SourceTuning {
64    pub fn from_config(config: Option<&TuningConfig>) -> Self {
65        let profile = config
66            .and_then(|c| c.profile)
67            .unwrap_or(TuningProfile::Balanced);
68
69        let mut tuning = Self::from_profile(profile);
70        tuning.configured_profile = profile;
71
72        if let Some(cfg) = config {
73            if let Some(v) = cfg.batch_size {
74                tuning.batch_size = v;
75            }
76            tuning.batch_size_memory_mb = cfg.batch_size_memory_mb;
77            if let Some(v) = cfg.throttle_ms {
78                tuning.throttle_ms = v;
79            }
80            if let Some(v) = cfg.statement_timeout_s {
81                tuning.statement_timeout_s = v;
82            }
83            if let Some(v) = cfg.max_retries {
84                tuning.max_retries = v;
85            }
86            if let Some(v) = cfg.retry_backoff_ms {
87                tuning.retry_backoff_ms = v;
88            }
89            if let Some(v) = cfg.lock_timeout_s {
90                tuning.lock_timeout_s = v;
91            }
92            if let Some(v) = cfg.memory_threshold_mb {
93                tuning.memory_threshold_mb = v;
94            }
95        }
96
97        tuning
98    }
99
100    fn from_profile(profile: TuningProfile) -> Self {
101        match profile {
102            TuningProfile::Fast => Self {
103                batch_size: 50_000,
104                batch_size_memory_mb: None,
105                throttle_ms: 0,
106                statement_timeout_s: 0,
107                max_retries: 1,
108                retry_backoff_ms: 1_000,
109                lock_timeout_s: 0,
110                memory_threshold_mb: 0,
111                configured_profile: TuningProfile::Fast,
112            },
113            TuningProfile::Balanced => Self {
114                batch_size: 10_000,
115                batch_size_memory_mb: None,
116                throttle_ms: 50,
117                statement_timeout_s: 300,
118                max_retries: 3,
119                retry_backoff_ms: 2_000,
120                lock_timeout_s: 30,
121                memory_threshold_mb: 0,
122                configured_profile: TuningProfile::Balanced,
123            },
124            TuningProfile::Safe => Self {
125                batch_size: 2_000,
126                batch_size_memory_mb: None,
127                throttle_ms: 500,
128                statement_timeout_s: 120,
129                max_retries: 5,
130                retry_backoff_ms: 5_000,
131                lock_timeout_s: 10,
132                memory_threshold_mb: 0,
133                configured_profile: TuningProfile::Safe,
134            },
135        }
136    }
137
138    pub fn profile_name(&self) -> &'static str {
139        match self.configured_profile {
140            TuningProfile::Fast => "fast",
141            TuningProfile::Balanced => "balanced",
142            TuningProfile::Safe => "safe",
143        }
144    }
145}
146
147impl std::fmt::Display for SourceTuning {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        write!(
150            f,
151            "profile={}, batch_size={}, throttle={}ms, timeout={}s, retries={}, lock_timeout={}s",
152            self.profile_name(),
153            self.batch_size,
154            self.throttle_ms,
155            self.statement_timeout_s,
156            self.max_retries,
157            self.lock_timeout_s,
158        )
159    }
160}
161
162/// Estimate average row size in bytes from an Arrow schema.
163pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
164    const STRING_ESTIMATE: usize = 256;
165    let mut total: usize = 0;
166    for field in schema.fields() {
167        total += match field.data_type() {
168            DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
169            DataType::Int16 | DataType::UInt16 => 2,
170            DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
171            DataType::Int64
172            | DataType::UInt64
173            | DataType::Float64
174            | DataType::Date64
175            | DataType::Timestamp(_, _)
176            | DataType::Time64(_)
177            | DataType::Duration(_) => 8,
178            DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
179            DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
180                STRING_ESTIMATE
181            }
182            _ => 64,
183        };
184        total += 1; // validity bitmap overhead (rounded up)
185    }
186    total.max(1)
187}
188
189/// Compute batch_size from a memory target in MB and estimated row size.
190pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
191    let row_bytes = estimate_row_bytes(schema);
192    let target = memory_mb * 1024 * 1024 / row_bytes;
193    target.clamp(1_000, 500_000)
194}
195
196impl SourceTuning {
197    /// If `batch_size_memory_mb` is set, compute and return an adjusted batch_size
198    /// from the schema; otherwise return the configured `batch_size`.
199    pub fn effective_batch_size(&self, schema: Option<&SchemaRef>) -> usize {
200        if let (Some(mem_mb), Some(schema)) = (self.batch_size_memory_mb, schema) {
201            let computed = compute_batch_size_from_memory(mem_mb, schema);
202            log::info!(
203                "batch_size_memory_mb={}: estimated row ~{}B, computed batch_size={}",
204                mem_mb,
205                estimate_row_bytes(schema),
206                computed
207            );
208            computed
209        } else {
210            self.batch_size
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    fn cfg_with_profile(profile: TuningProfile) -> TuningConfig {
220        TuningConfig {
221            profile: Some(profile),
222            ..Default::default()
223        }
224    }
225
226    #[test]
227    fn default_config_uses_balanced_profile() {
228        let t = SourceTuning::from_config(None);
229        assert_eq!(t.batch_size, 10_000);
230        assert_eq!(t.throttle_ms, 50);
231        assert_eq!(t.statement_timeout_s, 300);
232        assert_eq!(t.max_retries, 3);
233        assert_eq!(t.retry_backoff_ms, 2_000);
234        assert_eq!(t.lock_timeout_s, 30);
235    }
236
237    #[test]
238    fn fast_profile_favors_throughput() {
239        let t = SourceTuning::from_config(Some(&cfg_with_profile(TuningProfile::Fast)));
240        assert_eq!(t.batch_size, 50_000);
241        assert_eq!(t.throttle_ms, 0);
242        assert_eq!(t.statement_timeout_s, 0);
243        assert_eq!(t.max_retries, 1);
244    }
245
246    #[test]
247    fn safe_profile_limits_impact() {
248        let t = SourceTuning::from_config(Some(&cfg_with_profile(TuningProfile::Safe)));
249        assert_eq!(t.batch_size, 2_000);
250        assert_eq!(t.throttle_ms, 500);
251        assert_eq!(t.statement_timeout_s, 120);
252        assert_eq!(t.max_retries, 5);
253        assert_eq!(t.retry_backoff_ms, 5_000);
254        assert_eq!(t.lock_timeout_s, 10);
255    }
256
257    #[test]
258    fn explicit_fields_override_profile_defaults() {
259        let cfg = TuningConfig {
260            profile: Some(TuningProfile::Safe),
261            batch_size: Some(3_000),
262            throttle_ms: Some(250),
263            ..Default::default()
264        };
265        let t = SourceTuning::from_config(Some(&cfg));
266        assert_eq!(t.batch_size, 3_000, "explicit batch_size should win");
267        assert_eq!(t.throttle_ms, 250, "explicit throttle_ms should win");
268        assert_eq!(
269            t.statement_timeout_s, 120,
270            "non-overridden field stays at safe default"
271        );
272        assert_eq!(
273            t.max_retries, 5,
274            "non-overridden field stays at safe default"
275        );
276    }
277
278    #[test]
279    fn profile_name_fast() {
280        let t = SourceTuning::from_config(Some(&cfg_with_profile(TuningProfile::Fast)));
281        assert_eq!(t.profile_name(), "fast");
282    }
283
284    #[test]
285    fn profile_name_balanced() {
286        let t = SourceTuning::from_config(None);
287        assert_eq!(t.profile_name(), "balanced");
288    }
289
290    #[test]
291    fn profile_name_safe() {
292        let t = SourceTuning::from_config(Some(&cfg_with_profile(TuningProfile::Safe)));
293        assert_eq!(t.profile_name(), "safe");
294    }
295
296    #[test]
297    fn display_contains_all_fields() {
298        let t = SourceTuning::from_config(None);
299        let s = t.to_string();
300        assert!(s.contains("profile=balanced"), "missing profile in: {s}");
301        assert!(s.contains("batch_size=10000"), "missing batch_size in: {s}");
302        assert!(s.contains("throttle=50ms"), "missing throttle in: {s}");
303        assert!(s.contains("timeout=300s"), "missing timeout in: {s}");
304        assert!(s.contains("retries=3"), "missing retries in: {s}");
305        assert!(
306            s.contains("lock_timeout=30s"),
307            "missing lock_timeout in: {s}"
308        );
309    }
310
311    #[test]
312    fn estimate_row_bytes_basic() {
313        use arrow::datatypes::{Field, Schema};
314        use std::sync::Arc;
315        let schema = Arc::new(Schema::new(vec![
316            Field::new("id", arrow::datatypes::DataType::Int64, false),
317            Field::new("name", arrow::datatypes::DataType::Utf8, true),
318        ]));
319        let est = estimate_row_bytes(&schema);
320        // Int64=8+1, Utf8=256+1 = 266
321        assert_eq!(est, 266);
322    }
323
324    #[test]
325    fn compute_batch_size_clamped() {
326        use arrow::datatypes::{Field, Schema};
327        use std::sync::Arc;
328        // 1 tiny column -> huge batch, clamped to 500_000
329        let schema = Arc::new(Schema::new(vec![Field::new(
330            "flag",
331            arrow::datatypes::DataType::Boolean,
332            false,
333        )]));
334        assert_eq!(compute_batch_size_from_memory(256, &schema), 500_000);
335
336        // 100 large string columns -> small batch, clamped to 1_000
337        let fields: Vec<Field> = (0..100)
338            .map(|i| Field::new(format!("c{i}"), arrow::datatypes::DataType::Utf8, true))
339            .collect();
340        let schema = Arc::new(Schema::new(fields));
341        assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
342    }
343
344    #[test]
345    fn merge_tuning_export_overrides_source_fields() {
346        let source = TuningConfig {
347            profile: Some(TuningProfile::Fast),
348            batch_size: Some(1_000),
349            throttle_ms: Some(0),
350            ..Default::default()
351        };
352        let export = TuningConfig {
353            profile: Some(TuningProfile::Safe),
354            batch_size: None,
355            ..Default::default()
356        };
357        let m = merge_tuning_config(Some(&source), Some(&export)).expect("merged");
358        assert_eq!(m.profile, Some(TuningProfile::Safe));
359        assert_eq!(
360            m.batch_size,
361            Some(1_000),
362            "export omitted batch_size -> keep source"
363        );
364        assert_eq!(m.throttle_ms, Some(0));
365    }
366
367    #[test]
368    fn merge_tuning_export_only() {
369        let e = cfg_with_profile(TuningProfile::Fast);
370        let m = merge_tuning_config(None, Some(&e)).expect("merged");
371        assert_eq!(m.profile, Some(TuningProfile::Fast));
372    }
373
374    #[test]
375    fn effective_batch_size_without_memory() {
376        let t = SourceTuning::from_config(None);
377        assert_eq!(t.effective_batch_size(None), 10_000);
378    }
379
380    #[test]
381    fn effective_batch_size_with_memory() {
382        use arrow::datatypes::{Field, Schema};
383        use std::sync::Arc;
384        let cfg = TuningConfig {
385            batch_size_memory_mb: Some(256),
386            ..Default::default()
387        };
388        let t = SourceTuning::from_config(Some(&cfg));
389        let schema = Arc::new(Schema::new(vec![
390            Field::new("id", arrow::datatypes::DataType::Int64, false),
391            Field::new("name", arrow::datatypes::DataType::Utf8, true),
392        ]));
393        let bs = t.effective_batch_size(Some(&schema));
394        assert!((1_000..=500_000).contains(&bs), "got {bs}");
395        // 256MB / 266B ≈ 1_009_022, clamped to 500_000
396        assert_eq!(bs, 500_000);
397    }
398}