polars_plan/plans/
lit.rs

1use std::hash::{Hash, Hasher};
2
3#[cfg(feature = "temporal")]
4use chrono::{Duration as ChronoDuration, NaiveDate, NaiveDateTime};
5use polars_core::chunked_array::cast::CastOptions;
6use polars_core::prelude::*;
7use polars_core::utils::materialize_dyn_int;
8use polars_utils::hashing::hash_to_partition;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use crate::constants::get_literal_name;
13use crate::prelude::*;
14
15#[derive(Clone, PartialEq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub enum DynLiteralValue {
18    Str(PlSmallStr),
19    Int(i128),
20    Float(f64),
21    List(DynListLiteralValue),
22}
23#[derive(Clone, PartialEq)]
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25pub enum DynListLiteralValue {
26    Str(Box<[Option<PlSmallStr>]>),
27    Int(Box<[Option<i128>]>),
28    Float(Box<[Option<f64>]>),
29    List(Box<[Option<DynListLiteralValue>]>),
30}
31
32impl Hash for DynLiteralValue {
33    fn hash<H: Hasher>(&self, state: &mut H) {
34        std::mem::discriminant(self).hash(state);
35        match self {
36            Self::Str(i) => i.hash(state),
37            Self::Int(i) => i.hash(state),
38            Self::Float(i) => i.to_ne_bytes().hash(state),
39            Self::List(i) => i.hash(state),
40        }
41    }
42}
43
44impl Hash for DynListLiteralValue {
45    fn hash<H: Hasher>(&self, state: &mut H) {
46        std::mem::discriminant(self).hash(state);
47        match self {
48            Self::Str(i) => i.hash(state),
49            Self::Int(i) => i.hash(state),
50            Self::Float(i) => i
51                .iter()
52                .for_each(|i| i.map(|i| i.to_ne_bytes()).hash(state)),
53            Self::List(i) => i.hash(state),
54        }
55    }
56}
57
58#[derive(Clone, PartialEq, Hash)]
59#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
60pub struct RangeLiteralValue {
61    pub low: i128,
62    pub high: i128,
63    pub dtype: DataType,
64}
65#[derive(Clone, PartialEq)]
66#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67pub enum LiteralValue {
68    /// A dynamically inferred literal value. This needs to be materialized into a specific type.
69    Dyn(DynLiteralValue),
70    Scalar(Scalar),
71    Series(SpecialEq<Series>),
72    Range(RangeLiteralValue),
73}
74
75pub enum MaterializedLiteralValue {
76    Scalar(Scalar),
77    Series(Series),
78}
79
80impl DynListLiteralValue {
81    pub fn try_materialize_to_dtype(self, dtype: &DataType) -> PolarsResult<Scalar> {
82        let Some(inner_dtype) = dtype.inner_dtype() else {
83            polars_bail!(InvalidOperation: "conversion from list literal to `{dtype}` failed.");
84        };
85
86        let s = match self {
87            DynListLiteralValue::Str(vs) => {
88                StringChunked::from_iter_options(PlSmallStr::from_static("literal"), vs.into_iter())
89                    .into_series()
90            },
91            DynListLiteralValue::Int(vs) => {
92                #[cfg(feature = "dtype-i128")]
93                {
94                    Int128Chunked::from_iter_options(
95                        PlSmallStr::from_static("literal"),
96                        vs.into_iter(),
97                    )
98                    .into_series()
99                }
100
101                #[cfg(not(feature = "dtype-i128"))]
102                {
103                    Int64Chunked::from_iter_options(
104                        PlSmallStr::from_static("literal"),
105                        vs.into_iter().map(|v| v.map(|v| v as i64)),
106                    )
107                    .into_series()
108                }
109            },
110            DynListLiteralValue::Float(vs) => Float64Chunked::from_iter_options(
111                PlSmallStr::from_static("literal"),
112                vs.into_iter(),
113            )
114            .into_series(),
115            DynListLiteralValue::List(_) => todo!("nested lists"),
116        };
117
118        let s = s.cast_with_options(inner_dtype, CastOptions::Strict)?;
119        let value = match dtype {
120            DataType::List(_) => AnyValue::List(s),
121            #[cfg(feature = "dtype-array")]
122            DataType::Array(_, size) => AnyValue::Array(s, *size),
123            _ => unreachable!(),
124        };
125
126        Ok(Scalar::new(dtype.clone(), value))
127    }
128}
129
130impl DynLiteralValue {
131    pub fn try_materialize_to_dtype(self, dtype: &DataType) -> PolarsResult<Scalar> {
132        match self {
133            DynLiteralValue::Str(s) => {
134                Ok(Scalar::from(s).cast_with_options(dtype, CastOptions::Strict)?)
135            },
136            DynLiteralValue::Int(i) => {
137                Ok(Scalar::from(i).cast_with_options(dtype, CastOptions::Strict)?)
138            },
139            DynLiteralValue::Float(f) => {
140                Ok(Scalar::from(f).cast_with_options(dtype, CastOptions::Strict)?)
141            },
142            DynLiteralValue::List(dyn_list_value) => dyn_list_value.try_materialize_to_dtype(dtype),
143        }
144    }
145}
146
147impl RangeLiteralValue {
148    pub fn try_materialize_to_series(self, dtype: &DataType) -> PolarsResult<Series> {
149        fn handle_range_oob(range: &RangeLiteralValue, to_dtype: &DataType) -> PolarsResult<()> {
150            polars_bail!(
151                InvalidOperation:
152                "conversion from `{}` to `{to_dtype}` failed for range({}, {})",
153                range.dtype, range.low, range.high,
154            )
155        }
156
157        let s = match dtype {
158            DataType::Int32 => {
159                if self.low < i32::MIN as i128 || self.high > i32::MAX as i128 {
160                    handle_range_oob(&self, dtype)?;
161                }
162
163                new_int_range::<Int32Type>(
164                    self.low as i32,
165                    self.high as i32,
166                    1,
167                    PlSmallStr::from_static("range"),
168                )
169                .unwrap()
170            },
171            DataType::Int64 => {
172                if self.low < i64::MIN as i128 || self.high > i64::MAX as i128 {
173                    handle_range_oob(&self, dtype)?;
174                }
175
176                new_int_range::<Int64Type>(
177                    self.low as i64,
178                    self.high as i64,
179                    1,
180                    PlSmallStr::from_static("range"),
181                )
182                .unwrap()
183            },
184            DataType::UInt32 => {
185                if self.low < u32::MIN as i128 || self.high > u32::MAX as i128 {
186                    handle_range_oob(&self, dtype)?;
187                }
188                new_int_range::<UInt32Type>(
189                    self.low as u32,
190                    self.high as u32,
191                    1,
192                    PlSmallStr::from_static("range"),
193                )
194                .unwrap()
195            },
196            _ => polars_bail!(InvalidOperation: "unsupported range datatype `{dtype}`"),
197        };
198
199        Ok(s)
200    }
201}
202
203impl LiteralValue {
204    /// Get the output name as `&str`.
205    pub(crate) fn output_name(&self) -> &PlSmallStr {
206        match self {
207            LiteralValue::Series(s) => s.name(),
208            _ => get_literal_name(),
209        }
210    }
211
212    /// Get the output name as [`PlSmallStr`].
213    pub(crate) fn output_column_name(&self) -> &PlSmallStr {
214        match self {
215            LiteralValue::Series(s) => s.name(),
216            _ => get_literal_name(),
217        }
218    }
219
220    pub fn try_materialize_to_dtype(
221        self,
222        dtype: &DataType,
223    ) -> PolarsResult<MaterializedLiteralValue> {
224        use LiteralValue as L;
225        match self {
226            L::Dyn(dyn_value) => dyn_value
227                .try_materialize_to_dtype(dtype)
228                .map(MaterializedLiteralValue::Scalar),
229            L::Scalar(sc) => Ok(MaterializedLiteralValue::Scalar(
230                sc.cast_with_options(dtype, CastOptions::Strict)?,
231            )),
232            L::Range(range) => {
233                let Some(inner_dtype) = dtype.inner_dtype() else {
234                    polars_bail!(
235                        InvalidOperation: "cannot turn `{}` range into `{dtype}`",
236                        range.dtype
237                    );
238                };
239
240                let s = range.try_materialize_to_series(inner_dtype)?;
241                let value = match dtype {
242                    DataType::List(_) => AnyValue::List(s),
243                    #[cfg(feature = "dtype-array")]
244                    DataType::Array(_, size) => AnyValue::Array(s, *size),
245                    _ => unreachable!(),
246                };
247                Ok(MaterializedLiteralValue::Scalar(Scalar::new(
248                    dtype.clone(),
249                    value,
250                )))
251            },
252            L::Series(s) => Ok(MaterializedLiteralValue::Series(
253                s.cast_with_options(dtype, CastOptions::Strict)?,
254            )),
255        }
256    }
257
258    pub fn extract_usize(&self) -> PolarsResult<usize> {
259        macro_rules! cast_usize {
260            ($v:expr) => {
261                usize::try_from($v).map_err(
262                    |_| polars_err!(InvalidOperation: "cannot convert value {} to usize", $v)
263                )
264            }
265        }
266        match &self {
267            Self::Dyn(DynLiteralValue::Int(v)) => cast_usize!(*v),
268            Self::Scalar(sc) => match sc.as_any_value() {
269                AnyValue::UInt8(v) => Ok(v as usize),
270                AnyValue::UInt16(v) => Ok(v as usize),
271                AnyValue::UInt32(v) => cast_usize!(v),
272                AnyValue::UInt64(v) => cast_usize!(v),
273                AnyValue::Int8(v) => cast_usize!(v),
274                AnyValue::Int16(v) => cast_usize!(v),
275                AnyValue::Int32(v) => cast_usize!(v),
276                AnyValue::Int64(v) => cast_usize!(v),
277                AnyValue::Int128(v) => cast_usize!(v),
278                _ => {
279                    polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
280                },
281            },
282            _ => {
283                polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
284            },
285        }
286    }
287
288    pub fn materialize(self) -> Self {
289        match self {
290            LiteralValue::Dyn(_) => {
291                let av = self.to_any_value().unwrap();
292                av.into()
293            },
294            lv => lv,
295        }
296    }
297
298    pub fn is_scalar(&self) -> bool {
299        !matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. })
300    }
301
302    pub fn to_any_value(&self) -> Option<AnyValue> {
303        let av = match self {
304            Self::Scalar(sc) => sc.value().clone(),
305            Self::Range(range) => {
306                let s = range.clone().try_materialize_to_series(&range.dtype).ok()?;
307                AnyValue::List(s)
308            },
309            Self::Series(_) => return None,
310            Self::Dyn(d) => match d {
311                DynLiteralValue::Int(v) => materialize_dyn_int(*v),
312                DynLiteralValue::Float(v) => AnyValue::Float64(*v),
313                DynLiteralValue::Str(v) => AnyValue::String(v),
314                DynLiteralValue::List(_) => todo!(),
315            },
316        };
317        Some(av)
318    }
319
320    /// Getter for the `DataType` of the value
321    pub fn get_datatype(&self) -> DataType {
322        match self {
323            Self::Dyn(d) => match d {
324                DynLiteralValue::Int(v) => DataType::Unknown(UnknownKind::Int(*v)),
325                DynLiteralValue::Float(_) => DataType::Unknown(UnknownKind::Float),
326                DynLiteralValue::Str(_) => DataType::Unknown(UnknownKind::Str),
327                DynLiteralValue::List(_) => todo!(),
328            },
329            Self::Scalar(sc) => sc.dtype().clone(),
330            Self::Series(s) => s.dtype().clone(),
331            Self::Range(s) => s.dtype.clone(),
332        }
333    }
334
335    pub fn new_idxsize(value: IdxSize) -> Self {
336        LiteralValue::Scalar(value.into())
337    }
338
339    pub fn extract_str(&self) -> Option<&str> {
340        match self {
341            LiteralValue::Dyn(DynLiteralValue::Str(s)) => Some(s.as_str()),
342            LiteralValue::Scalar(sc) => match sc.value() {
343                AnyValue::String(s) => Some(s),
344                AnyValue::StringOwned(s) => Some(s),
345                _ => None,
346            },
347            _ => None,
348        }
349    }
350
351    pub fn extract_binary(&self) -> Option<&[u8]> {
352        match self {
353            LiteralValue::Scalar(sc) => match sc.value() {
354                AnyValue::Binary(s) => Some(s),
355                AnyValue::BinaryOwned(s) => Some(s),
356                _ => None,
357            },
358            _ => None,
359        }
360    }
361
362    pub fn is_null(&self) -> bool {
363        match self {
364            Self::Scalar(sc) => sc.is_null(),
365            Self::Series(s) => s.len() == 1 && s.null_count() == 1,
366            _ => false,
367        }
368    }
369
370    pub fn bool(&self) -> Option<bool> {
371        match self {
372            LiteralValue::Scalar(s) => match s.as_any_value() {
373                AnyValue::Boolean(b) => Some(b),
374                _ => None,
375            },
376            _ => None,
377        }
378    }
379
380    pub const fn untyped_null() -> Self {
381        Self::Scalar(Scalar::null(DataType::Null))
382    }
383}
384
385impl From<Scalar> for LiteralValue {
386    fn from(value: Scalar) -> Self {
387        Self::Scalar(value)
388    }
389}
390
391pub trait Literal {
392    /// [Literal](Expr::Literal) expression.
393    fn lit(self) -> Expr;
394}
395
396pub trait TypedLiteral: Literal {
397    /// [Literal](Expr::Literal) expression.
398    fn typed_lit(self) -> Expr
399    where
400        Self: Sized,
401    {
402        self.lit()
403    }
404}
405
406impl TypedLiteral for String {}
407impl TypedLiteral for &str {}
408
409impl Literal for PlSmallStr {
410    fn lit(self) -> Expr {
411        Expr::Literal(Scalar::from(self).into())
412    }
413}
414
415impl Literal for String {
416    fn lit(self) -> Expr {
417        Expr::Literal(Scalar::from(PlSmallStr::from_string(self)).into())
418    }
419}
420
421impl Literal for &str {
422    fn lit(self) -> Expr {
423        Expr::Literal(Scalar::from(PlSmallStr::from_str(self)).into())
424    }
425}
426
427impl Literal for Vec<u8> {
428    fn lit(self) -> Expr {
429        Expr::Literal(Scalar::from(self).into())
430    }
431}
432
433impl Literal for &[u8] {
434    fn lit(self) -> Expr {
435        Expr::Literal(Scalar::from(self.to_vec()).into())
436    }
437}
438
439impl From<AnyValue<'_>> for LiteralValue {
440    fn from(value: AnyValue<'_>) -> Self {
441        Self::Scalar(Scalar::new(value.dtype(), value.into_static()))
442    }
443}
444
445macro_rules! make_literal {
446    ($TYPE:ty, $SCALAR:ident) => {
447        impl Literal for $TYPE {
448            fn lit(self) -> Expr {
449                Expr::Literal(Scalar::from(self).into())
450            }
451        }
452    };
453}
454
455macro_rules! make_literal_typed {
456    ($TYPE:ty, $SCALAR:ident) => {
457        impl TypedLiteral for $TYPE {
458            fn typed_lit(self) -> Expr {
459                Expr::Literal(Scalar::from(self).into())
460            }
461        }
462    };
463}
464
465macro_rules! make_dyn_lit {
466    ($TYPE:ty, $SCALAR:ident) => {
467        impl Literal for $TYPE {
468            fn lit(self) -> Expr {
469                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::$SCALAR(
470                    self.try_into().unwrap(),
471                )))
472            }
473        }
474    };
475}
476
477make_literal!(bool, Boolean);
478make_literal_typed!(f32, Float32);
479make_literal_typed!(f64, Float64);
480make_literal_typed!(i8, Int8);
481make_literal_typed!(i16, Int16);
482make_literal_typed!(i32, Int32);
483make_literal_typed!(i64, Int64);
484make_literal_typed!(i128, Int128);
485make_literal_typed!(u8, UInt8);
486make_literal_typed!(u16, UInt16);
487make_literal_typed!(u32, UInt32);
488make_literal_typed!(u64, UInt64);
489
490make_dyn_lit!(f32, Float);
491make_dyn_lit!(f64, Float);
492make_dyn_lit!(i8, Int);
493make_dyn_lit!(i16, Int);
494make_dyn_lit!(i32, Int);
495make_dyn_lit!(i64, Int);
496make_dyn_lit!(u8, Int);
497make_dyn_lit!(u16, Int);
498make_dyn_lit!(u32, Int);
499make_dyn_lit!(u64, Int);
500make_dyn_lit!(i128, Int);
501
502/// The literal Null
503pub struct Null {}
504pub const NULL: Null = Null {};
505
506impl Literal for Null {
507    fn lit(self) -> Expr {
508        Expr::Literal(LiteralValue::Scalar(Scalar::null(DataType::Null)))
509    }
510}
511
512#[cfg(feature = "dtype-datetime")]
513impl Literal for NaiveDateTime {
514    fn lit(self) -> Expr {
515        if in_nanoseconds_window(&self) {
516            Expr::Literal(
517                Scalar::new_datetime(
518                    self.and_utc().timestamp_nanos_opt().unwrap(),
519                    TimeUnit::Nanoseconds,
520                    None,
521                )
522                .into(),
523            )
524        } else {
525            Expr::Literal(
526                Scalar::new_datetime(
527                    self.and_utc().timestamp_micros(),
528                    TimeUnit::Microseconds,
529                    None,
530                )
531                .into(),
532            )
533        }
534    }
535}
536
537#[cfg(feature = "dtype-duration")]
538impl Literal for ChronoDuration {
539    fn lit(self) -> Expr {
540        if let Some(value) = self.num_nanoseconds() {
541            Expr::Literal(Scalar::new_duration(value, TimeUnit::Nanoseconds).into())
542        } else {
543            Expr::Literal(
544                Scalar::new_duration(self.num_microseconds().unwrap(), TimeUnit::Microseconds)
545                    .into(),
546            )
547        }
548    }
549}
550
551#[cfg(feature = "dtype-duration")]
552impl Literal for Duration {
553    fn lit(self) -> Expr {
554        assert!(
555            self.months() == 0,
556            "Cannot create literal duration that is not of fixed length; found {}",
557            self
558        );
559        let ns = self.duration_ns();
560        Expr::Literal(
561            Scalar::new_duration(
562                if self.negative() { -ns } else { ns },
563                TimeUnit::Nanoseconds,
564            )
565            .into(),
566        )
567    }
568}
569
570#[cfg(feature = "dtype-datetime")]
571impl Literal for NaiveDate {
572    fn lit(self) -> Expr {
573        self.and_hms_opt(0, 0, 0).unwrap().lit()
574    }
575}
576
577impl Literal for Series {
578    fn lit(self) -> Expr {
579        Expr::Literal(LiteralValue::Series(SpecialEq::new(self)))
580    }
581}
582
583impl Literal for LiteralValue {
584    fn lit(self) -> Expr {
585        Expr::Literal(self)
586    }
587}
588
589impl Literal for Scalar {
590    fn lit(self) -> Expr {
591        Expr::Literal(self.into())
592    }
593}
594
595/// Create a Literal Expression from `L`. A literal expression behaves like a column that contains a single distinct
596/// value.
597///
598/// The column is automatically of the "correct" length to make the operations work. Often this is determined by the
599/// length of the `LazyFrame` it is being used with. For instance, `lazy_df.with_column(lit(5).alias("five"))` creates a
600/// new column named "five" that is the length of the Dataframe (at the time `collect` is called), where every value in
601/// the column is `5`.
602pub fn lit<L: Literal>(t: L) -> Expr {
603    t.lit()
604}
605
606pub fn typed_lit<L: TypedLiteral>(t: L) -> Expr {
607    t.typed_lit()
608}
609
610impl Hash for LiteralValue {
611    fn hash<H: Hasher>(&self, state: &mut H) {
612        std::mem::discriminant(self).hash(state);
613        match self {
614            LiteralValue::Series(s) => {
615                // Free stats
616                s.dtype().hash(state);
617                let len = s.len();
618                len.hash(state);
619                s.null_count().hash(state);
620                const RANDOM: u64 = 0x2c194fa5df32a367;
621                let mut rng = (len as u64) ^ RANDOM;
622                for _ in 0..std::cmp::min(5, len) {
623                    let idx = hash_to_partition(rng, len);
624                    s.get(idx).unwrap().hash(state);
625                    rng = rng.rotate_right(17).wrapping_add(RANDOM);
626                }
627            },
628            LiteralValue::Range(range) => range.hash(state),
629            LiteralValue::Scalar(sc) => sc.hash(state),
630            LiteralValue::Dyn(d) => d.hash(state),
631        }
632    }
633}