polars_arrow/datatypes/
reshape.rs

1use std::fmt;
2use std::hash::Hash;
3use std::num::NonZeroU64;
4
5#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[repr(transparent)]
8pub struct Dimension(NonZeroU64);
9
10/// A dimension in a reshape.
11///
12/// Any dimension smaller than 0 is seen as an `infer`.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub enum ReshapeDimension {
16    Infer,
17    Specified(Dimension),
18}
19
20impl fmt::Debug for Dimension {
21    #[inline]
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        self.get().fmt(f)
24    }
25}
26
27impl fmt::Display for ReshapeDimension {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            Self::Infer => f.write_str("inferred"),
31            Self::Specified(v) => v.get().fmt(f),
32        }
33    }
34}
35
36impl Hash for ReshapeDimension {
37    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
38        self.to_repr().hash(state)
39    }
40}
41
42impl Dimension {
43    #[inline]
44    pub const fn new(v: u64) -> Self {
45        assert!(v <= i64::MAX as u64);
46
47        // SAFETY: Bounds check done before
48        let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) };
49        Self(dim)
50    }
51
52    #[inline]
53    pub const fn get(self) -> u64 {
54        self.0.get() - 1
55    }
56}
57
58impl ReshapeDimension {
59    #[inline]
60    pub const fn new(v: i64) -> Self {
61        if v < 0 {
62            Self::Infer
63        } else {
64            // SAFETY: We have bounds checked for -1
65            let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) };
66            Self::Specified(Dimension(dim))
67        }
68    }
69
70    #[inline]
71    fn to_repr(self) -> u64 {
72        match self {
73            Self::Infer => 0,
74            Self::Specified(dim) => dim.0.get(),
75        }
76    }
77
78    #[inline]
79    pub const fn get(self) -> Option<u64> {
80        match self {
81            ReshapeDimension::Infer => None,
82            ReshapeDimension::Specified(dim) => Some(dim.get()),
83        }
84    }
85
86    #[inline]
87    pub const fn get_or_infer(self, inferred: u64) -> u64 {
88        match self {
89            ReshapeDimension::Infer => inferred,
90            ReshapeDimension::Specified(dim) => dim.get(),
91        }
92    }
93
94    #[inline]
95    pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 {
96        match self {
97            ReshapeDimension::Infer => f(),
98            ReshapeDimension::Specified(dim) => dim.get(),
99        }
100    }
101
102    pub const fn new_dimension(dimension: u64) -> ReshapeDimension {
103        Self::Specified(Dimension::new(dimension))
104    }
105}
106
107impl TryFrom<i64> for Dimension {
108    type Error = ();
109
110    #[inline]
111    fn try_from(value: i64) -> Result<Self, Self::Error> {
112        let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else {
113            return Err(());
114        };
115
116        Ok(v)
117    }
118}