use crate::error::_exec_datafusion_err;
use crate::{HashSet, Result};
use arrow::array::ArrayData;
use arrow::record_batch::RecordBatch;
use std::{mem::size_of, ptr::NonNull};
pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
num_elements
.checked_mul(8)
.and_then(|overestimate| {
let estimated_buckets = (overestimate / 7).next_power_of_two();
size_of::<T>()
.checked_mul(estimated_buckets)?
.checked_add(estimated_buckets)?
.checked_add(fixed_size)
})
.ok_or_else(|| {
_exec_datafusion_err!("usize overflow while estimating the number of buckets")
})
}
pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize {
let mut counted_buffers: HashSet<NonNull<u8>> = HashSet::new();
let mut total_size = 0;
for array in batch.columns() {
let array_data = array.to_data();
count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size);
}
total_size
}
fn count_array_data_memory_size(
array_data: &ArrayData,
counted_buffers: &mut HashSet<NonNull<u8>>,
total_size: &mut usize,
) {
for buffer in array_data.buffers() {
if counted_buffers.insert(buffer.data_ptr()) {
*total_size += buffer.capacity();
} }
if let Some(null_buffer) = array_data.nulls()
&& counted_buffers.insert(null_buffer.inner().inner().data_ptr())
{
*total_size += null_buffer.inner().inner().capacity();
}
for child in array_data.child_data() {
count_array_data_memory_size(child, counted_buffers, total_size);
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashSet, mem::size_of};
use super::estimate_memory_size;
#[test]
fn test_estimate_memory() {
let fixed_size = size_of::<HashSet<u32>>();
let num_elements = 8;
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
assert_eq!(estimated, 128);
let num_elements = 40;
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
assert_eq!(estimated, 368);
}
#[test]
fn test_estimate_memory_overflow() {
let num_elements = usize::MAX;
let fixed_size = size_of::<HashSet<u32>>();
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);
assert!(estimated.is_err());
}
}
#[cfg(test)]
mod record_batch_tests {
use super::*;
use arrow::array::{Float64Array, Int32Array, ListArray};
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
use std::sync::Arc;
#[test]
fn test_get_record_batch_memory_size() {
let schema = Arc::new(Schema::new(vec![
Field::new("ints", DataType::Int32, true),
Field::new("float64", DataType::Float64, false),
]));
let int_array =
Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]);
let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(int_array), Arc::new(float64_array)],
)
.unwrap();
let size = get_record_batch_memory_size(&batch);
assert_eq!(size, 60);
}
#[test]
fn test_get_record_batch_memory_size_with_null() {
let schema = Arc::new(Schema::new(vec![
Field::new("ints", DataType::Int32, true),
Field::new("float64", DataType::Float64, false),
]));
let int_array = Int32Array::from(vec![None, Some(2), Some(3)]);
let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]);
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(int_array), Arc::new(float64_array)],
)
.unwrap();
let size = get_record_batch_memory_size(&batch);
assert_eq!(size, 100);
}
#[test]
fn test_get_record_batch_memory_size_empty() {
let schema = Arc::new(Schema::new(vec![Field::new(
"ints",
DataType::Int32,
false,
)]));
let int_array: Int32Array = Int32Array::from(vec![] as Vec<i32>);
let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap();
let size = get_record_batch_memory_size(&batch);
assert_eq!(size, 0, "Empty batch should have 0 memory size");
}
#[test]
fn test_get_record_batch_memory_size_shared_buffer() {
let original = Int32Array::from(vec![1, 2, 3, 4, 5]);
let slice1 = original.slice(0, 3);
let slice2 = original.slice(2, 3);
let schema_origin = Arc::new(Schema::new(vec![Field::new(
"origin_col",
DataType::Int32,
false,
)]));
let batch_origin =
RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("slice1", DataType::Int32, false),
Field::new("slice2", DataType::Int32, false),
]));
let batch_sliced =
RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)])
.unwrap();
let size_origin = get_record_batch_memory_size(&batch_origin);
let size_sliced = get_record_batch_memory_size(&batch_sliced);
assert_eq!(size_origin, size_sliced);
}
#[test]
fn test_get_record_batch_memory_size_nested_array() {
let schema = Arc::new(Schema::new(vec![
Field::new(
"nested_int",
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
false,
),
Field::new(
"nested_int2",
DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
false,
),
]));
let int_list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
]);
let int_list_array2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(4), Some(5), Some(6)]),
]);
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(int_list_array), Arc::new(int_list_array2)],
)
.unwrap();
let size = get_record_batch_memory_size(&batch);
assert_eq!(size, 8208);
}
}