polars_python/expr/
selector.rs1use std::hash::{Hash, Hasher};
2use std::sync::Arc;
3
4use polars::prelude::{
5 DataType, DataTypeSelector, Selector, TimeUnit, TimeUnitSet, TimeZone, TimeZoneSet,
6};
7use polars_plan::dsl;
8use pyo3::PyResult;
9use pyo3::exceptions::PyTypeError;
10
11use crate::prelude::Wrap;
12
13#[pyo3::pyclass]
14#[repr(transparent)]
15#[derive(Clone)]
16pub struct PySelector {
17 pub inner: Selector,
18}
19
20impl From<Selector> for PySelector {
21 fn from(inner: Selector) -> Self {
22 Self { inner }
23 }
24}
25
26fn parse_time_unit_set(time_units: Vec<Wrap<TimeUnit>>) -> TimeUnitSet {
27 let mut tu = TimeUnitSet::empty();
28 for v in time_units {
29 match v.0 {
30 TimeUnit::Nanoseconds => tu |= TimeUnitSet::NANO_SECONDS,
31 TimeUnit::Microseconds => tu |= TimeUnitSet::MICRO_SECONDS,
32 TimeUnit::Milliseconds => tu |= TimeUnitSet::MILLI_SECONDS,
33 }
34 }
35 tu
36}
37
38pub fn parse_datatype_selector(selector: PySelector) -> PyResult<DataTypeSelector> {
39 selector.inner.to_dtype_selector().ok_or_else(|| {
40 PyTypeError::new_err(format!(
41 "expected datatype based expression got '{}'",
42 selector.inner
43 ))
44 })
45}
46
47#[cfg(feature = "pymethods")]
48#[pyo3::pymethods]
49impl PySelector {
50 fn union(&self, other: &Self) -> Self {
51 Self {
52 inner: self.inner.clone() | other.inner.clone(),
53 }
54 }
55
56 fn difference(&self, other: &Self) -> Self {
57 Self {
58 inner: self.inner.clone() - other.inner.clone(),
59 }
60 }
61
62 fn exclusive_or(&self, other: &Self) -> Self {
63 Self {
64 inner: self.inner.clone() ^ other.inner.clone(),
65 }
66 }
67
68 fn intersect(&self, other: &Self) -> Self {
69 Self {
70 inner: self.inner.clone() & other.inner.clone(),
71 }
72 }
73
74 #[staticmethod]
75 fn by_dtype(dtypes: Vec<Wrap<DataType>>) -> Self {
76 let dtypes = dtypes.into_iter().map(|x| x.0).collect::<Vec<_>>();
77 dsl::dtype_cols(dtypes).as_selector().into()
78 }
79
80 #[staticmethod]
81 fn by_name(names: Vec<String>, strict: bool) -> Self {
82 dsl::by_name(names, strict).into()
83 }
84
85 #[staticmethod]
86 fn by_index(indices: Vec<i64>, strict: bool) -> Self {
87 Selector::ByIndex {
88 indices: indices.into(),
89 strict,
90 }
91 .into()
92 }
93
94 #[staticmethod]
95 fn first(strict: bool) -> Self {
96 Selector::ByIndex {
97 indices: [0].into(),
98 strict,
99 }
100 .into()
101 }
102
103 #[staticmethod]
104 fn last(strict: bool) -> Self {
105 Selector::ByIndex {
106 indices: [-1].into(),
107 strict,
108 }
109 .into()
110 }
111
112 #[staticmethod]
113 fn matches(pattern: String) -> Self {
114 Selector::Matches(pattern.into()).into()
115 }
116
117 #[staticmethod]
118 fn enum_() -> Self {
119 DataTypeSelector::Enum.as_selector().into()
120 }
121
122 #[staticmethod]
123 fn categorical() -> Self {
124 DataTypeSelector::Categorical.as_selector().into()
125 }
126
127 #[staticmethod]
128 fn nested() -> Self {
129 DataTypeSelector::Nested.as_selector().into()
130 }
131
132 #[staticmethod]
133 fn list(inner_dst: Option<Self>) -> PyResult<Self> {
134 let inner_dst = match inner_dst {
135 None => None,
136 Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),
137 };
138 Ok(DataTypeSelector::List(inner_dst).as_selector().into())
139 }
140
141 #[staticmethod]
142 fn array(inner_dst: Option<Self>, width: Option<usize>) -> PyResult<Self> {
143 let inner_dst = match inner_dst {
144 None => None,
145 Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),
146 };
147 Ok(DataTypeSelector::Array(inner_dst, width)
148 .as_selector()
149 .into())
150 }
151
152 #[staticmethod]
153 fn struct_() -> Self {
154 DataTypeSelector::Struct.as_selector().into()
155 }
156
157 #[staticmethod]
158 fn integer() -> Self {
159 DataTypeSelector::Integer.as_selector().into()
160 }
161
162 #[staticmethod]
163 fn signed_integer() -> Self {
164 DataTypeSelector::SignedInteger.as_selector().into()
165 }
166
167 #[staticmethod]
168 fn unsigned_integer() -> Self {
169 DataTypeSelector::UnsignedInteger.as_selector().into()
170 }
171
172 #[staticmethod]
173 fn float() -> Self {
174 DataTypeSelector::Float.as_selector().into()
175 }
176
177 #[staticmethod]
178 fn decimal() -> Self {
179 DataTypeSelector::Decimal.as_selector().into()
180 }
181
182 #[staticmethod]
183 fn numeric() -> Self {
184 DataTypeSelector::Numeric.as_selector().into()
185 }
186
187 #[staticmethod]
188 fn temporal() -> Self {
189 DataTypeSelector::Temporal.as_selector().into()
190 }
191
192 #[staticmethod]
193 fn datetime(tu: Vec<Wrap<TimeUnit>>, tz: Vec<Wrap<Option<TimeZone>>>) -> Self {
194 use TimeZoneSet as TZS;
195
196 let mut allow_unset = false;
197 let mut allow_set = false;
198 let mut any_of: Vec<TimeZone> = Vec::new();
199
200 let tu = parse_time_unit_set(tu);
201 for t in tz {
202 let t = t.0;
203 match t {
204 None => allow_unset = true,
205 Some(s) if s.as_str() == "*" => allow_set = true,
206 Some(t) => any_of.push(t),
207 }
208 }
209
210 let tzs = match (allow_unset, allow_set) {
211 (true, true) => TZS::Any,
212 (false, true) => TZS::AnySet,
213 (true, false) if any_of.is_empty() => TZS::Unset,
214 (true, false) => TZS::UnsetOrAnyOf(any_of.into()),
215 (false, false) => TZS::AnyOf(any_of.into()),
216 };
217 DataTypeSelector::Datetime(tu, tzs).as_selector().into()
218 }
219
220 #[staticmethod]
221 fn duration(tu: Vec<Wrap<TimeUnit>>) -> Self {
222 let tu = parse_time_unit_set(tu);
223 DataTypeSelector::Duration(tu).as_selector().into()
224 }
225
226 #[staticmethod]
227 fn object() -> Self {
228 DataTypeSelector::Object.as_selector().into()
229 }
230
231 #[staticmethod]
232 fn empty() -> Self {
233 dsl::empty().into()
234 }
235
236 #[staticmethod]
237 fn all() -> Self {
238 dsl::all().into()
239 }
240
241 fn hash(&self) -> u64 {
242 let mut hasher = std::hash::DefaultHasher::default();
243 self.inner.hash(&mut hasher);
244 hasher.finish()
245 }
246}