1use arrow::datatypes::{DataType, SchemaRef};
11
12pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
14 const STRING_ESTIMATE: usize = 256;
15 let mut total: usize = 0;
16 for field in schema.fields() {
17 total += match field.data_type() {
18 DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
19 DataType::Int16 | DataType::UInt16 => 2,
20 DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
21 DataType::Int64
22 | DataType::UInt64
23 | DataType::Float64
24 | DataType::Date64
25 | DataType::Timestamp(_, _)
26 | DataType::Time64(_)
27 | DataType::Duration(_) => 8,
28 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
29 DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
30 STRING_ESTIMATE
31 }
32 _ => 64,
33 };
34 total += 1; }
36 total.max(1)
37}
38
39pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
48 let row_bytes = estimate_row_bytes(schema);
49 let target = memory_mb * 1024 * 1024 / row_bytes;
50 target.clamp(1_000, 150_000)
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use arrow::datatypes::{Field, Schema};
57 use std::sync::Arc;
58
59 #[test]
60 fn estimate_row_bytes_basic() {
61 let schema = Arc::new(Schema::new(vec![
62 Field::new("id", DataType::Int64, false),
63 Field::new("name", DataType::Utf8, true),
64 ]));
65 let est = estimate_row_bytes(&schema);
66 assert_eq!(est, 266);
68 }
69
70 #[test]
71 fn compute_batch_size_clamped() {
72 let schema = Arc::new(Schema::new(vec![Field::new(
74 "flag",
75 DataType::Boolean,
76 false,
77 )]));
78 assert_eq!(compute_batch_size_from_memory(256, &schema), 150_000);
79
80 let fields: Vec<Field> = (0..100)
82 .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
83 .collect();
84 let schema = Arc::new(Schema::new(fields));
85 assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
86 }
87}