polars_core/series/ops/
reshape.rs

1use std::borrow::Cow;
2
3use arrow::array::*;
4use arrow::bitmap::Bitmap;
5use arrow::offset::{Offsets, OffsetsBuffer};
6use polars_compute::gather::sublist::list::array_to_unit_list;
7use polars_error::{PolarsResult, polars_bail, polars_ensure};
8use polars_utils::format_tuple;
9
10use crate::chunked_array::builder::get_list_builder;
11use crate::datatypes::{DataType, ListChunked};
12use crate::prelude::{IntoSeries, Series, *};
13
14fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {
15    let mut ca = ListChunked::from_chunk_iter(
16        name,
17        s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),
18    );
19
20    ca.set_inner_dtype(s.dtype().clone());
21    ca.set_fast_explode();
22    ca.into_series()
23}
24
25impl Series {
26    /// Recurse nested types until we are at the leaf array.
27    pub fn get_leaf_array(&self) -> Series {
28        let s = self;
29        match s.dtype() {
30            #[cfg(feature = "dtype-array")]
31            DataType::Array(dtype, _) => {
32                let ca = s.array().unwrap();
33                let chunks = ca
34                    .downcast_iter()
35                    .map(|arr| arr.values().clone())
36                    .collect::<Vec<_>>();
37                // Safety: guarded by the type system
38                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
39                    .get_leaf_array()
40            },
41            DataType::List(dtype) => {
42                let ca = s.list().unwrap();
43                let chunks = ca
44                    .downcast_iter()
45                    .map(|arr| arr.values().clone())
46                    .collect::<Vec<_>>();
47                // Safety: guarded by the type system
48                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
49                    .get_leaf_array()
50            },
51            _ => s.clone(),
52        }
53    }
54
55    /// TODO: Move this somewhere else?
56    pub fn list_offsets_and_validities_recursive(
57        &self,
58    ) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {
59        let mut offsets = vec![];
60        let mut validities = vec![];
61
62        let mut s = self.rechunk();
63
64        while let DataType::List(_) = s.dtype() {
65            let ca = s.list().unwrap();
66            offsets.push(ca.offsets().unwrap());
67            validities.push(ca.rechunk_validity());
68            s = ca.get_inner();
69        }
70
71        (offsets, validities)
72    }
73
74    /// Convert the values of this Series to a ListChunked with a length of 1,
75    /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
76    pub fn implode(&self) -> PolarsResult<ListChunked> {
77        let s = self;
78        let s = s.rechunk();
79        let values = s.array_ref(0);
80
81        let offsets = vec![0i64, values.len() as i64];
82        let inner_type = s.dtype();
83
84        let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
85
86        // SAFETY: offsets are correct.
87        let arr = unsafe {
88            ListArray::new(
89                dtype,
90                Offsets::new_unchecked(offsets).into(),
91                values.clone(),
92                None,
93            )
94        };
95
96        let mut ca = ListChunked::with_chunk(s.name().clone(), arr);
97        unsafe { ca.to_logical(inner_type.clone()) };
98        ca.set_fast_explode();
99        Ok(ca)
100    }
101
102    #[cfg(feature = "dtype-array")]
103    pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
104        polars_ensure!(
105            !dimensions.is_empty(),
106            InvalidOperation: "at least one dimension must be specified"
107        );
108
109        let leaf_array = self
110            .trim_lists_to_normalized_offsets()
111            .as_ref()
112            .unwrap_or(self)
113            .get_leaf_array()
114            .rechunk();
115        let size = leaf_array.len();
116
117        let mut total_dim_size = 1;
118        let mut num_infers = 0;
119        for &dim in dimensions {
120            match dim {
121                ReshapeDimension::Infer => num_infers += 1,
122                ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,
123            }
124        }
125
126        polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
127
128        if size == 0 {
129            polars_ensure!(
130                num_infers > 0 || total_dim_size == 0,
131                InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",
132                format_tuple!(dimensions),
133            );
134
135            let mut prev_arrow_dtype = leaf_array
136                .dtype()
137                .to_physical()
138                .to_arrow(CompatLevel::newest());
139            let mut prev_dtype = leaf_array.dtype().clone();
140            let mut prev_array = leaf_array.chunks()[0].clone();
141
142            // @NOTE: We need to collect the iterator here because it is lazily processed.
143            let mut current_length = dimensions[0].get_or_infer(0);
144            let len_iter = dimensions[1..]
145                .iter()
146                .map(|d| {
147                    let length = current_length as usize;
148                    current_length *= d.get_or_infer(0);
149                    length
150                })
151                .collect::<Vec<_>>();
152
153            // We pop the outer dimension as that is the height of the series.
154            for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {
155                // Infer dimension if needed
156                let dim = dim.get_or_infer(0);
157                prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
158                prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
159
160                prev_array =
161                    FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)
162                        .boxed();
163            }
164
165            return Ok(unsafe {
166                Series::from_chunks_and_dtype_unchecked(
167                    leaf_array.name().clone(),
168                    vec![prev_array],
169                    &prev_dtype,
170                )
171            });
172        }
173
174        polars_ensure!(
175            total_dim_size > 0,
176            InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",
177            format_tuple!(dimensions)
178        );
179
180        polars_ensure!(
181            size.is_multiple_of(total_dim_size),
182            InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
183        );
184
185        let leaf_array = leaf_array.rechunk();
186        let mut prev_arrow_dtype = leaf_array
187            .dtype()
188            .to_physical()
189            .to_arrow(CompatLevel::newest());
190        let mut prev_dtype = leaf_array.dtype().clone();
191        let mut prev_array = leaf_array.chunks()[0].clone();
192
193        // We pop the outer dimension as that is the height of the series.
194        for dim in dimensions[1..].iter().rev() {
195            // Infer dimension if needed
196            let dim = dim.get_or_infer((size / total_dim_size) as u64);
197            prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
198            prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
199
200            prev_array = FixedSizeListArray::new(
201                prev_arrow_dtype.clone(),
202                prev_array.len() / dim as usize,
203                prev_array,
204                None,
205            )
206            .boxed();
207        }
208        Ok(unsafe {
209            Series::from_chunks_and_dtype_unchecked(
210                leaf_array.name().clone(),
211                vec![prev_array],
212                &prev_dtype,
213            )
214        })
215    }
216
217    pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
218        polars_ensure!(
219            !dimensions.is_empty(),
220            InvalidOperation: "at least one dimension must be specified"
221        );
222
223        let s = self;
224        let s = if let DataType::List(_) = s.dtype() {
225            Cow::Owned(s.explode(true)?)
226        } else {
227            Cow::Borrowed(s)
228        };
229
230        let s_ref = s.as_ref();
231
232        // let dimensions = dimensions.to_vec();
233
234        match dimensions.len() {
235            1 => {
236                polars_ensure!(
237                    dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),
238                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
239                );
240                Ok(s_ref.clone())
241            },
242            2 => {
243                let rows = dimensions[0];
244                let cols = dimensions[1];
245
246                if s_ref.is_empty() {
247                    if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
248                        let s = reshape_fast_path(s.name().clone(), s_ref);
249                        return Ok(s);
250                    } else {
251                        polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
252                    }
253                }
254
255                use ReshapeDimension as RD;
256                // Infer dimension.
257
258                let (rows, cols) = match (rows, cols) {
259                    (RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
260                        (s_ref.len() as u64 / cols.get(), cols.get())
261                    },
262                    (RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
263                        (rows.get(), s_ref.len() as u64 / rows.get())
264                    },
265                    (RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
266                    (RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
267                    _ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
268                };
269
270                // Fast path, we can create a unit list so we only allocate offsets.
271                if rows as usize == s_ref.len() && cols == 1 {
272                    let s = reshape_fast_path(s.name().clone(), s_ref);
273                    return Ok(s);
274                }
275
276                polars_ensure!(
277                    (rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,
278                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
279                );
280
281                let mut builder =
282                    get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());
283
284                let mut offset = 0u64;
285                for _ in 0..rows {
286                    let row = s_ref.slice(offset as i64, cols as usize);
287                    builder.append_series(&row).unwrap();
288                    offset += cols;
289                }
290                Ok(builder.finish().into_series())
291            },
292            _ => {
293                polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");
294            },
295        }
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302    use crate::prelude::*;
303
304    #[test]
305    fn test_to_list() -> PolarsResult<()> {
306        let s = Series::new("a".into(), &[1, 2, 3]);
307
308        let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());
309        builder.append_series(&s).unwrap();
310        let expected = builder.finish();
311
312        let out = s.implode()?;
313        assert!(expected.into_series().equals(&out.into_series()));
314
315        Ok(())
316    }
317
318    #[test]
319    fn test_reshape() -> PolarsResult<()> {
320        let s = Series::new("a".into(), &[1, 2, 3, 4]);
321
322        for (dims, list_len) in [
323            (&[-1, 1], 4),
324            (&[4, 1], 4),
325            (&[2, 2], 2),
326            (&[-1, 2], 2),
327            (&[2, -1], 2),
328        ] {
329            let dims = dims
330                .iter()
331                .map(|&v| ReshapeDimension::new(v))
332                .collect::<Vec<_>>();
333            let out = s.reshape_list(&dims)?;
334            assert_eq!(out.len(), list_len);
335            assert!(matches!(out.dtype(), DataType::List(_)));
336            assert_eq!(out.explode(false)?.len(), 4);
337        }
338
339        Ok(())
340    }
341}