1use arrow::datatypes::{DataType, SchemaRef};
11
12pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
14 const STRING_ESTIMATE: usize = 256;
15 let mut total: usize = 0;
16 for field in schema.fields() {
17 total += match field.data_type() {
18 DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
19 DataType::Int16 | DataType::UInt16 => 2,
20 DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
21 DataType::Int64
22 | DataType::UInt64
23 | DataType::Float64
24 | DataType::Date64
25 | DataType::Timestamp(_, _)
26 | DataType::Time64(_)
27 | DataType::Duration(_) => 8,
28 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
29 DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
30 STRING_ESTIMATE
31 }
32 _ => 64,
33 };
34 total += 1; }
36 total.max(1)
37}
38
39pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
41 let row_bytes = estimate_row_bytes(schema);
42 let target = memory_mb * 1024 * 1024 / row_bytes;
43 target.clamp(1_000, 500_000)
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49 use arrow::datatypes::{Field, Schema};
50 use std::sync::Arc;
51
52 #[test]
53 fn estimate_row_bytes_basic() {
54 let schema = Arc::new(Schema::new(vec![
55 Field::new("id", DataType::Int64, false),
56 Field::new("name", DataType::Utf8, true),
57 ]));
58 let est = estimate_row_bytes(&schema);
59 assert_eq!(est, 266);
61 }
62
63 #[test]
64 fn compute_batch_size_clamped() {
65 let schema = Arc::new(Schema::new(vec![Field::new(
67 "flag",
68 DataType::Boolean,
69 false,
70 )]));
71 assert_eq!(compute_batch_size_from_memory(256, &schema), 500_000);
72
73 let fields: Vec<Field> = (0..100)
75 .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
76 .collect();
77 let schema = Arc::new(Schema::new(fields));
78 assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
79 }
80}