polars_core/chunked_array/ops/
bit_repr.rs

1use arrow::buffer::Buffer;
2use polars_error::feature_gated;
3
4use crate::prelude::*;
5use crate::series::BitRepr;
6
7/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size
8/// and alignment.
9fn reinterpret_chunked_array<T: PolarsNumericType, U: PolarsNumericType>(
10    ca: &ChunkedArray<T>,
11) -> ChunkedArray<U> {
12    assert!(size_of::<T::Native>() == size_of::<U::Native>());
13    assert!(align_of::<T::Native>() == align_of::<U::Native>());
14
15    let chunks = ca.downcast_iter().map(|array| {
16        let buf = array.values().clone();
17        // SAFETY: we checked that the size and alignment matches.
18        #[allow(clippy::transmute_undefined_repr)]
19        let reinterpreted_buf =
20            unsafe { std::mem::transmute::<Buffer<T::Native>, Buffer<U::Native>>(buf) };
21        PrimitiveArray::from_data_default(reinterpreted_buf, array.validity().cloned())
22    });
23
24    ChunkedArray::from_chunk_iter(ca.name().clone(), chunks)
25}
26
27/// Reinterprets the type of a [`ListChunked`]. T and U must have the same size
28/// and alignment.
29#[cfg(feature = "reinterpret")]
30fn reinterpret_list_chunked<T: PolarsNumericType, U: PolarsNumericType>(
31    ca: &ListChunked,
32) -> ListChunked {
33    assert!(size_of::<T::Native>() == size_of::<U::Native>());
34    assert!(align_of::<T::Native>() == align_of::<U::Native>());
35
36    let chunks = ca.downcast_iter().map(|array| {
37        let inner_arr = array
38            .values()
39            .as_any()
40            .downcast_ref::<PrimitiveArray<T::Native>>()
41            .unwrap();
42        // SAFETY: we checked that the size and alignment matches.
43        #[allow(clippy::transmute_undefined_repr)]
44        let reinterpreted_buf = unsafe {
45            std::mem::transmute::<Buffer<T::Native>, Buffer<U::Native>>(inner_arr.values().clone())
46        };
47        let pa =
48            PrimitiveArray::from_data_default(reinterpreted_buf, inner_arr.validity().cloned());
49        LargeListArray::new(
50            DataType::List(Box::new(U::get_static_dtype())).to_arrow(CompatLevel::newest()),
51            array.offsets().clone(),
52            pa.to_boxed(),
53            array.validity().cloned(),
54        )
55    });
56
57    ListChunked::from_chunk_iter(ca.name().clone(), chunks)
58}
59
60#[cfg(all(feature = "reinterpret", feature = "dtype-i16", feature = "dtype-u16"))]
61impl Reinterpret for Int16Chunked {
62    fn reinterpret_signed(&self) -> Series {
63        self.clone().into_series()
64    }
65
66    fn reinterpret_unsigned(&self) -> Series {
67        reinterpret_chunked_array::<_, UInt16Type>(self).into_series()
68    }
69}
70
71#[cfg(all(feature = "reinterpret", feature = "dtype-u16", feature = "dtype-i16"))]
72impl Reinterpret for UInt16Chunked {
73    fn reinterpret_signed(&self) -> Series {
74        reinterpret_chunked_array::<_, Int16Type>(self).into_series()
75    }
76
77    fn reinterpret_unsigned(&self) -> Series {
78        self.clone().into_series()
79    }
80}
81
82#[cfg(all(feature = "reinterpret", feature = "dtype-i8", feature = "dtype-u8"))]
83impl Reinterpret for Int8Chunked {
84    fn reinterpret_signed(&self) -> Series {
85        self.clone().into_series()
86    }
87
88    fn reinterpret_unsigned(&self) -> Series {
89        reinterpret_chunked_array::<_, UInt8Type>(self).into_series()
90    }
91}
92
93#[cfg(all(feature = "reinterpret", feature = "dtype-u8", feature = "dtype-i8"))]
94impl Reinterpret for UInt8Chunked {
95    fn reinterpret_signed(&self) -> Series {
96        reinterpret_chunked_array::<_, Int8Type>(self).into_series()
97    }
98
99    fn reinterpret_unsigned(&self) -> Series {
100        self.clone().into_series()
101    }
102}
103
104impl<T> ToBitRepr for ChunkedArray<T>
105where
106    T: PolarsNumericType,
107{
108    fn to_bit_repr(&self) -> BitRepr {
109        match size_of::<T::Native>() {
110            16 => {
111                feature_gated!("dtype-i128", {
112                    if matches!(self.dtype(), DataType::Int128) {
113                        let ca = self.clone();
114                        // Convince the compiler we are this type. This keeps flags.
115                        return BitRepr::I128(unsafe {
116                            std::mem::transmute::<ChunkedArray<T>, Int128Chunked>(ca)
117                        });
118                    }
119
120                    BitRepr::I128(reinterpret_chunked_array(self))
121                })
122            },
123
124            8 => {
125                if matches!(self.dtype(), DataType::UInt64) {
126                    let ca = self.clone();
127                    // Convince the compiler we are this type. This keeps flags.
128                    return BitRepr::U64(unsafe {
129                        std::mem::transmute::<ChunkedArray<T>, UInt64Chunked>(ca)
130                    });
131                }
132
133                BitRepr::U64(reinterpret_chunked_array(self))
134            },
135
136            byte_size => {
137                assert!(byte_size <= 4);
138
139                BitRepr::U32(if byte_size == 4 {
140                    if matches!(self.dtype(), DataType::UInt32) {
141                        let ca = self.clone();
142                        // Convince the compiler we are this type. This preserves flags.
143                        return BitRepr::U32(unsafe {
144                            std::mem::transmute::<ChunkedArray<T>, UInt32Chunked>(ca)
145                        });
146                    }
147
148                    reinterpret_chunked_array(self)
149                } else {
150                    // SAFETY: an unchecked cast to uint32 (which has no invariants) is
151                    // always sound.
152                    unsafe {
153                        self.cast_unchecked(&DataType::UInt32)
154                            .unwrap()
155                            .u32()
156                            .unwrap()
157                            .clone()
158                    }
159                })
160            },
161        }
162    }
163}
164
165#[cfg(feature = "reinterpret")]
166impl Reinterpret for UInt64Chunked {
167    fn reinterpret_signed(&self) -> Series {
168        let signed: Int64Chunked = reinterpret_chunked_array(self);
169        signed.into_series()
170    }
171
172    fn reinterpret_unsigned(&self) -> Series {
173        self.clone().into_series()
174    }
175}
176#[cfg(feature = "reinterpret")]
177impl Reinterpret for Int64Chunked {
178    fn reinterpret_signed(&self) -> Series {
179        self.clone().into_series()
180    }
181
182    fn reinterpret_unsigned(&self) -> Series {
183        let BitRepr::U64(b) = self.to_bit_repr() else {
184            unreachable!()
185        };
186        b.into_series()
187    }
188}
189
190#[cfg(feature = "reinterpret")]
191impl Reinterpret for UInt32Chunked {
192    fn reinterpret_signed(&self) -> Series {
193        let signed: Int32Chunked = reinterpret_chunked_array(self);
194        signed.into_series()
195    }
196
197    fn reinterpret_unsigned(&self) -> Series {
198        self.clone().into_series()
199    }
200}
201
202#[cfg(feature = "reinterpret")]
203impl Reinterpret for Int32Chunked {
204    fn reinterpret_signed(&self) -> Series {
205        self.clone().into_series()
206    }
207
208    fn reinterpret_unsigned(&self) -> Series {
209        let BitRepr::U32(b) = self.to_bit_repr() else {
210            unreachable!()
211        };
212        b.into_series()
213    }
214}
215
216#[cfg(feature = "reinterpret")]
217impl Reinterpret for Float32Chunked {
218    fn reinterpret_signed(&self) -> Series {
219        reinterpret_chunked_array::<_, Int32Type>(self).into_series()
220    }
221
222    fn reinterpret_unsigned(&self) -> Series {
223        reinterpret_chunked_array::<_, UInt32Type>(self).into_series()
224    }
225}
226
227#[cfg(feature = "reinterpret")]
228impl Reinterpret for ListChunked {
229    fn reinterpret_signed(&self) -> Series {
230        match self.inner_dtype() {
231            DataType::Float32 => reinterpret_list_chunked::<Float32Type, Int32Type>(self),
232            DataType::Float64 => reinterpret_list_chunked::<Float64Type, Int64Type>(self),
233            _ => unimplemented!(),
234        }
235        .into_series()
236    }
237
238    fn reinterpret_unsigned(&self) -> Series {
239        match self.inner_dtype() {
240            DataType::Float32 => reinterpret_list_chunked::<Float32Type, UInt32Type>(self),
241            DataType::Float64 => reinterpret_list_chunked::<Float64Type, UInt64Type>(self),
242            _ => unimplemented!(),
243        }
244        .into_series()
245    }
246}
247
248#[cfg(feature = "reinterpret")]
249impl Reinterpret for Float64Chunked {
250    fn reinterpret_signed(&self) -> Series {
251        reinterpret_chunked_array::<_, Int64Type>(self).into_series()
252    }
253
254    fn reinterpret_unsigned(&self) -> Series {
255        reinterpret_chunked_array::<_, UInt64Type>(self).into_series()
256    }
257}
258
259impl UInt64Chunked {
260    #[doc(hidden)]
261    pub fn _reinterpret_float(&self) -> Float64Chunked {
262        reinterpret_chunked_array(self)
263    }
264}
265impl UInt32Chunked {
266    #[doc(hidden)]
267    pub fn _reinterpret_float(&self) -> Float32Chunked {
268        reinterpret_chunked_array(self)
269    }
270}
271
272/// Used to save compilation paths. Use carefully. Although this is safe,
273/// if misused it can lead to incorrect results.
274impl Float32Chunked {
275    pub fn apply_as_ints<F>(&self, f: F) -> Series
276    where
277        F: Fn(&Series) -> Series,
278    {
279        let BitRepr::U32(s) = self.to_bit_repr() else {
280            unreachable!()
281        };
282        let s = s.into_series();
283        let out = f(&s);
284        let out = out.u32().unwrap();
285        out._reinterpret_float().into()
286    }
287}
288impl Float64Chunked {
289    pub fn apply_as_ints<F>(&self, f: F) -> Series
290    where
291        F: Fn(&Series) -> Series,
292    {
293        let BitRepr::U64(s) = self.to_bit_repr() else {
294            unreachable!()
295        };
296        let s = s.into_series();
297        let out = f(&s);
298        let out = out.u64().unwrap();
299        out._reinterpret_float().into()
300    }
301}