polars_python/expr/
selector.rs

1use std::hash::{Hash, Hasher};
2use std::sync::Arc;
3
4use polars::prelude::{
5    DataType, DataTypeSelector, Selector, TimeUnit, TimeUnitSet, TimeZone, TimeZoneSet,
6};
7use polars_plan::dsl;
8use pyo3::PyResult;
9use pyo3::exceptions::PyTypeError;
10
11use crate::prelude::Wrap;
12
13#[pyo3::pyclass]
14#[repr(transparent)]
15#[derive(Clone)]
16pub struct PySelector {
17    pub inner: Selector,
18}
19
20impl From<Selector> for PySelector {
21    fn from(inner: Selector) -> Self {
22        Self { inner }
23    }
24}
25
26fn parse_time_unit_set(time_units: Vec<Wrap<TimeUnit>>) -> TimeUnitSet {
27    let mut tu = TimeUnitSet::empty();
28    for v in time_units {
29        match v.0 {
30            TimeUnit::Nanoseconds => tu |= TimeUnitSet::NANO_SECONDS,
31            TimeUnit::Microseconds => tu |= TimeUnitSet::MICRO_SECONDS,
32            TimeUnit::Milliseconds => tu |= TimeUnitSet::MILLI_SECONDS,
33        }
34    }
35    tu
36}
37
38pub fn parse_datatype_selector(selector: PySelector) -> PyResult<DataTypeSelector> {
39    selector.inner.to_dtype_selector().ok_or_else(|| {
40        PyTypeError::new_err(format!(
41            "expected datatype based expression got '{}'",
42            selector.inner
43        ))
44    })
45}
46
47#[cfg(feature = "pymethods")]
48#[pyo3::pymethods]
49impl PySelector {
50    fn union(&self, other: &Self) -> Self {
51        Self {
52            inner: self.inner.clone() | other.inner.clone(),
53        }
54    }
55
56    fn difference(&self, other: &Self) -> Self {
57        Self {
58            inner: self.inner.clone() - other.inner.clone(),
59        }
60    }
61
62    fn exclusive_or(&self, other: &Self) -> Self {
63        Self {
64            inner: self.inner.clone() ^ other.inner.clone(),
65        }
66    }
67
68    fn intersect(&self, other: &Self) -> Self {
69        Self {
70            inner: self.inner.clone() & other.inner.clone(),
71        }
72    }
73
74    #[staticmethod]
75    fn by_dtype(dtypes: Vec<Wrap<DataType>>) -> Self {
76        let dtypes = dtypes.into_iter().map(|x| x.0).collect::<Vec<_>>();
77        dsl::dtype_cols(dtypes).as_selector().into()
78    }
79
80    #[staticmethod]
81    fn by_name(names: Vec<String>, strict: bool) -> Self {
82        dsl::by_name(names, strict).into()
83    }
84
85    #[staticmethod]
86    fn by_index(indices: Vec<i64>, strict: bool) -> Self {
87        Selector::ByIndex {
88            indices: indices.into(),
89            strict,
90        }
91        .into()
92    }
93
94    #[staticmethod]
95    fn first(strict: bool) -> Self {
96        Selector::ByIndex {
97            indices: [0].into(),
98            strict,
99        }
100        .into()
101    }
102
103    #[staticmethod]
104    fn last(strict: bool) -> Self {
105        Selector::ByIndex {
106            indices: [-1].into(),
107            strict,
108        }
109        .into()
110    }
111
112    #[staticmethod]
113    fn matches(pattern: String) -> Self {
114        Selector::Matches(pattern.into()).into()
115    }
116
117    #[staticmethod]
118    fn enum_() -> Self {
119        DataTypeSelector::Enum.as_selector().into()
120    }
121
122    #[staticmethod]
123    fn categorical() -> Self {
124        DataTypeSelector::Categorical.as_selector().into()
125    }
126
127    #[staticmethod]
128    fn nested() -> Self {
129        DataTypeSelector::Nested.as_selector().into()
130    }
131
132    #[staticmethod]
133    fn list(inner_dst: Option<Self>) -> PyResult<Self> {
134        let inner_dst = match inner_dst {
135            None => None,
136            Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),
137        };
138        Ok(DataTypeSelector::List(inner_dst).as_selector().into())
139    }
140
141    #[staticmethod]
142    fn array(inner_dst: Option<Self>, width: Option<usize>) -> PyResult<Self> {
143        let inner_dst = match inner_dst {
144            None => None,
145            Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),
146        };
147        Ok(DataTypeSelector::Array(inner_dst, width)
148            .as_selector()
149            .into())
150    }
151
152    #[staticmethod]
153    fn struct_() -> Self {
154        DataTypeSelector::Struct.as_selector().into()
155    }
156
157    #[staticmethod]
158    fn integer() -> Self {
159        DataTypeSelector::Integer.as_selector().into()
160    }
161
162    #[staticmethod]
163    fn signed_integer() -> Self {
164        DataTypeSelector::SignedInteger.as_selector().into()
165    }
166
167    #[staticmethod]
168    fn unsigned_integer() -> Self {
169        DataTypeSelector::UnsignedInteger.as_selector().into()
170    }
171
172    #[staticmethod]
173    fn float() -> Self {
174        DataTypeSelector::Float.as_selector().into()
175    }
176
177    #[staticmethod]
178    fn decimal() -> Self {
179        DataTypeSelector::Decimal.as_selector().into()
180    }
181
182    #[staticmethod]
183    fn numeric() -> Self {
184        DataTypeSelector::Numeric.as_selector().into()
185    }
186
187    #[staticmethod]
188    fn temporal() -> Self {
189        DataTypeSelector::Temporal.as_selector().into()
190    }
191
192    #[staticmethod]
193    fn datetime(tu: Vec<Wrap<TimeUnit>>, tz: Vec<Wrap<Option<TimeZone>>>) -> Self {
194        use TimeZoneSet as TZS;
195
196        let mut allow_unset = false;
197        let mut allow_set = false;
198        let mut any_of: Vec<TimeZone> = Vec::new();
199
200        let tu = parse_time_unit_set(tu);
201        for t in tz {
202            let t = t.0;
203            match t {
204                None => allow_unset = true,
205                Some(s) if s.as_str() == "*" => allow_set = true,
206                Some(t) => any_of.push(t),
207            }
208        }
209
210        let tzs = match (allow_unset, allow_set) {
211            (true, true) => TZS::Any,
212            (false, true) => TZS::AnySet,
213            (true, false) if any_of.is_empty() => TZS::Unset,
214            (true, false) => TZS::UnsetOrAnyOf(any_of.into()),
215            (false, false) => TZS::AnyOf(any_of.into()),
216        };
217        DataTypeSelector::Datetime(tu, tzs).as_selector().into()
218    }
219
220    #[staticmethod]
221    fn duration(tu: Vec<Wrap<TimeUnit>>) -> Self {
222        let tu = parse_time_unit_set(tu);
223        DataTypeSelector::Duration(tu).as_selector().into()
224    }
225
226    #[staticmethod]
227    fn object() -> Self {
228        DataTypeSelector::Object.as_selector().into()
229    }
230
231    #[staticmethod]
232    fn empty() -> Self {
233        dsl::empty().into()
234    }
235
236    #[staticmethod]
237    fn all() -> Self {
238        dsl::all().into()
239    }
240
241    fn hash(&self) -> u64 {
242        let mut hasher = std::hash::DefaultHasher::default();
243        self.inner.hash(&mut hasher);
244        hasher.finish()
245    }
246}