polars_arrow/datatypes/
reshape.rs

1use std::fmt;
2use std::hash::Hash;
3use std::num::NonZeroU64;
4
5#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
8#[repr(transparent)]
9pub struct Dimension(NonZeroU64);
10
11/// A dimension in a reshape.
12///
13/// Any dimension smaller than 0 is seen as an `infer`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
17pub enum ReshapeDimension {
18    Infer,
19    Specified(Dimension),
20}
21
22impl fmt::Debug for Dimension {
23    #[inline]
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        self.get().fmt(f)
26    }
27}
28
29impl fmt::Display for ReshapeDimension {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::Infer => f.write_str("inferred"),
33            Self::Specified(v) => v.get().fmt(f),
34        }
35    }
36}
37
38impl Hash for ReshapeDimension {
39    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
40        self.to_repr().hash(state)
41    }
42}
43
44impl Dimension {
45    #[inline]
46    pub const fn new(v: u64) -> Self {
47        assert!(v <= i64::MAX as u64);
48
49        // SAFETY: Bounds check done before
50        let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) };
51        Self(dim)
52    }
53
54    #[inline]
55    pub const fn get(self) -> u64 {
56        self.0.get() - 1
57    }
58}
59
60impl ReshapeDimension {
61    #[inline]
62    pub const fn new(v: i64) -> Self {
63        if v < 0 {
64            Self::Infer
65        } else {
66            // SAFETY: We have bounds checked for -1
67            let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) };
68            Self::Specified(Dimension(dim))
69        }
70    }
71
72    #[inline]
73    fn to_repr(self) -> u64 {
74        match self {
75            Self::Infer => 0,
76            Self::Specified(dim) => dim.0.get(),
77        }
78    }
79
80    #[inline]
81    pub const fn get(self) -> Option<u64> {
82        match self {
83            ReshapeDimension::Infer => None,
84            ReshapeDimension::Specified(dim) => Some(dim.get()),
85        }
86    }
87
88    #[inline]
89    pub const fn get_or_infer(self, inferred: u64) -> u64 {
90        match self {
91            ReshapeDimension::Infer => inferred,
92            ReshapeDimension::Specified(dim) => dim.get(),
93        }
94    }
95
96    #[inline]
97    pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 {
98        match self {
99            ReshapeDimension::Infer => f(),
100            ReshapeDimension::Specified(dim) => dim.get(),
101        }
102    }
103
104    pub const fn new_dimension(dimension: u64) -> ReshapeDimension {
105        Self::Specified(Dimension::new(dimension))
106    }
107}
108
109impl TryFrom<i64> for Dimension {
110    type Error = ();
111
112    #[inline]
113    fn try_from(value: i64) -> Result<Self, Self::Error> {
114        let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else {
115            return Err(());
116        };
117
118        Ok(v)
119    }
120}