1use arrow::bitmap::MutableBitmap;
2use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask};
3
4use crate::prelude::*;
5use crate::utils::align_chunks_binary;
6
7macro_rules! impl_scatter_with {
8 ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{
9 let mut ca_iter = $self.into_iter().enumerate();
10
11 for current_idx in $idx.into_iter().map(|i| i as usize) {
12 polars_ensure!(current_idx < $self.len(), oob = current_idx, $self.len());
13 while let Some((cnt_idx, opt_val)) = ca_iter.next() {
14 if cnt_idx == current_idx {
15 $builder.append_option($f(opt_val));
16 break;
17 } else {
18 $builder.append_option(opt_val);
19 }
20 }
21 }
22 while let Some((_, opt_val)) = ca_iter.next() {
24 $builder.append_option(opt_val);
25 }
26
27 let ca = $builder.finish();
28 Ok(ca)
29 }};
30}
31
32macro_rules! check_bounds {
33 ($self:ident, $mask:ident) => {{
34 polars_ensure!(
35 $self.len() == $mask.len(),
36 ShapeMismatch: "invalid mask in `get` operation: shape doesn't match array's shape"
37 );
38 }};
39}
40
41impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray<T>
42where
43 T: PolarsNumericType,
44{
45 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
46 &'a self,
47 idx: I,
48 value: Option<T::Native>,
49 ) -> PolarsResult<Self> {
50 if !self.has_nulls() {
51 if let Some(value) = value {
52 if self.chunks.len() == 1 {
54 let arr = scatter_single_non_null(
55 self.downcast_iter().next().unwrap(),
56 idx,
57 value,
58 T::get_dtype().to_arrow(CompatLevel::newest()),
59 )?;
60 return Ok(Self::with_chunk(self.name().clone(), arr));
61 }
62 else {
64 let mut av = self.into_no_null_iter().collect::<Vec<_>>();
65 let data = av.as_mut_slice();
66
67 idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
68 let val = data
69 .get_mut(idx as usize)
70 .ok_or_else(|| polars_err!(oob = idx as usize, self.len()))?;
71 *val = value;
72 Ok(())
73 })?;
74 return Ok(Self::from_vec(self.name().clone(), av));
75 }
76 }
77 }
78 self.scatter_with(idx, |_| value)
79 }
80
81 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
82 &'a self,
83 idx: I,
84 f: F,
85 ) -> PolarsResult<Self>
86 where
87 F: Fn(Option<T::Native>) -> Option<T::Native>,
88 {
89 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
90 impl_scatter_with!(self, builder, idx, f)
91 }
92
93 fn set(&'a self, mask: &BooleanChunked, value: Option<T::Native>) -> PolarsResult<Self> {
94 check_bounds!(self, mask);
95
96 if let (Some(value), false) = (value, mask.has_nulls()) {
98 let (left, mask) = align_chunks_binary(self, mask);
99
100 let chunks = left
102 .downcast_iter()
103 .zip(mask.downcast_iter())
104 .map(|(arr, mask)| {
105 set_with_mask(
106 arr,
107 mask,
108 value,
109 T::get_dtype().to_arrow(CompatLevel::newest()),
110 )
111 });
112 Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks))
113 } else {
114 let ca = mask
116 .into_iter()
117 .zip(self)
118 .map(|(mask_val, opt_val)| match mask_val {
119 Some(true) => value,
120 _ => opt_val,
121 })
122 .collect_trusted::<Self>()
123 .with_name(self.name().clone());
124 Ok(ca)
125 }
126 }
127}
128
129impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
130 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
131 &'a self,
132 idx: I,
133 value: Option<bool>,
134 ) -> PolarsResult<Self> {
135 self.scatter_with(idx, |_| value)
136 }
137
138 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
139 &'a self,
140 idx: I,
141 f: F,
142 ) -> PolarsResult<Self>
143 where
144 F: Fn(Option<bool>) -> Option<bool>,
145 {
146 let mut values = MutableBitmap::with_capacity(self.len());
147 let mut validity = MutableBitmap::with_capacity(self.len());
148
149 for a in self.downcast_iter() {
150 values.extend_from_bitmap(a.values());
151 if let Some(v) = a.validity() {
152 validity.extend_from_bitmap(v)
153 } else {
154 validity.extend_constant(a.len(), true);
155 }
156 }
157
158 for i in idx.into_iter().map(|i| i as usize) {
159 let input = validity.get(i).then(|| values.get(i));
160 validity.set(i, f(input).unwrap_or(false));
161 }
162 let arr = BooleanArray::from_data_default(values.into(), Some(validity.into()));
163 Ok(BooleanChunked::with_chunk(self.name().clone(), arr))
164 }
165
166 fn set(&'a self, mask: &BooleanChunked, value: Option<bool>) -> PolarsResult<Self> {
167 check_bounds!(self, mask);
168 let ca = mask
169 .into_iter()
170 .zip(self)
171 .map(|(mask_val, opt_val)| match mask_val {
172 Some(true) => value,
173 _ => opt_val,
174 })
175 .collect_trusted::<Self>()
176 .with_name(self.name().clone());
177 Ok(ca)
178 }
179}
180
181impl<'a> ChunkSet<'a, &'a str, String> for StringChunked {
182 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
183 &'a self,
184 idx: I,
185 opt_value: Option<&'a str>,
186 ) -> PolarsResult<Self>
187 where
188 Self: Sized,
189 {
190 let idx_iter = idx.into_iter();
191 let mut ca_iter = self.into_iter().enumerate();
192 let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
193
194 for current_idx in idx_iter.into_iter().map(|i| i as usize) {
195 polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
196 for (cnt_idx, opt_val_self) in &mut ca_iter {
197 if cnt_idx == current_idx {
198 builder.append_option(opt_value);
199 break;
200 } else {
201 builder.append_option(opt_val_self);
202 }
203 }
204 }
205 for (_, opt_val_self) in ca_iter {
207 builder.append_option(opt_val_self);
208 }
209
210 let ca = builder.finish();
211 Ok(ca)
212 }
213
214 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
215 &'a self,
216 idx: I,
217 f: F,
218 ) -> PolarsResult<Self>
219 where
220 Self: Sized,
221 F: Fn(Option<&'a str>) -> Option<String>,
222 {
223 let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
224 impl_scatter_with!(self, builder, idx, f)
225 }
226
227 fn set(&'a self, mask: &BooleanChunked, value: Option<&'a str>) -> PolarsResult<Self>
228 where
229 Self: Sized,
230 {
231 check_bounds!(self, mask);
232 let ca = mask
233 .into_iter()
234 .zip(self)
235 .map(|(mask_val, opt_val)| match mask_val {
236 Some(true) => value,
237 _ => opt_val,
238 })
239 .collect_trusted::<Self>()
240 .with_name(self.name().clone());
241 Ok(ca)
242 }
243}
244
245impl<'a> ChunkSet<'a, &'a [u8], Vec<u8>> for BinaryChunked {
246 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
247 &'a self,
248 idx: I,
249 opt_value: Option<&'a [u8]>,
250 ) -> PolarsResult<Self>
251 where
252 Self: Sized,
253 {
254 let mut ca_iter = self.into_iter().enumerate();
255 let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
256
257 for current_idx in idx.into_iter().map(|i| i as usize) {
258 polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
259 for (cnt_idx, opt_val_self) in &mut ca_iter {
260 if cnt_idx == current_idx {
261 builder.append_option(opt_value);
262 break;
263 } else {
264 builder.append_option(opt_val_self);
265 }
266 }
267 }
268 for (_, opt_val_self) in ca_iter {
270 builder.append_option(opt_val_self);
271 }
272
273 let ca = builder.finish();
274 Ok(ca)
275 }
276
277 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
278 &'a self,
279 idx: I,
280 f: F,
281 ) -> PolarsResult<Self>
282 where
283 Self: Sized,
284 F: Fn(Option<&'a [u8]>) -> Option<Vec<u8>>,
285 {
286 let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
287 impl_scatter_with!(self, builder, idx, f)
288 }
289
290 fn set(&'a self, mask: &BooleanChunked, value: Option<&'a [u8]>) -> PolarsResult<Self>
291 where
292 Self: Sized,
293 {
294 check_bounds!(self, mask);
295 let ca = mask
296 .into_iter()
297 .zip(self)
298 .map(|(mask_val, opt_val)| match mask_val {
299 Some(true) => value,
300 _ => opt_val,
301 })
302 .collect_trusted::<Self>()
303 .with_name(self.name().clone());
304 Ok(ca)
305 }
306}
307
308#[cfg(test)]
309mod test {
310 use crate::prelude::*;
311
312 #[test]
313 fn test_set() {
314 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
315 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
316 let ca = ca.set(&mask, Some(5)).unwrap();
317 assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
318
319 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
320 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]);
321 let ca = ca.set(&mask, Some(5)).unwrap();
322 assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
323
324 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
325 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]);
326 let ca = ca.set(&mask, Some(5)).unwrap();
327 assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
328
329 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
330 let mask = BooleanChunked::new(
331 PlSmallStr::from_static("mask"),
332 &[Some(true), Some(false), None],
333 );
334 let ca = ca.set(&mask, Some(5)).unwrap();
335 assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]);
336
337 let ca = ca.scatter_single(vec![0, 1], Some(10)).unwrap();
338 assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]);
339
340 assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err());
341
342 let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]);
344 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
345 let ca = ca.set(&mask, None).unwrap();
346 assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]);
347
348 let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]);
350 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
351 let ca = ca.set(&mask, Some("bar")).unwrap();
352 assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]);
353 }
354
355 #[test]
356 fn test_set_null_values() {
357 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]);
358 let mask = BooleanChunked::new(
359 PlSmallStr::from_static("mask"),
360 &[Some(false), Some(true), None],
361 );
362 let ca = ca.set(&mask, Some(2)).unwrap();
363 assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
364
365 let ca = StringChunked::new(
366 PlSmallStr::from_static("a"),
367 &[Some("foo"), None, Some("bar")],
368 );
369 let ca = ca.set(&mask, Some("foo")).unwrap();
370 assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]);
371
372 let ca = BooleanChunked::new(
373 PlSmallStr::from_static("a"),
374 &[Some(false), None, Some(true)],
375 );
376 let ca = ca.set(&mask, Some(true)).unwrap();
377 assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]);
378 }
379}