lance_arrow/
memory.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5
6use arrow_array::{Array, RecordBatch};
7use arrow_data::ArrayData;
8
9/// Counts memory used by buffers of Arrow arrays and RecordBatches.
10///
11/// This is meant to capture how much memory is being used by the Arrow data
12/// structures as they are. It does not represent the memory used if the data
13/// were to be serialized and then deserialized. In particular:
14///
15/// * This does not double count memory used by buffers shared by multiple
16///   arrays or batches. Round-tripped data may use more memory because of this.
17/// * This counts the **total** size of the buffers, even if the array is a slice.
18///   Round-tripped data may use less memory because of this.
19#[derive(Default)]
20pub struct MemoryAccumulator {
21    seen: HashSet<usize>,
22    total: usize,
23}
24
25impl MemoryAccumulator {
26    pub fn record_array(&mut self, array: &dyn Array) {
27        let data = array.to_data();
28        self.record_array_data(&data);
29    }
30
31    fn record_array_data(&mut self, data: &ArrayData) {
32        for buffer in data.buffers() {
33            let ptr = buffer.as_ptr();
34            if self.seen.insert(ptr as usize) {
35                self.total += buffer.capacity();
36            }
37        }
38
39        if let Some(nulls) = data.nulls() {
40            let null_buf = nulls.inner().inner();
41            let ptr = null_buf.as_ptr();
42            if self.seen.insert(ptr as usize) {
43                self.total += null_buf.capacity();
44            }
45        }
46
47        for child in data.child_data() {
48            self.record_array_data(child);
49        }
50    }
51
52    pub fn record_batch(&mut self, batch: &RecordBatch) {
53        for array in batch.columns() {
54            self.record_array(array);
55        }
56    }
57
58    pub fn total(&self) -> usize {
59        self.total
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::sync::Arc;
66
67    use arrow_array::Int32Array;
68    use arrow_schema::{DataType, Field, Schema};
69
70    use super::*;
71
72    #[test]
73    fn test_memory_accumulator() {
74        let batch = RecordBatch::try_new(
75            Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
76            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
77        )
78        .unwrap();
79        let slice = batch.slice(1, 2);
80
81        let mut acc = MemoryAccumulator::default();
82
83        // Should record whole buffer, not just slice
84        acc.record_batch(&slice);
85        assert_eq!(acc.total(), 3 * std::mem::size_of::<i32>());
86
87        // Should not double count
88        acc.record_batch(&slice);
89        assert_eq!(acc.total(), 3 * std::mem::size_of::<i32>());
90    }
91}