1use std::collections::HashSet;
5
6use arrow_array::{Array, RecordBatch};
7use arrow_data::ArrayData;
8
9#[derive(Default)]
20pub struct MemoryAccumulator {
21 seen: HashSet<usize>,
22 total: usize,
23}
24
25impl MemoryAccumulator {
26 pub fn record_array(&mut self, array: &dyn Array) {
27 let data = array.to_data();
28 self.record_array_data(&data);
29 }
30
31 fn record_array_data(&mut self, data: &ArrayData) {
32 for buffer in data.buffers() {
33 let ptr = buffer.as_ptr();
34 if self.seen.insert(ptr as usize) {
35 self.total += buffer.capacity();
36 }
37 }
38
39 if let Some(nulls) = data.nulls() {
40 let null_buf = nulls.inner().inner();
41 let ptr = null_buf.as_ptr();
42 if self.seen.insert(ptr as usize) {
43 self.total += null_buf.capacity();
44 }
45 }
46
47 for child in data.child_data() {
48 self.record_array_data(child);
49 }
50 }
51
52 pub fn record_batch(&mut self, batch: &RecordBatch) {
53 for array in batch.columns() {
54 self.record_array(array);
55 }
56 }
57
58 pub fn total(&self) -> usize {
59 self.total
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use std::sync::Arc;
66
67 use arrow_array::Int32Array;
68 use arrow_schema::{DataType, Field, Schema};
69
70 use super::*;
71
72 #[test]
73 fn test_memory_accumulator() {
74 let batch = RecordBatch::try_new(
75 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
76 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
77 )
78 .unwrap();
79 let slice = batch.slice(1, 2);
80
81 let mut acc = MemoryAccumulator::default();
82
83 acc.record_batch(&slice);
85 assert_eq!(acc.total(), 3 * std::mem::size_of::<i32>());
86
87 acc.record_batch(&slice);
89 assert_eq!(acc.total(), 3 * std::mem::size_of::<i32>());
90 }
91}