datafusion_common/utils/
memory.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provides a function to estimate the memory size of a HashTable prior to allocation
19
20use crate::error::_exec_datafusion_err;
21use crate::Result;
22use std::mem::size_of;
23
24/// Estimates the memory size required for a hash table prior to allocation.
25///
26/// # Parameters
27/// - `num_elements`: The number of elements expected in the hash table.
28/// - `fixed_size`: A fixed overhead size associated with the collection
29///   (e.g., HashSet or HashTable).
30/// - `T`: The type of elements stored in the hash table.
31///
32/// # Details
33/// This function calculates the estimated memory size by considering:
34/// - An overestimation of buckets to keep approximately 1/8 of them empty.
35/// - The total memory size is computed as:
36///   - The size of each entry (`T`) multiplied by the estimated number of
37///     buckets.
38///   - One byte overhead for each bucket.
39///   - The fixed size overhead of the collection.
40/// - If the estimation overflows, we return a [`crate::error::DataFusionError`]
41///
42/// # Examples
43/// ---
44///
45/// ## From within a struct
46///
47/// ```rust
48/// # use datafusion_common::utils::memory::estimate_memory_size;
49/// # use datafusion_common::Result;
50///
51/// struct MyStruct<T> {
52///     values: Vec<T>,
53///     other_data: usize,
54/// }
55///
56/// impl<T> MyStruct<T> {
57///     fn size(&self) -> Result<usize> {
58///         let num_elements = self.values.len();
59///         let fixed_size =
60///             std::mem::size_of_val(self) + std::mem::size_of_val(&self.values);
61///
62///         estimate_memory_size::<T>(num_elements, fixed_size)
63///     }
64/// }
65/// ```
66/// ---
67/// ## With a simple collection
68///
69/// ```rust
70/// # use datafusion_common::utils::memory::estimate_memory_size;
71/// # use std::collections::HashMap;
72///
73/// let num_rows = 100;
74/// let fixed_size = std::mem::size_of::<HashMap<u64, u64>>();
75/// let estimated_hashtable_size =
76///     estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)
77///         .expect("Size estimation failed");
78/// ```
79pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
80    // For the majority of cases hashbrown overestimates the bucket quantity
81    // to keep ~1/8 of them empty. We take this factor into account by
82    // multiplying the number of elements with a fixed ratio of 8/7 (~1.14).
83    // This formula leads to over-allocation for small tables (< 8 elements)
84    // but should be fine overall.
85    num_elements
86        .checked_mul(8)
87        .and_then(|overestimate| {
88            let estimated_buckets = (overestimate / 7).next_power_of_two();
89            // + size of entry * number of buckets
90            // + 1 byte for each bucket
91            // + fixed size of collection (HashSet/HashTable)
92            size_of::<T>()
93                .checked_mul(estimated_buckets)?
94                .checked_add(estimated_buckets)?
95                .checked_add(fixed_size)
96        })
97        .ok_or_else(|| {
98            _exec_datafusion_err!("usize overflow while estimating the number of buckets")
99        })
100}
101
102#[cfg(test)]
103mod tests {
104    use std::{collections::HashSet, mem::size_of};
105
106    use super::estimate_memory_size;
107
108    #[test]
109    fn test_estimate_memory() {
110        // size (bytes): 48
111        let fixed_size = size_of::<HashSet<u32>>();
112
113        // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two()
114        let num_elements = 8;
115        // size (bytes): 128 = 16 * 4 + 16 + 48
116        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
117        assert_eq!(estimated, 128);
118
119        // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two()
120        let num_elements = 40;
121        // size (bytes): 368 = 64 * 4 + 64 + 48
122        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
123        assert_eq!(estimated, 368);
124    }
125
126    #[test]
127    fn test_estimate_memory_overflow() {
128        let num_elements = usize::MAX;
129        let fixed_size = size_of::<HashSet<u32>>();
130        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);
131
132        assert!(estimated.is_err());
133    }
134}