use arrow::datatypes::{DataType, SchemaRef};
pub fn estimate_row_bytes(schema: &SchemaRef) -> usize {
const STRING_ESTIMATE: usize = 256;
let mut total: usize = 0;
for field in schema.fields() {
total += match field.data_type() {
DataType::Boolean | DataType::Int8 | DataType::UInt8 => 1,
DataType::Int16 | DataType::UInt16 => 2,
DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Date32 => 4,
DataType::Int64
| DataType::UInt64
| DataType::Float64
| DataType::Date64
| DataType::Timestamp(_, _)
| DataType::Time64(_)
| DataType::Duration(_) => 8,
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 16,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
STRING_ESTIMATE
}
_ => 64,
};
total += 1; }
total.max(1)
}
pub fn compute_batch_size_from_memory(memory_mb: usize, schema: &SchemaRef) -> usize {
let row_bytes = estimate_row_bytes(schema);
let target = memory_mb * 1024 * 1024 / row_bytes;
target.clamp(1_000, 500_000)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{Field, Schema};
use std::sync::Arc;
#[test]
fn estimate_row_bytes_basic() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]));
let est = estimate_row_bytes(&schema);
assert_eq!(est, 266);
}
#[test]
fn compute_batch_size_clamped() {
let schema = Arc::new(Schema::new(vec![Field::new(
"flag",
DataType::Boolean,
false,
)]));
assert_eq!(compute_batch_size_from_memory(256, &schema), 500_000);
let fields: Vec<Field> = (0..100)
.map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
.collect();
let schema = Arc::new(Schema::new(fields));
assert_eq!(compute_batch_size_from_memory(1, &schema), 1_000);
}
}