polars_pipe/executors/sinks/group_by/generic/
thread_local.rs1use std::sync::LazyLock;
2
3use arrow::array::MutableBinaryArray;
4use polars_utils::hashing::hash_to_partition;
5
6use super::*;
7use crate::pipeline::PARTITION_SIZE;
8
9const OB_SIZE: usize = 2048;
10
11static SPILL_SIZE: LazyLock<usize> = LazyLock::new(|| {
12 std::env::var("POLARS_STREAMING_GROUPBY_SPILL_SIZE")
13 .map(|v| v.parse::<usize>().unwrap())
14 .unwrap_or(10_000)
15});
16
17#[derive(Clone)]
18struct SpillPartitions {
19 keys_partitioned: PartitionVec<MutableBinaryArray<i64>>,
22 aggs_partitioned: PartitionVec<Vec<AnyValueBufferTrusted<'static>>>,
23 hash_partitioned: PartitionVec<Vec<u64>>,
24 chunk_index_partitioned: PartitionVec<Vec<IdxSize>>,
25 spilled: bool,
26 finished_payloads: PartitionVec<Vec<SpillPayload>>,
29 keys_dtypes: Arc<[DataType]>,
30 agg_dtypes: Arc<[DataType]>,
31 output_schema: SchemaRef,
32}
33
34impl SpillPartitions {
35 fn new(keys: Arc<[DataType]>, aggs: Arc<[DataType]>, output_schema: SchemaRef) -> Self {
36 let hash_partitioned = vec![];
37 let chunk_index_partitioned = vec![];
38
39 Self {
41 keys_partitioned: vec![],
42 aggs_partitioned: vec![],
43 hash_partitioned,
44 chunk_index_partitioned,
45 spilled: false,
46 finished_payloads: vec![],
47 keys_dtypes: keys,
48 agg_dtypes: aggs,
49 output_schema,
50 }
51 .split()
52 }
53
54 fn split(&self) -> Self {
55 let n_columns = self.agg_dtypes.as_ref().len();
56
57 let aggs_partitioned = (0..PARTITION_SIZE)
58 .map(|_| {
59 let mut buf = Vec::with_capacity(n_columns);
60 for dtype in self.agg_dtypes.as_ref() {
61 let builder = AnyValueBufferTrusted::new(&dtype.to_physical(), OB_SIZE);
62 buf.push(builder);
63 }
64 buf
65 })
66 .collect();
67
68 let keys_partitioned = (0..PARTITION_SIZE)
69 .map(|_| MutableBinaryArray::with_capacity(OB_SIZE))
70 .collect();
71
72 let hash_partitioned = (0..PARTITION_SIZE)
73 .map(|_| Vec::with_capacity(OB_SIZE))
74 .collect::<Vec<_>>();
75 let chunk_index_partitioned = (0..PARTITION_SIZE)
76 .map(|_| Vec::with_capacity(OB_SIZE))
77 .collect::<Vec<_>>();
78
79 Self {
80 keys_partitioned,
81 aggs_partitioned,
82 hash_partitioned,
83 chunk_index_partitioned,
84 spilled: false,
85 finished_payloads: vec![],
86 keys_dtypes: self.keys_dtypes.clone(),
87 agg_dtypes: self.agg_dtypes.clone(),
88 output_schema: self.output_schema.clone(),
89 }
90 }
91}
92
93impl SpillPartitions {
94 fn insert(
96 &mut self,
97 hash: u64,
98 chunk_idx: IdxSize,
99 row: &[u8],
100 agg_iters: &mut [SeriesPhysIter],
101 ) -> Option<(usize, SpillPayload)> {
102 let partition = hash_to_partition(hash, self.aggs_partitioned.len());
103 self.spilled = true;
104 unsafe {
105 let agg_values = self.aggs_partitioned.get_unchecked_mut(partition);
106 let hashes = self.hash_partitioned.get_unchecked_mut(partition);
107 let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition);
108 let key_builder = self.keys_partitioned.get_unchecked_mut(partition);
109
110 hashes.push(hash);
111 chunk_indexes.push(chunk_idx);
112
113 key_builder.push(Some(row));
115 for (i, agg) in agg_iters.iter_mut().enumerate() {
116 let av = agg.next().unwrap_unchecked();
117 let buf = agg_values.get_unchecked_mut(i);
118 buf.add_unchecked_borrowed_physical(&av);
119 }
120
121 if hashes.len() >= OB_SIZE {
122 let mut new_hashes = Vec::with_capacity(OB_SIZE);
123 let mut new_chunk_indexes = Vec::with_capacity(OB_SIZE);
124 let mut new_keys_builder = MutableBinaryArray::with_capacity(OB_SIZE);
125 std::mem::swap(&mut new_hashes, hashes);
126 std::mem::swap(&mut new_chunk_indexes, chunk_indexes);
127 std::mem::swap(&mut new_keys_builder, key_builder);
128
129 Some((
130 partition,
131 SpillPayload {
132 hashes: new_hashes,
133 chunk_idx: new_chunk_indexes,
134 keys: new_keys_builder.into(),
135 aggs: agg_values
136 .iter_mut()
137 .zip(self.output_schema.iter_names())
138 .map(|(b, name)| {
139 let mut s = b.reset(OB_SIZE, false).unwrap();
140 s.rename(name.clone());
141 s
142 })
143 .collect(),
144 },
145 ))
146 } else {
147 None
148 }
149 }
150 }
151
152 fn finish(&mut self) {
153 if self.spilled {
154 let all_spilled = self.get_all_spilled().collect::<Vec<_>>();
155 for (partition_i, payload) in all_spilled {
156 let buf = if let Some(buf) = self.finished_payloads.get_mut(partition_i) {
157 buf
158 } else {
159 self.finished_payloads.push(vec![]);
160 self.finished_payloads.last_mut().unwrap()
161 };
162 buf.push(payload)
163 }
164 }
165 }
166
167 fn combine(&mut self, other: &mut Self) {
168 match (self.spilled, other.spilled) {
169 (false, true) => std::mem::swap(self, other),
170 (true, false) => {},
171 (false, false) => {},
172 (true, true) => {
173 self.finish();
174 other.finish();
175 let other_payloads = std::mem::take(&mut other.finished_payloads);
176
177 for (part_self, part_other) in self.finished_payloads.iter_mut().zip(other_payloads)
178 {
179 part_self.extend(part_other)
180 }
181 },
182 }
183 }
184
185 fn get_all_spilled(&mut self) -> impl Iterator<Item = (usize, SpillPayload)> + '_ {
186 let mut flattened = vec![];
188 let finished_payloads = std::mem::take(&mut self.finished_payloads);
189 for (part, payloads) in finished_payloads.into_iter().enumerate() {
190 for payload in payloads {
191 flattened.push((part, payload))
192 }
193 }
194
195 (0..PARTITION_SIZE)
196 .map(|partition| unsafe {
197 let spilled_aggs = self.aggs_partitioned.get_unchecked_mut(partition);
198 let hashes = self.hash_partitioned.get_unchecked_mut(partition);
199 let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition);
200 let keys_builder =
201 std::mem::take(self.keys_partitioned.get_unchecked_mut(partition));
202 let hashes = std::mem::take(hashes);
203 let chunk_idx = std::mem::take(chunk_indexes);
204
205 (
206 partition,
207 SpillPayload {
208 hashes,
209 chunk_idx,
210 keys: keys_builder.into(),
211 aggs: spilled_aggs
212 .iter_mut()
213 .map(|b| b.reset(0, false).unwrap())
214 .collect(),
215 },
216 )
217 })
218 .chain(flattened)
219 }
220}
221
222pub(super) struct ThreadLocalTable {
223 inner_map: AggHashTable<true>,
224 spill_partitions: SpillPartitions,
225}
226
227impl ThreadLocalTable {
228 pub(super) fn new(
229 agg_constructors: Arc<[AggregateFunction]>,
230 key_dtypes: Arc<[DataType]>,
231 agg_dtypes: Arc<[DataType]>,
232 output_schema: SchemaRef,
233 ) -> Self {
234 let spill_partitions =
235 SpillPartitions::new(key_dtypes.clone(), agg_dtypes, output_schema.clone());
236
237 Self {
238 inner_map: AggHashTable::new(
239 agg_constructors,
240 key_dtypes.as_ref(),
241 output_schema,
242 Some(*SPILL_SIZE),
243 ),
244 spill_partitions,
245 }
246 }
247
248 pub(super) fn split(&self) -> Self {
249 debug_assert!(self.inner_map.is_empty());
251
252 Self {
253 inner_map: self.inner_map.split(),
254 spill_partitions: self.spill_partitions.clone(),
255 }
256 }
257
258 pub(super) fn get_inner_map_mut(&mut self) -> &mut AggHashTable<true> {
259 &mut self.inner_map
260 }
261
262 #[inline]
265 pub(super) unsafe fn insert(
266 &mut self,
267 hash: u64,
268 keys_row: &[u8],
269 agg_iters: &mut [SeriesPhysIter],
270 chunk_index: IdxSize,
271 ) -> Option<(usize, SpillPayload)> {
272 if self
273 .inner_map
274 .insert(hash, keys_row, agg_iters, chunk_index)
275 {
276 self.spill_partitions
277 .insert(hash, chunk_index, keys_row, agg_iters)
278 } else {
279 None
280 }
281 }
282
283 pub(super) fn combine(&mut self, other: &mut Self) {
284 self.inner_map.combine(&other.inner_map);
285 self.spill_partitions.combine(&mut other.spill_partitions);
286 }
287
288 pub(super) fn finalize(&mut self, slice: &mut Option<(i64, usize)>) -> Option<DataFrame> {
289 if !self.spill_partitions.spilled {
290 Some(self.inner_map.finalize(slice))
291 } else {
292 None
293 }
294 }
295
296 pub(super) fn get_all_spilled(&mut self) -> impl Iterator<Item = (usize, SpillPayload)> + '_ {
297 self.spill_partitions.get_all_spilled()
298 }
299}