datafusion_common/utils/
memory.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provides a function to estimate the memory size of a HashTable prior to allocation
19
20use crate::error::_exec_datafusion_err;
21use crate::{HashSet, Result};
22use arrow::array::ArrayData;
23use arrow::record_batch::RecordBatch;
24use std::{mem::size_of, ptr::NonNull};
25
26/// Estimates the memory size required for a hash table prior to allocation.
27///
28/// # Parameters
29/// - `num_elements`: The number of elements expected in the hash table.
30/// - `fixed_size`: A fixed overhead size associated with the collection
31///   (e.g., HashSet or HashTable).
32/// - `T`: The type of elements stored in the hash table.
33///
34/// # Details
35/// This function calculates the estimated memory size by considering:
36/// - An overestimation of buckets to keep approximately 1/8 of them empty.
37/// - The total memory size is computed as:
38///   - The size of each entry (`T`) multiplied by the estimated number of
39///     buckets.
40///   - One byte overhead for each bucket.
41///   - The fixed size overhead of the collection.
42/// - If the estimation overflows, we return a [`crate::error::DataFusionError`]
43///
44/// # Examples
45/// ---
46///
47/// ## From within a struct
48///
49/// ```rust
50/// # use datafusion_common::utils::memory::estimate_memory_size;
51/// # use datafusion_common::Result;
52///
53/// struct MyStruct<T> {
54///     values: Vec<T>,
55///     other_data: usize,
56/// }
57///
58/// impl<T> MyStruct<T> {
59///     fn size(&self) -> Result<usize> {
60///         let num_elements = self.values.len();
61///         let fixed_size =
62///             std::mem::size_of_val(self) + std::mem::size_of_val(&self.values);
63///
64///         estimate_memory_size::<T>(num_elements, fixed_size)
65///     }
66/// }
67/// ```
68/// ---
69/// ## With a simple collection
70///
71/// ```rust
72/// # use datafusion_common::utils::memory::estimate_memory_size;
73/// # use std::collections::HashMap;
74///
75/// let num_rows = 100;
76/// let fixed_size = std::mem::size_of::<HashMap<u64, u64>>();
77/// let estimated_hashtable_size =
78///     estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)
79///         .expect("Size estimation failed");
80/// ```
81pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
82    // For the majority of cases hashbrown overestimates the bucket quantity
83    // to keep ~1/8 of them empty. We take this factor into account by
84    // multiplying the number of elements with a fixed ratio of 8/7 (~1.14).
85    // This formula leads to over-allocation for small tables (< 8 elements)
86    // but should be fine overall.
87    num_elements
88        .checked_mul(8)
89        .and_then(|overestimate| {
90            let estimated_buckets = (overestimate / 7).next_power_of_two();
91            // + size of entry * number of buckets
92            // + 1 byte for each bucket
93            // + fixed size of collection (HashSet/HashTable)
94            size_of::<T>()
95                .checked_mul(estimated_buckets)?
96                .checked_add(estimated_buckets)?
97                .checked_add(fixed_size)
98        })
99        .ok_or_else(|| {
100            _exec_datafusion_err!("usize overflow while estimating the number of buckets")
101        })
102}
103
104/// Calculate total used memory of this batch.
105///
106/// This function is used to estimate the physical memory usage of the `RecordBatch`.
107/// It only counts the memory of large data `Buffer`s, and ignores metadata like
108/// types and pointers.
109/// The implementation will add up all unique `Buffer`'s memory
110/// size, due to:
111/// - The data pointer inside `Buffer` are memory regions returned by global memory
112///   allocator, those regions can't have overlap.
113/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap
114///   or reuse the same `Buffer`. For example: taking a slice from `Array`.
115///
116/// Example:
117/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing
118/// to a sub-region of the same buffer.
119///
120/// {xxxxxxxxxxxxxxxxxxx} <--- buffer
121///       ^    ^  ^    ^
122///       |    |  |    |
123/// col1->{    }  |    |
124/// col2--------->{    }
125///
126/// In the above case, `get_record_batch_memory_size` will return the size of
127/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size.
128///
129/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the
130/// buffer memory size if multiple arrays within the batch are sharing the same
131/// `Buffer`. This method provides temporary fix until the issue is resolved:
132/// <https://github.com/apache/arrow-rs/issues/6439>
133pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize {
134    // Store pointers to `Buffer`'s start memory address (instead of actual
135    // used data region's pointer represented by current `Array`)
136    let mut counted_buffers: HashSet<NonNull<u8>> = HashSet::new();
137    let mut total_size = 0;
138
139    for array in batch.columns() {
140        let array_data = array.to_data();
141        count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size);
142    }
143
144    total_size
145}
146
147/// Count the memory usage of `array_data` and its children recursively.
148fn count_array_data_memory_size(
149    array_data: &ArrayData,
150    counted_buffers: &mut HashSet<NonNull<u8>>,
151    total_size: &mut usize,
152) {
153    // Count memory usage for `array_data`
154    for buffer in array_data.buffers() {
155        if counted_buffers.insert(buffer.data_ptr()) {
156            *total_size += buffer.capacity();
157        } // Otherwise the buffer's memory is already counted
158    }
159
160    if let Some(null_buffer) = array_data.nulls()
161        && counted_buffers.insert(null_buffer.inner().inner().data_ptr())
162    {
163        *total_size += null_buffer.inner().inner().capacity();
164    }
165
166    // Count all children `ArrayData` recursively
167    for child in array_data.child_data() {
168        count_array_data_memory_size(child, counted_buffers, total_size);
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use std::{collections::HashSet, mem::size_of};
175
176    use super::estimate_memory_size;
177
178    #[test]
179    fn test_estimate_memory() {
180        // size (bytes): 48
181        let fixed_size = size_of::<HashSet<u32>>();
182
183        // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two()
184        let num_elements = 8;
185        // size (bytes): 128 = 16 * 4 + 16 + 48
186        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
187        assert_eq!(estimated, 128);
188
189        // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two()
190        let num_elements = 40;
191        // size (bytes): 368 = 64 * 4 + 64 + 48
192        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
193        assert_eq!(estimated, 368);
194    }
195
196    #[test]
197    fn test_estimate_memory_overflow() {
198        let num_elements = usize::MAX;
199        let fixed_size = size_of::<HashSet<u32>>();
200        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);
201
202        assert!(estimated.is_err());
203    }
204}
205
206#[cfg(test)]
207mod record_batch_tests {
208    use super::*;
209    use arrow::array::{Float64Array, Int32Array, ListArray};
210    use arrow::datatypes::{DataType, Field, Int32Type, Schema};
211    use std::sync::Arc;
212
213    #[test]
214    fn test_get_record_batch_memory_size() {
215        let schema = Arc::new(Schema::new(vec![
216            Field::new("ints", DataType::Int32, true),
217            Field::new("float64", DataType::Float64, false),
218        ]));
219
220        let int_array =
221            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]);
222        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
223
224        let batch = RecordBatch::try_new(
225            schema,
226            vec![Arc::new(int_array), Arc::new(float64_array)],
227        )
228        .unwrap();
229
230        let size = get_record_batch_memory_size(&batch);
231        assert_eq!(size, 60);
232    }
233
234    #[test]
235    fn test_get_record_batch_memory_size_with_null() {
236        let schema = Arc::new(Schema::new(vec![
237            Field::new("ints", DataType::Int32, true),
238            Field::new("float64", DataType::Float64, false),
239        ]));
240
241        let int_array = Int32Array::from(vec![None, Some(2), Some(3)]);
242        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]);
243
244        let batch = RecordBatch::try_new(
245            schema,
246            vec![Arc::new(int_array), Arc::new(float64_array)],
247        )
248        .unwrap();
249
250        let size = get_record_batch_memory_size(&batch);
251        assert_eq!(size, 100);
252    }
253
254    #[test]
255    fn test_get_record_batch_memory_size_empty() {
256        let schema = Arc::new(Schema::new(vec![Field::new(
257            "ints",
258            DataType::Int32,
259            false,
260        )]));
261
262        let int_array: Int32Array = Int32Array::from(vec![] as Vec<i32>);
263        let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap();
264
265        let size = get_record_batch_memory_size(&batch);
266        assert_eq!(size, 0, "Empty batch should have 0 memory size");
267    }
268
269    #[test]
270    fn test_get_record_batch_memory_size_shared_buffer() {
271        let original = Int32Array::from(vec![1, 2, 3, 4, 5]);
272        let slice1 = original.slice(0, 3);
273        let slice2 = original.slice(2, 3);
274
275        let schema_origin = Arc::new(Schema::new(vec![Field::new(
276            "origin_col",
277            DataType::Int32,
278            false,
279        )]));
280        let batch_origin =
281            RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap();
282
283        let schema = Arc::new(Schema::new(vec![
284            Field::new("slice1", DataType::Int32, false),
285            Field::new("slice2", DataType::Int32, false),
286        ]));
287
288        let batch_sliced =
289            RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)])
290                .unwrap();
291
292        let size_origin = get_record_batch_memory_size(&batch_origin);
293        let size_sliced = get_record_batch_memory_size(&batch_sliced);
294
295        assert_eq!(size_origin, size_sliced);
296    }
297
298    #[test]
299    fn test_get_record_batch_memory_size_nested_array() {
300        let schema = Arc::new(Schema::new(vec![
301            Field::new(
302                "nested_int",
303                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
304                false,
305            ),
306            Field::new(
307                "nested_int2",
308                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
309                false,
310            ),
311        ]));
312
313        let int_list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
314            Some(vec![Some(1), Some(2), Some(3)]),
315        ]);
316
317        let int_list_array2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
318            Some(vec![Some(4), Some(5), Some(6)]),
319        ]);
320
321        let batch = RecordBatch::try_new(
322            schema,
323            vec![Arc::new(int_list_array), Arc::new(int_list_array2)],
324        )
325        .unwrap();
326
327        let size = get_record_batch_memory_size(&batch);
328        assert_eq!(size, 8208);
329    }
330}