polars_core/chunked_array/ops/
set.rs

1use arrow::bitmap::MutableBitmap;
2use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask};
3
4use crate::prelude::*;
5use crate::utils::align_chunks_binary;
6
7macro_rules! impl_scatter_with {
8    ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{
9        let mut ca_iter = $self.into_iter().enumerate();
10
11        for current_idx in $idx.into_iter().map(|i| i as usize) {
12            polars_ensure!(current_idx < $self.len(), oob = current_idx, $self.len());
13            while let Some((cnt_idx, opt_val)) = ca_iter.next() {
14                if cnt_idx == current_idx {
15                    $builder.append_option($f(opt_val));
16                    break;
17                } else {
18                    $builder.append_option(opt_val);
19                }
20            }
21        }
22        // the last idx is probably not the last value so we finish the iterator
23        while let Some((_, opt_val)) = ca_iter.next() {
24            $builder.append_option(opt_val);
25        }
26
27        let ca = $builder.finish();
28        Ok(ca)
29    }};
30}
31
32macro_rules! check_bounds {
33    ($self:ident, $mask:ident) => {{
34        polars_ensure!(
35            $self.len() == $mask.len(),
36            ShapeMismatch: "invalid mask in `get` operation: shape doesn't match array's shape"
37        );
38    }};
39}
40
41impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray<T>
42where
43    T: PolarsNumericType,
44{
45    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
46        &'a self,
47        idx: I,
48        value: Option<T::Native>,
49    ) -> PolarsResult<Self> {
50        if !self.has_nulls() {
51            if let Some(value) = value {
52                // Fast path uses kernel.
53                if self.chunks.len() == 1 {
54                    let arr = scatter_single_non_null(
55                        self.downcast_iter().next().unwrap(),
56                        idx,
57                        value,
58                        T::get_dtype().to_arrow(CompatLevel::newest()),
59                    )?;
60                    return Ok(Self::with_chunk(self.name().clone(), arr));
61                }
62                // Other fast path. Slightly slower as it does not do a memcpy.
63                else {
64                    let mut av = self.into_no_null_iter().collect::<Vec<_>>();
65                    let data = av.as_mut_slice();
66
67                    idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
68                        let val = data
69                            .get_mut(idx as usize)
70                            .ok_or_else(|| polars_err!(oob = idx as usize, self.len()))?;
71                        *val = value;
72                        Ok(())
73                    })?;
74                    return Ok(Self::from_vec(self.name().clone(), av));
75                }
76            }
77        }
78        self.scatter_with(idx, |_| value)
79    }
80
81    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
82        &'a self,
83        idx: I,
84        f: F,
85    ) -> PolarsResult<Self>
86    where
87        F: Fn(Option<T::Native>) -> Option<T::Native>,
88    {
89        let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
90        impl_scatter_with!(self, builder, idx, f)
91    }
92
93    fn set(&'a self, mask: &BooleanChunked, value: Option<T::Native>) -> PolarsResult<Self> {
94        check_bounds!(self, mask);
95
96        // Fast path uses the kernel in polars-arrow.
97        if let (Some(value), false) = (value, mask.has_nulls()) {
98            let (left, mask) = align_chunks_binary(self, mask);
99
100            // Apply binary kernel.
101            let chunks = left
102                .downcast_iter()
103                .zip(mask.downcast_iter())
104                .map(|(arr, mask)| {
105                    set_with_mask(
106                        arr,
107                        mask,
108                        value,
109                        T::get_dtype().to_arrow(CompatLevel::newest()),
110                    )
111                });
112            Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks))
113        } else {
114            // slow path, could be optimized.
115            let ca = mask
116                .into_iter()
117                .zip(self)
118                .map(|(mask_val, opt_val)| match mask_val {
119                    Some(true) => value,
120                    _ => opt_val,
121                })
122                .collect_trusted::<Self>()
123                .with_name(self.name().clone());
124            Ok(ca)
125        }
126    }
127}
128
129impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
130    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
131        &'a self,
132        idx: I,
133        value: Option<bool>,
134    ) -> PolarsResult<Self> {
135        self.scatter_with(idx, |_| value)
136    }
137
138    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
139        &'a self,
140        idx: I,
141        f: F,
142    ) -> PolarsResult<Self>
143    where
144        F: Fn(Option<bool>) -> Option<bool>,
145    {
146        let mut values = MutableBitmap::with_capacity(self.len());
147        let mut validity = MutableBitmap::with_capacity(self.len());
148
149        for a in self.downcast_iter() {
150            values.extend_from_bitmap(a.values());
151            if let Some(v) = a.validity() {
152                validity.extend_from_bitmap(v)
153            } else {
154                validity.extend_constant(a.len(), true);
155            }
156        }
157
158        for i in idx.into_iter().map(|i| i as usize) {
159            let input = validity.get(i).then(|| values.get(i));
160            validity.set(i, f(input).unwrap_or(false));
161        }
162        let arr = BooleanArray::from_data_default(values.into(), Some(validity.into()));
163        Ok(BooleanChunked::with_chunk(self.name().clone(), arr))
164    }
165
166    fn set(&'a self, mask: &BooleanChunked, value: Option<bool>) -> PolarsResult<Self> {
167        check_bounds!(self, mask);
168        let ca = mask
169            .into_iter()
170            .zip(self)
171            .map(|(mask_val, opt_val)| match mask_val {
172                Some(true) => value,
173                _ => opt_val,
174            })
175            .collect_trusted::<Self>()
176            .with_name(self.name().clone());
177        Ok(ca)
178    }
179}
180
181impl<'a> ChunkSet<'a, &'a str, String> for StringChunked {
182    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
183        &'a self,
184        idx: I,
185        opt_value: Option<&'a str>,
186    ) -> PolarsResult<Self>
187    where
188        Self: Sized,
189    {
190        let idx_iter = idx.into_iter();
191        let mut ca_iter = self.into_iter().enumerate();
192        let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
193
194        for current_idx in idx_iter.into_iter().map(|i| i as usize) {
195            polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
196            for (cnt_idx, opt_val_self) in &mut ca_iter {
197                if cnt_idx == current_idx {
198                    builder.append_option(opt_value);
199                    break;
200                } else {
201                    builder.append_option(opt_val_self);
202                }
203            }
204        }
205        // the last idx is probably not the last value so we finish the iterator
206        for (_, opt_val_self) in ca_iter {
207            builder.append_option(opt_val_self);
208        }
209
210        let ca = builder.finish();
211        Ok(ca)
212    }
213
214    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
215        &'a self,
216        idx: I,
217        f: F,
218    ) -> PolarsResult<Self>
219    where
220        Self: Sized,
221        F: Fn(Option<&'a str>) -> Option<String>,
222    {
223        let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
224        impl_scatter_with!(self, builder, idx, f)
225    }
226
227    fn set(&'a self, mask: &BooleanChunked, value: Option<&'a str>) -> PolarsResult<Self>
228    where
229        Self: Sized,
230    {
231        check_bounds!(self, mask);
232        let ca = mask
233            .into_iter()
234            .zip(self)
235            .map(|(mask_val, opt_val)| match mask_val {
236                Some(true) => value,
237                _ => opt_val,
238            })
239            .collect_trusted::<Self>()
240            .with_name(self.name().clone());
241        Ok(ca)
242    }
243}
244
245impl<'a> ChunkSet<'a, &'a [u8], Vec<u8>> for BinaryChunked {
246    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
247        &'a self,
248        idx: I,
249        opt_value: Option<&'a [u8]>,
250    ) -> PolarsResult<Self>
251    where
252        Self: Sized,
253    {
254        let mut ca_iter = self.into_iter().enumerate();
255        let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
256
257        for current_idx in idx.into_iter().map(|i| i as usize) {
258            polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
259            for (cnt_idx, opt_val_self) in &mut ca_iter {
260                if cnt_idx == current_idx {
261                    builder.append_option(opt_value);
262                    break;
263                } else {
264                    builder.append_option(opt_val_self);
265                }
266            }
267        }
268        // the last idx is probably not the last value so we finish the iterator
269        for (_, opt_val_self) in ca_iter {
270            builder.append_option(opt_val_self);
271        }
272
273        let ca = builder.finish();
274        Ok(ca)
275    }
276
277    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
278        &'a self,
279        idx: I,
280        f: F,
281    ) -> PolarsResult<Self>
282    where
283        Self: Sized,
284        F: Fn(Option<&'a [u8]>) -> Option<Vec<u8>>,
285    {
286        let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
287        impl_scatter_with!(self, builder, idx, f)
288    }
289
290    fn set(&'a self, mask: &BooleanChunked, value: Option<&'a [u8]>) -> PolarsResult<Self>
291    where
292        Self: Sized,
293    {
294        check_bounds!(self, mask);
295        let ca = mask
296            .into_iter()
297            .zip(self)
298            .map(|(mask_val, opt_val)| match mask_val {
299                Some(true) => value,
300                _ => opt_val,
301            })
302            .collect_trusted::<Self>()
303            .with_name(self.name().clone());
304        Ok(ca)
305    }
306}
307
308#[cfg(test)]
309mod test {
310    use crate::prelude::*;
311
312    #[test]
313    fn test_set() {
314        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
315        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
316        let ca = ca.set(&mask, Some(5)).unwrap();
317        assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
318
319        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
320        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]);
321        let ca = ca.set(&mask, Some(5)).unwrap();
322        assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
323
324        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
325        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]);
326        let ca = ca.set(&mask, Some(5)).unwrap();
327        assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
328
329        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
330        let mask = BooleanChunked::new(
331            PlSmallStr::from_static("mask"),
332            &[Some(true), Some(false), None],
333        );
334        let ca = ca.set(&mask, Some(5)).unwrap();
335        assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]);
336
337        let ca = ca.scatter_single(vec![0, 1], Some(10)).unwrap();
338        assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]);
339
340        assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err());
341
342        // test booleans
343        let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]);
344        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
345        let ca = ca.set(&mask, None).unwrap();
346        assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]);
347
348        // test string
349        let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]);
350        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
351        let ca = ca.set(&mask, Some("bar")).unwrap();
352        assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]);
353    }
354
355    #[test]
356    fn test_set_null_values() {
357        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]);
358        let mask = BooleanChunked::new(
359            PlSmallStr::from_static("mask"),
360            &[Some(false), Some(true), None],
361        );
362        let ca = ca.set(&mask, Some(2)).unwrap();
363        assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
364
365        let ca = StringChunked::new(
366            PlSmallStr::from_static("a"),
367            &[Some("foo"), None, Some("bar")],
368        );
369        let ca = ca.set(&mask, Some("foo")).unwrap();
370        assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]);
371
372        let ca = BooleanChunked::new(
373            PlSmallStr::from_static("a"),
374            &[Some(false), None, Some(true)],
375        );
376        let ca = ca.set(&mask, Some(true)).unwrap();
377        assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]);
378    }
379}