polars_pipe/executors/sinks/group_by/generic/
thread_local.rs

1use std::sync::LazyLock;
2
3use arrow::array::MutableBinaryArray;
4use polars_utils::hashing::hash_to_partition;
5
6use super::*;
7use crate::pipeline::PARTITION_SIZE;
8
9const OB_SIZE: usize = 2048;
10
11static SPILL_SIZE: LazyLock<usize> = LazyLock::new(|| {
12    std::env::var("POLARS_STREAMING_GROUPBY_SPILL_SIZE")
13        .map(|v| v.parse::<usize>().unwrap())
14        .unwrap_or(10_000)
15});
16
17#[derive(Clone)]
18struct SpillPartitions {
19    // outer vec: partitions (factor of 2)
20    // inner vec: number of keys + number of aggregated columns
21    keys_partitioned: PartitionVec<MutableBinaryArray<i64>>,
22    aggs_partitioned: PartitionVec<Vec<AnyValueBufferTrusted<'static>>>,
23    hash_partitioned: PartitionVec<Vec<u64>>,
24    chunk_index_partitioned: PartitionVec<Vec<IdxSize>>,
25    spilled: bool,
26    // this only fills during the reduce phase IFF
27    // there are spilled tuples
28    finished_payloads: PartitionVec<Vec<SpillPayload>>,
29    keys_dtypes: Arc<[DataType]>,
30    agg_dtypes: Arc<[DataType]>,
31    output_schema: SchemaRef,
32}
33
34impl SpillPartitions {
35    fn new(keys: Arc<[DataType]>, aggs: Arc<[DataType]>, output_schema: SchemaRef) -> Self {
36        let hash_partitioned = vec![];
37        let chunk_index_partitioned = vec![];
38
39        // construct via split so that pre-allocation succeeds
40        Self {
41            keys_partitioned: vec![],
42            aggs_partitioned: vec![],
43            hash_partitioned,
44            chunk_index_partitioned,
45            spilled: false,
46            finished_payloads: vec![],
47            keys_dtypes: keys,
48            agg_dtypes: aggs,
49            output_schema,
50        }
51        .split()
52    }
53
54    fn split(&self) -> Self {
55        let n_columns = self.agg_dtypes.as_ref().len();
56
57        let aggs_partitioned = (0..PARTITION_SIZE)
58            .map(|_| {
59                let mut buf = Vec::with_capacity(n_columns);
60                for dtype in self.agg_dtypes.as_ref() {
61                    let builder = AnyValueBufferTrusted::new(&dtype.to_physical(), OB_SIZE);
62                    buf.push(builder);
63                }
64                buf
65            })
66            .collect();
67
68        let keys_partitioned = (0..PARTITION_SIZE)
69            .map(|_| MutableBinaryArray::with_capacity(OB_SIZE))
70            .collect();
71
72        let hash_partitioned = (0..PARTITION_SIZE)
73            .map(|_| Vec::with_capacity(OB_SIZE))
74            .collect::<Vec<_>>();
75        let chunk_index_partitioned = (0..PARTITION_SIZE)
76            .map(|_| Vec::with_capacity(OB_SIZE))
77            .collect::<Vec<_>>();
78
79        Self {
80            keys_partitioned,
81            aggs_partitioned,
82            hash_partitioned,
83            chunk_index_partitioned,
84            spilled: false,
85            finished_payloads: vec![],
86            keys_dtypes: self.keys_dtypes.clone(),
87            agg_dtypes: self.agg_dtypes.clone(),
88            output_schema: self.output_schema.clone(),
89        }
90    }
91}
92
93impl SpillPartitions {
94    /// Returns (partition, overflowing hashes, chunk_indexes, keys and aggs)
95    fn insert(
96        &mut self,
97        hash: u64,
98        chunk_idx: IdxSize,
99        row: &[u8],
100        agg_iters: &mut [SeriesPhysIter],
101    ) -> Option<(usize, SpillPayload)> {
102        let partition = hash_to_partition(hash, self.aggs_partitioned.len());
103        self.spilled = true;
104        unsafe {
105            let agg_values = self.aggs_partitioned.get_unchecked_mut(partition);
106            let hashes = self.hash_partitioned.get_unchecked_mut(partition);
107            let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition);
108            let key_builder = self.keys_partitioned.get_unchecked_mut(partition);
109
110            hashes.push(hash);
111            chunk_indexes.push(chunk_idx);
112
113            // amortize the loop counter
114            key_builder.push(Some(row));
115            for (i, agg) in agg_iters.iter_mut().enumerate() {
116                let av = agg.next().unwrap_unchecked();
117                let buf = agg_values.get_unchecked_mut(i);
118                buf.add_unchecked_borrowed_physical(&av);
119            }
120
121            if hashes.len() >= OB_SIZE {
122                let mut new_hashes = Vec::with_capacity(OB_SIZE);
123                let mut new_chunk_indexes = Vec::with_capacity(OB_SIZE);
124                let mut new_keys_builder = MutableBinaryArray::with_capacity(OB_SIZE);
125                std::mem::swap(&mut new_hashes, hashes);
126                std::mem::swap(&mut new_chunk_indexes, chunk_indexes);
127                std::mem::swap(&mut new_keys_builder, key_builder);
128
129                Some((
130                    partition,
131                    SpillPayload {
132                        hashes: new_hashes,
133                        chunk_idx: new_chunk_indexes,
134                        keys: new_keys_builder.into(),
135                        aggs: agg_values
136                            .iter_mut()
137                            .zip(self.output_schema.iter_names())
138                            .map(|(b, name)| {
139                                let mut s = b.reset(OB_SIZE, false).unwrap();
140                                s.rename(name.clone());
141                                s
142                            })
143                            .collect(),
144                    },
145                ))
146            } else {
147                None
148            }
149        }
150    }
151
152    fn finish(&mut self) {
153        if self.spilled {
154            let all_spilled = self.get_all_spilled().collect::<Vec<_>>();
155            for (partition_i, payload) in all_spilled {
156                let buf = if let Some(buf) = self.finished_payloads.get_mut(partition_i) {
157                    buf
158                } else {
159                    self.finished_payloads.push(vec![]);
160                    self.finished_payloads.last_mut().unwrap()
161                };
162                buf.push(payload)
163            }
164        }
165    }
166
167    fn combine(&mut self, other: &mut Self) {
168        match (self.spilled, other.spilled) {
169            (false, true) => std::mem::swap(self, other),
170            (true, false) => {},
171            (false, false) => {},
172            (true, true) => {
173                self.finish();
174                other.finish();
175                let other_payloads = std::mem::take(&mut other.finished_payloads);
176
177                for (part_self, part_other) in self.finished_payloads.iter_mut().zip(other_payloads)
178                {
179                    part_self.extend(part_other)
180                }
181            },
182        }
183    }
184
185    fn get_all_spilled(&mut self) -> impl Iterator<Item = (usize, SpillPayload)> + '_ {
186        // todo! allocate
187        let mut flattened = vec![];
188        let finished_payloads = std::mem::take(&mut self.finished_payloads);
189        for (part, payloads) in finished_payloads.into_iter().enumerate() {
190            for payload in payloads {
191                flattened.push((part, payload))
192            }
193        }
194
195        (0..PARTITION_SIZE)
196            .map(|partition| unsafe {
197                let spilled_aggs = self.aggs_partitioned.get_unchecked_mut(partition);
198                let hashes = self.hash_partitioned.get_unchecked_mut(partition);
199                let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition);
200                let keys_builder =
201                    std::mem::take(self.keys_partitioned.get_unchecked_mut(partition));
202                let hashes = std::mem::take(hashes);
203                let chunk_idx = std::mem::take(chunk_indexes);
204
205                (
206                    partition,
207                    SpillPayload {
208                        hashes,
209                        chunk_idx,
210                        keys: keys_builder.into(),
211                        aggs: spilled_aggs
212                            .iter_mut()
213                            .map(|b| b.reset(0, false).unwrap())
214                            .collect(),
215                    },
216                )
217            })
218            .chain(flattened)
219    }
220}
221
222pub(super) struct ThreadLocalTable {
223    inner_map: AggHashTable<true>,
224    spill_partitions: SpillPartitions,
225}
226
227impl ThreadLocalTable {
228    pub(super) fn new(
229        agg_constructors: Arc<[AggregateFunction]>,
230        key_dtypes: Arc<[DataType]>,
231        agg_dtypes: Arc<[DataType]>,
232        output_schema: SchemaRef,
233    ) -> Self {
234        let spill_partitions =
235            SpillPartitions::new(key_dtypes.clone(), agg_dtypes, output_schema.clone());
236
237        Self {
238            inner_map: AggHashTable::new(
239                agg_constructors,
240                key_dtypes.as_ref(),
241                output_schema,
242                Some(*SPILL_SIZE),
243            ),
244            spill_partitions,
245        }
246    }
247
248    pub(super) fn split(&self) -> Self {
249        // should be called before any chunk is processed
250        debug_assert!(self.inner_map.is_empty());
251
252        Self {
253            inner_map: self.inner_map.split(),
254            spill_partitions: self.spill_partitions.clone(),
255        }
256    }
257
258    pub(super) fn get_inner_map_mut(&mut self) -> &mut AggHashTable<true> {
259        &mut self.inner_map
260    }
261
262    /// # Safety
263    /// Caller must ensure that `keys` and `agg_iters` are not depleted.
264    #[inline]
265    pub(super) unsafe fn insert(
266        &mut self,
267        hash: u64,
268        keys_row: &[u8],
269        agg_iters: &mut [SeriesPhysIter],
270        chunk_index: IdxSize,
271    ) -> Option<(usize, SpillPayload)> {
272        if self
273            .inner_map
274            .insert(hash, keys_row, agg_iters, chunk_index)
275        {
276            self.spill_partitions
277                .insert(hash, chunk_index, keys_row, agg_iters)
278        } else {
279            None
280        }
281    }
282
283    pub(super) fn combine(&mut self, other: &mut Self) {
284        self.inner_map.combine(&other.inner_map);
285        self.spill_partitions.combine(&mut other.spill_partitions);
286    }
287
288    pub(super) fn finalize(&mut self, slice: &mut Option<(i64, usize)>) -> Option<DataFrame> {
289        if !self.spill_partitions.spilled {
290            Some(self.inner_map.finalize(slice))
291        } else {
292            None
293        }
294    }
295
296    pub(super) fn get_all_spilled(&mut self) -> impl Iterator<Item = (usize, SpillPayload)> + '_ {
297        self.spill_partitions.get_all_spilled()
298    }
299}