datafusion_common/utils/memory.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provides a function to estimate the memory size of a HashTable prior to allocation
19
20use crate::error::_exec_datafusion_err;
21use crate::{HashSet, Result};
22use arrow::array::ArrayData;
23use arrow::record_batch::RecordBatch;
24use std::{mem::size_of, ptr::NonNull};
25
26/// Estimates the memory size required for a hash table prior to allocation.
27///
28/// # Parameters
29/// - `num_elements`: The number of elements expected in the hash table.
30/// - `fixed_size`: A fixed overhead size associated with the collection
31/// (e.g., HashSet or HashTable).
32/// - `T`: The type of elements stored in the hash table.
33///
34/// # Details
35/// This function calculates the estimated memory size by considering:
36/// - An overestimation of buckets to keep approximately 1/8 of them empty.
37/// - The total memory size is computed as:
38/// - The size of each entry (`T`) multiplied by the estimated number of
39/// buckets.
40/// - One byte overhead for each bucket.
41/// - The fixed size overhead of the collection.
42/// - If the estimation overflows, we return a [`crate::error::DataFusionError`]
43///
44/// # Examples
45/// ---
46///
47/// ## From within a struct
48///
49/// ```rust
50/// # use datafusion_common::utils::memory::estimate_memory_size;
51/// # use datafusion_common::Result;
52///
53/// struct MyStruct<T> {
54/// values: Vec<T>,
55/// other_data: usize,
56/// }
57///
58/// impl<T> MyStruct<T> {
59/// fn size(&self) -> Result<usize> {
60/// let num_elements = self.values.len();
61/// let fixed_size =
62/// std::mem::size_of_val(self) + std::mem::size_of_val(&self.values);
63///
64/// estimate_memory_size::<T>(num_elements, fixed_size)
65/// }
66/// }
67/// ```
68/// ---
69/// ## With a simple collection
70///
71/// ```rust
72/// # use datafusion_common::utils::memory::estimate_memory_size;
73/// # use std::collections::HashMap;
74///
75/// let num_rows = 100;
76/// let fixed_size = std::mem::size_of::<HashMap<u64, u64>>();
77/// let estimated_hashtable_size =
78/// estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)
79/// .expect("Size estimation failed");
80/// ```
81pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
82 // For the majority of cases hashbrown overestimates the bucket quantity
83 // to keep ~1/8 of them empty. We take this factor into account by
84 // multiplying the number of elements with a fixed ratio of 8/7 (~1.14).
85 // This formula leads to over-allocation for small tables (< 8 elements)
86 // but should be fine overall.
87 num_elements
88 .checked_mul(8)
89 .and_then(|overestimate| {
90 let estimated_buckets = (overestimate / 7).next_power_of_two();
91 // + size of entry * number of buckets
92 // + 1 byte for each bucket
93 // + fixed size of collection (HashSet/HashTable)
94 size_of::<T>()
95 .checked_mul(estimated_buckets)?
96 .checked_add(estimated_buckets)?
97 .checked_add(fixed_size)
98 })
99 .ok_or_else(|| {
100 _exec_datafusion_err!("usize overflow while estimating the number of buckets")
101 })
102}
103
104/// Calculate total used memory of this batch.
105///
106/// This function is used to estimate the physical memory usage of the `RecordBatch`.
107/// It only counts the memory of large data `Buffer`s, and ignores metadata like
108/// types and pointers.
109/// The implementation will add up all unique `Buffer`'s memory
110/// size, due to:
111/// - The data pointer inside `Buffer` are memory regions returned by global memory
112/// allocator, those regions can't have overlap.
113/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap
114/// or reuse the same `Buffer`. For example: taking a slice from `Array`.
115///
116/// Example:
117/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing
118/// to a sub-region of the same buffer.
119///
120/// {xxxxxxxxxxxxxxxxxxx} <--- buffer
121/// ^ ^ ^ ^
122/// | | | |
123/// col1->{ } | |
124/// col2--------->{ }
125///
126/// In the above case, `get_record_batch_memory_size` will return the size of
127/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size.
128///
129/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the
130/// buffer memory size if multiple arrays within the batch are sharing the same
131/// `Buffer`. This method provides temporary fix until the issue is resolved:
132/// <https://github.com/apache/arrow-rs/issues/6439>
133pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize {
134 // Store pointers to `Buffer`'s start memory address (instead of actual
135 // used data region's pointer represented by current `Array`)
136 let mut counted_buffers: HashSet<NonNull<u8>> = HashSet::new();
137 let mut total_size = 0;
138
139 for array in batch.columns() {
140 let array_data = array.to_data();
141 count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size);
142 }
143
144 total_size
145}
146
147/// Count the memory usage of `array_data` and its children recursively.
148fn count_array_data_memory_size(
149 array_data: &ArrayData,
150 counted_buffers: &mut HashSet<NonNull<u8>>,
151 total_size: &mut usize,
152) {
153 // Count memory usage for `array_data`
154 for buffer in array_data.buffers() {
155 if counted_buffers.insert(buffer.data_ptr()) {
156 *total_size += buffer.capacity();
157 } // Otherwise the buffer's memory is already counted
158 }
159
160 if let Some(null_buffer) = array_data.nulls()
161 && counted_buffers.insert(null_buffer.inner().inner().data_ptr())
162 {
163 *total_size += null_buffer.inner().inner().capacity();
164 }
165
166 // Count all children `ArrayData` recursively
167 for child in array_data.child_data() {
168 count_array_data_memory_size(child, counted_buffers, total_size);
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::{collections::HashSet, mem::size_of};
175
176 use super::estimate_memory_size;
177
178 #[test]
179 fn test_estimate_memory() {
180 // size (bytes): 48
181 let fixed_size = size_of::<HashSet<u32>>();
182
183 // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two()
184 let num_elements = 8;
185 // size (bytes): 128 = 16 * 4 + 16 + 48
186 let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
187 assert_eq!(estimated, 128);
188
189 // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two()
190 let num_elements = 40;
191 // size (bytes): 368 = 64 * 4 + 64 + 48
192 let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
193 assert_eq!(estimated, 368);
194 }
195
196 #[test]
197 fn test_estimate_memory_overflow() {
198 let num_elements = usize::MAX;
199 let fixed_size = size_of::<HashSet<u32>>();
200 let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);
201
202 assert!(estimated.is_err());
203 }
204}
205
206#[cfg(test)]
207mod record_batch_tests {
208 use super::*;
209 use arrow::array::{Float64Array, Int32Array, ListArray};
210 use arrow::datatypes::{DataType, Field, Int32Type, Schema};
211 use std::sync::Arc;
212
213 #[test]
214 fn test_get_record_batch_memory_size() {
215 let schema = Arc::new(Schema::new(vec![
216 Field::new("ints", DataType::Int32, true),
217 Field::new("float64", DataType::Float64, false),
218 ]));
219
220 let int_array =
221 Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]);
222 let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
223
224 let batch = RecordBatch::try_new(
225 schema,
226 vec![Arc::new(int_array), Arc::new(float64_array)],
227 )
228 .unwrap();
229
230 let size = get_record_batch_memory_size(&batch);
231 assert_eq!(size, 60);
232 }
233
234 #[test]
235 fn test_get_record_batch_memory_size_with_null() {
236 let schema = Arc::new(Schema::new(vec![
237 Field::new("ints", DataType::Int32, true),
238 Field::new("float64", DataType::Float64, false),
239 ]));
240
241 let int_array = Int32Array::from(vec![None, Some(2), Some(3)]);
242 let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]);
243
244 let batch = RecordBatch::try_new(
245 schema,
246 vec![Arc::new(int_array), Arc::new(float64_array)],
247 )
248 .unwrap();
249
250 let size = get_record_batch_memory_size(&batch);
251 assert_eq!(size, 100);
252 }
253
254 #[test]
255 fn test_get_record_batch_memory_size_empty() {
256 let schema = Arc::new(Schema::new(vec![Field::new(
257 "ints",
258 DataType::Int32,
259 false,
260 )]));
261
262 let int_array: Int32Array = Int32Array::from(vec![] as Vec<i32>);
263 let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap();
264
265 let size = get_record_batch_memory_size(&batch);
266 assert_eq!(size, 0, "Empty batch should have 0 memory size");
267 }
268
269 #[test]
270 fn test_get_record_batch_memory_size_shared_buffer() {
271 let original = Int32Array::from(vec![1, 2, 3, 4, 5]);
272 let slice1 = original.slice(0, 3);
273 let slice2 = original.slice(2, 3);
274
275 let schema_origin = Arc::new(Schema::new(vec![Field::new(
276 "origin_col",
277 DataType::Int32,
278 false,
279 )]));
280 let batch_origin =
281 RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap();
282
283 let schema = Arc::new(Schema::new(vec![
284 Field::new("slice1", DataType::Int32, false),
285 Field::new("slice2", DataType::Int32, false),
286 ]));
287
288 let batch_sliced =
289 RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)])
290 .unwrap();
291
292 let size_origin = get_record_batch_memory_size(&batch_origin);
293 let size_sliced = get_record_batch_memory_size(&batch_sliced);
294
295 assert_eq!(size_origin, size_sliced);
296 }
297
298 #[test]
299 fn test_get_record_batch_memory_size_nested_array() {
300 let schema = Arc::new(Schema::new(vec![
301 Field::new(
302 "nested_int",
303 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
304 false,
305 ),
306 Field::new(
307 "nested_int2",
308 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
309 false,
310 ),
311 ]));
312
313 let int_list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
314 Some(vec![Some(1), Some(2), Some(3)]),
315 ]);
316
317 let int_list_array2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
318 Some(vec![Some(4), Some(5), Some(6)]),
319 ]);
320
321 let batch = RecordBatch::try_new(
322 schema,
323 vec![Arc::new(int_list_array), Arc::new(int_list_array2)],
324 )
325 .unwrap();
326
327 let size = get_record_batch_memory_size(&batch);
328 assert_eq!(size, 8208);
329 }
330}