polars_plan/dsl/function_expr/
strings.rs

1use std::borrow::Cow;
2
3use arrow::legacy::utils::CustomIterTools;
4#[cfg(feature = "timezones")]
5use polars_core::chunked_array::temporal::validate_time_zone;
6use polars_core::utils::handle_casting_failures;
7#[cfg(feature = "dtype-struct")]
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "regex")]
10use regex::{NoExpand, escape};
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14use super::*;
15use crate::{map, map_as_slice};
16
17#[cfg(all(feature = "regex", feature = "timezones"))]
18polars_utils::regex_cache::cached_regex! {
19    static TZ_AWARE_RE = r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)";
20}
21
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23#[derive(Clone, PartialEq, Debug, Eq, Hash)]
24pub enum StringFunction {
25    #[cfg(feature = "concat_str")]
26    ConcatHorizontal {
27        delimiter: PlSmallStr,
28        ignore_nulls: bool,
29    },
30    #[cfg(feature = "concat_str")]
31    ConcatVertical {
32        delimiter: PlSmallStr,
33        ignore_nulls: bool,
34    },
35    #[cfg(feature = "regex")]
36    Contains {
37        literal: bool,
38        strict: bool,
39    },
40    CountMatches(bool),
41    EndsWith,
42    Extract(usize),
43    ExtractAll,
44    #[cfg(feature = "extract_groups")]
45    ExtractGroups {
46        dtype: DataType,
47        pat: PlSmallStr,
48    },
49    #[cfg(feature = "regex")]
50    Find {
51        literal: bool,
52        strict: bool,
53    },
54    #[cfg(feature = "string_to_integer")]
55    ToInteger(bool),
56    LenBytes,
57    LenChars,
58    Lowercase,
59    #[cfg(feature = "extract_jsonpath")]
60    JsonDecode {
61        dtype: Option<DataType>,
62        infer_schema_len: Option<usize>,
63    },
64    #[cfg(feature = "extract_jsonpath")]
65    JsonPathMatch,
66    #[cfg(feature = "regex")]
67    Replace {
68        // negative is replace all
69        // how many matches to replace
70        n: i64,
71        literal: bool,
72    },
73    #[cfg(feature = "string_normalize")]
74    Normalize {
75        form: UnicodeForm,
76    },
77    #[cfg(feature = "string_reverse")]
78    Reverse,
79    #[cfg(feature = "string_pad")]
80    PadStart {
81        length: usize,
82        fill_char: char,
83    },
84    #[cfg(feature = "string_pad")]
85    PadEnd {
86        length: usize,
87        fill_char: char,
88    },
89    Slice,
90    Head,
91    Tail,
92    #[cfg(feature = "string_encoding")]
93    HexEncode,
94    #[cfg(feature = "binary_encoding")]
95    HexDecode(bool),
96    #[cfg(feature = "string_encoding")]
97    Base64Encode,
98    #[cfg(feature = "binary_encoding")]
99    Base64Decode(bool),
100    StartsWith,
101    StripChars,
102    StripCharsStart,
103    StripCharsEnd,
104    StripPrefix,
105    StripSuffix,
106    #[cfg(feature = "dtype-struct")]
107    SplitExact {
108        n: usize,
109        inclusive: bool,
110    },
111    #[cfg(feature = "dtype-struct")]
112    SplitN(usize),
113    #[cfg(feature = "temporal")]
114    Strptime(DataType, StrptimeOptions),
115    Split(bool),
116    #[cfg(feature = "dtype-decimal")]
117    ToDecimal(usize),
118    #[cfg(feature = "nightly")]
119    Titlecase,
120    Uppercase,
121    #[cfg(feature = "string_pad")]
122    ZFill,
123    #[cfg(feature = "find_many")]
124    ContainsAny {
125        ascii_case_insensitive: bool,
126    },
127    #[cfg(feature = "find_many")]
128    ReplaceMany {
129        ascii_case_insensitive: bool,
130    },
131    #[cfg(feature = "find_many")]
132    ExtractMany {
133        ascii_case_insensitive: bool,
134        overlapping: bool,
135    },
136    #[cfg(feature = "find_many")]
137    FindMany {
138        ascii_case_insensitive: bool,
139        overlapping: bool,
140    },
141    #[cfg(feature = "regex")]
142    EscapeRegex,
143}
144
145impl StringFunction {
146    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
147        use StringFunction::*;
148        match self {
149            #[cfg(feature = "concat_str")]
150            ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String),
151            #[cfg(feature = "regex")]
152            Contains { .. } => mapper.with_dtype(DataType::Boolean),
153            CountMatches(_) => mapper.with_dtype(DataType::UInt32),
154            EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
155            Extract(_) => mapper.with_same_dtype(),
156            ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
157            #[cfg(feature = "extract_groups")]
158            ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
159            #[cfg(feature = "string_to_integer")]
160            ToInteger { .. } => mapper.with_dtype(DataType::Int64),
161            #[cfg(feature = "regex")]
162            Find { .. } => mapper.with_dtype(DataType::UInt32),
163            #[cfg(feature = "extract_jsonpath")]
164            JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
165            #[cfg(feature = "extract_jsonpath")]
166            JsonPathMatch => mapper.with_dtype(DataType::String),
167            LenBytes => mapper.with_dtype(DataType::UInt32),
168            LenChars => mapper.with_dtype(DataType::UInt32),
169            #[cfg(feature = "regex")]
170            Replace { .. } => mapper.with_same_dtype(),
171            #[cfg(feature = "string_normalize")]
172            Normalize { .. } => mapper.with_same_dtype(),
173            #[cfg(feature = "string_reverse")]
174            Reverse => mapper.with_same_dtype(),
175            #[cfg(feature = "temporal")]
176            Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
177            Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
178            #[cfg(feature = "nightly")]
179            Titlecase => mapper.with_same_dtype(),
180            #[cfg(feature = "dtype-decimal")]
181            ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
182            #[cfg(feature = "string_encoding")]
183            HexEncode => mapper.with_same_dtype(),
184            #[cfg(feature = "binary_encoding")]
185            HexDecode(_) => mapper.with_dtype(DataType::Binary),
186            #[cfg(feature = "string_encoding")]
187            Base64Encode => mapper.with_same_dtype(),
188            #[cfg(feature = "binary_encoding")]
189            Base64Decode(_) => mapper.with_dtype(DataType::Binary),
190            Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
191            | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
192            #[cfg(feature = "string_pad")]
193            PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
194            #[cfg(feature = "dtype-struct")]
195            SplitExact { n, .. } => mapper.with_dtype(DataType::Struct(
196                (0..n + 1)
197                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
198                    .collect(),
199            )),
200            #[cfg(feature = "dtype-struct")]
201            SplitN(n) => mapper.with_dtype(DataType::Struct(
202                (0..*n)
203                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
204                    .collect(),
205            )),
206            #[cfg(feature = "find_many")]
207            ContainsAny { .. } => mapper.with_dtype(DataType::Boolean),
208            #[cfg(feature = "find_many")]
209            ReplaceMany { .. } => mapper.with_same_dtype(),
210            #[cfg(feature = "find_many")]
211            ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
212            #[cfg(feature = "find_many")]
213            FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))),
214            #[cfg(feature = "regex")]
215            EscapeRegex => mapper.with_same_dtype(),
216        }
217    }
218
219    pub fn function_options(&self) -> FunctionOptions {
220        use StringFunction as S;
221        match self {
222            #[cfg(feature = "concat_str")]
223            S::ConcatHorizontal { .. } => FunctionOptions::elementwise()
224                .with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
225            #[cfg(feature = "concat_str")]
226            S::ConcatVertical { .. } => FunctionOptions::aggregation(),
227            #[cfg(feature = "regex")]
228            S::Contains { .. } => {
229                FunctionOptions::elementwise().with_supertyping(Default::default())
230            },
231            S::CountMatches(_) => FunctionOptions::elementwise(),
232            S::EndsWith | S::StartsWith | S::Extract(_) => {
233                FunctionOptions::elementwise().with_supertyping(Default::default())
234            },
235            S::ExtractAll => FunctionOptions::elementwise(),
236            #[cfg(feature = "extract_groups")]
237            S::ExtractGroups { .. } => FunctionOptions::elementwise(),
238            #[cfg(feature = "string_to_integer")]
239            S::ToInteger { .. } => FunctionOptions::elementwise(),
240            #[cfg(feature = "regex")]
241            S::Find { .. } => FunctionOptions::elementwise().with_supertyping(Default::default()),
242            #[cfg(feature = "extract_jsonpath")]
243            S::JsonDecode { dtype: Some(_), .. } => FunctionOptions::elementwise(),
244            // because dtype should be inferred only once and be consistent over chunks/morsels.
245            #[cfg(feature = "extract_jsonpath")]
246            S::JsonDecode { dtype: None, .. } => FunctionOptions::elementwise_with_infer(),
247            #[cfg(feature = "extract_jsonpath")]
248            S::JsonPathMatch => FunctionOptions::elementwise(),
249            S::LenBytes | S::LenChars => FunctionOptions::elementwise(),
250            #[cfg(feature = "regex")]
251            S::Replace { .. } => {
252                FunctionOptions::elementwise().with_supertyping(Default::default())
253            },
254            #[cfg(feature = "string_normalize")]
255            S::Normalize { .. } => FunctionOptions::elementwise(),
256            #[cfg(feature = "string_reverse")]
257            S::Reverse => FunctionOptions::elementwise(),
258            #[cfg(feature = "temporal")]
259            S::Strptime(_, options) if options.format.is_some() => FunctionOptions::elementwise(),
260            S::Strptime(_, _) => FunctionOptions::elementwise_with_infer(),
261            S::Split(_) => FunctionOptions::elementwise(),
262            #[cfg(feature = "nightly")]
263            S::Titlecase => FunctionOptions::elementwise(),
264            #[cfg(feature = "dtype-decimal")]
265            S::ToDecimal(_) => FunctionOptions::elementwise_with_infer(),
266            #[cfg(feature = "string_encoding")]
267            S::HexEncode | S::Base64Encode => FunctionOptions::elementwise(),
268            #[cfg(feature = "binary_encoding")]
269            S::HexDecode(_) | S::Base64Decode(_) => FunctionOptions::elementwise(),
270            S::Uppercase | S::Lowercase => FunctionOptions::elementwise(),
271            S::StripChars
272            | S::StripCharsStart
273            | S::StripCharsEnd
274            | S::StripPrefix
275            | S::StripSuffix
276            | S::Head
277            | S::Tail => FunctionOptions::elementwise(),
278            S::Slice => FunctionOptions::elementwise(),
279            #[cfg(feature = "string_pad")]
280            S::PadStart { .. } | S::PadEnd { .. } | S::ZFill => FunctionOptions::elementwise(),
281            #[cfg(feature = "dtype-struct")]
282            S::SplitExact { .. } => FunctionOptions::elementwise(),
283            #[cfg(feature = "dtype-struct")]
284            S::SplitN(_) => FunctionOptions::elementwise(),
285            #[cfg(feature = "find_many")]
286            S::ContainsAny { .. } => FunctionOptions::elementwise(),
287            #[cfg(feature = "find_many")]
288            S::ReplaceMany { .. } => FunctionOptions::elementwise(),
289            #[cfg(feature = "find_many")]
290            S::ExtractMany { .. } => FunctionOptions::elementwise(),
291            #[cfg(feature = "find_many")]
292            S::FindMany { .. } => FunctionOptions::elementwise(),
293            #[cfg(feature = "regex")]
294            S::EscapeRegex => FunctionOptions::elementwise(),
295        }
296    }
297}
298
299impl Display for StringFunction {
300    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
301        use StringFunction::*;
302        let s = match self {
303            #[cfg(feature = "regex")]
304            Contains { .. } => "contains",
305            CountMatches(_) => "count_matches",
306            EndsWith => "ends_with",
307            Extract(_) => "extract",
308            #[cfg(feature = "concat_str")]
309            ConcatHorizontal { .. } => "concat_horizontal",
310            #[cfg(feature = "concat_str")]
311            ConcatVertical { .. } => "concat_vertical",
312            ExtractAll => "extract_all",
313            #[cfg(feature = "extract_groups")]
314            ExtractGroups { .. } => "extract_groups",
315            #[cfg(feature = "string_to_integer")]
316            ToInteger { .. } => "to_integer",
317            #[cfg(feature = "regex")]
318            Find { .. } => "find",
319            Head => "head",
320            Tail => "tail",
321            #[cfg(feature = "extract_jsonpath")]
322            JsonDecode { .. } => "json_decode",
323            #[cfg(feature = "extract_jsonpath")]
324            JsonPathMatch => "json_path_match",
325            LenBytes => "len_bytes",
326            Lowercase => "lowercase",
327            LenChars => "len_chars",
328            #[cfg(feature = "string_pad")]
329            PadEnd { .. } => "pad_end",
330            #[cfg(feature = "string_pad")]
331            PadStart { .. } => "pad_start",
332            #[cfg(feature = "regex")]
333            Replace { .. } => "replace",
334            #[cfg(feature = "string_normalize")]
335            Normalize { .. } => "normalize",
336            #[cfg(feature = "string_reverse")]
337            Reverse => "reverse",
338            #[cfg(feature = "string_encoding")]
339            HexEncode => "hex_encode",
340            #[cfg(feature = "binary_encoding")]
341            HexDecode(_) => "hex_decode",
342            #[cfg(feature = "string_encoding")]
343            Base64Encode => "base64_encode",
344            #[cfg(feature = "binary_encoding")]
345            Base64Decode(_) => "base64_decode",
346            Slice => "slice",
347            StartsWith => "starts_with",
348            StripChars => "strip_chars",
349            StripCharsStart => "strip_chars_start",
350            StripCharsEnd => "strip_chars_end",
351            StripPrefix => "strip_prefix",
352            StripSuffix => "strip_suffix",
353            #[cfg(feature = "dtype-struct")]
354            SplitExact { inclusive, .. } => {
355                if *inclusive {
356                    "split_exact_inclusive"
357                } else {
358                    "split_exact"
359                }
360            },
361            #[cfg(feature = "dtype-struct")]
362            SplitN(_) => "splitn",
363            #[cfg(feature = "temporal")]
364            Strptime(_, _) => "strptime",
365            Split(inclusive) => {
366                if *inclusive {
367                    "split_inclusive"
368                } else {
369                    "split"
370                }
371            },
372            #[cfg(feature = "nightly")]
373            Titlecase => "titlecase",
374            #[cfg(feature = "dtype-decimal")]
375            ToDecimal(_) => "to_decimal",
376            Uppercase => "uppercase",
377            #[cfg(feature = "string_pad")]
378            ZFill => "zfill",
379            #[cfg(feature = "find_many")]
380            ContainsAny { .. } => "contains_any",
381            #[cfg(feature = "find_many")]
382            ReplaceMany { .. } => "replace_many",
383            #[cfg(feature = "find_many")]
384            ExtractMany { .. } => "extract_many",
385            #[cfg(feature = "find_many")]
386            FindMany { .. } => "extract_many",
387            #[cfg(feature = "regex")]
388            EscapeRegex => "escape_regex",
389        };
390        write!(f, "str.{s}")
391    }
392}
393
394impl From<StringFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
395    fn from(func: StringFunction) -> Self {
396        use StringFunction::*;
397        match func {
398            #[cfg(feature = "regex")]
399            Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
400            CountMatches(literal) => {
401                map_as_slice!(strings::count_matches, literal)
402            },
403            EndsWith => map_as_slice!(strings::ends_with),
404            StartsWith => map_as_slice!(strings::starts_with),
405            Extract(group_index) => map_as_slice!(strings::extract, group_index),
406            ExtractAll => {
407                map_as_slice!(strings::extract_all)
408            },
409            #[cfg(feature = "extract_groups")]
410            ExtractGroups { pat, dtype } => {
411                map!(strings::extract_groups, &pat, &dtype)
412            },
413            #[cfg(feature = "regex")]
414            Find { literal, strict } => map_as_slice!(strings::find, literal, strict),
415            LenBytes => map!(strings::len_bytes),
416            LenChars => map!(strings::len_chars),
417            #[cfg(feature = "string_pad")]
418            PadEnd { length, fill_char } => {
419                map!(strings::pad_end, length, fill_char)
420            },
421            #[cfg(feature = "string_pad")]
422            PadStart { length, fill_char } => {
423                map!(strings::pad_start, length, fill_char)
424            },
425            #[cfg(feature = "string_pad")]
426            ZFill => {
427                map_as_slice!(strings::zfill)
428            },
429            #[cfg(feature = "temporal")]
430            Strptime(dtype, options) => {
431                map_as_slice!(strings::strptime, dtype.clone(), &options)
432            },
433            Split(inclusive) => {
434                map_as_slice!(strings::split, inclusive)
435            },
436            #[cfg(feature = "dtype-struct")]
437            SplitExact { n, inclusive } => map_as_slice!(strings::split_exact, n, inclusive),
438            #[cfg(feature = "dtype-struct")]
439            SplitN(n) => map_as_slice!(strings::splitn, n),
440            #[cfg(feature = "concat_str")]
441            ConcatVertical {
442                delimiter,
443                ignore_nulls,
444            } => map!(strings::join, &delimiter, ignore_nulls),
445            #[cfg(feature = "concat_str")]
446            ConcatHorizontal {
447                delimiter,
448                ignore_nulls,
449            } => map_as_slice!(strings::concat_hor, &delimiter, ignore_nulls),
450            #[cfg(feature = "regex")]
451            Replace { n, literal } => map_as_slice!(strings::replace, literal, n),
452            #[cfg(feature = "string_normalize")]
453            Normalize { form } => map!(strings::normalize, form.clone()),
454            #[cfg(feature = "string_reverse")]
455            Reverse => map!(strings::reverse),
456            Uppercase => map!(uppercase),
457            Lowercase => map!(lowercase),
458            #[cfg(feature = "nightly")]
459            Titlecase => map!(strings::titlecase),
460            StripChars => map_as_slice!(strings::strip_chars),
461            StripCharsStart => map_as_slice!(strings::strip_chars_start),
462            StripCharsEnd => map_as_slice!(strings::strip_chars_end),
463            StripPrefix => map_as_slice!(strings::strip_prefix),
464            StripSuffix => map_as_slice!(strings::strip_suffix),
465            #[cfg(feature = "string_to_integer")]
466            ToInteger(strict) => map_as_slice!(strings::to_integer, strict),
467            Slice => map_as_slice!(strings::str_slice),
468            Head => map_as_slice!(strings::str_head),
469            Tail => map_as_slice!(strings::str_tail),
470            #[cfg(feature = "string_encoding")]
471            HexEncode => map!(strings::hex_encode),
472            #[cfg(feature = "binary_encoding")]
473            HexDecode(strict) => map!(strings::hex_decode, strict),
474            #[cfg(feature = "string_encoding")]
475            Base64Encode => map!(strings::base64_encode),
476            #[cfg(feature = "binary_encoding")]
477            Base64Decode(strict) => map!(strings::base64_decode, strict),
478            #[cfg(feature = "dtype-decimal")]
479            ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
480            #[cfg(feature = "extract_jsonpath")]
481            JsonDecode {
482                dtype,
483                infer_schema_len,
484            } => map!(strings::json_decode, dtype.clone(), infer_schema_len),
485            #[cfg(feature = "extract_jsonpath")]
486            JsonPathMatch => map_as_slice!(strings::json_path_match),
487            #[cfg(feature = "find_many")]
488            ContainsAny {
489                ascii_case_insensitive,
490            } => {
491                map_as_slice!(contains_any, ascii_case_insensitive)
492            },
493            #[cfg(feature = "find_many")]
494            ReplaceMany {
495                ascii_case_insensitive,
496            } => {
497                map_as_slice!(replace_many, ascii_case_insensitive)
498            },
499            #[cfg(feature = "find_many")]
500            ExtractMany {
501                ascii_case_insensitive,
502                overlapping,
503            } => {
504                map_as_slice!(extract_many, ascii_case_insensitive, overlapping)
505            },
506            #[cfg(feature = "find_many")]
507            FindMany {
508                ascii_case_insensitive,
509                overlapping,
510            } => {
511                map_as_slice!(find_many, ascii_case_insensitive, overlapping)
512            },
513            #[cfg(feature = "regex")]
514            EscapeRegex => map!(escape_regex),
515        }
516    }
517}
518
519#[cfg(feature = "find_many")]
520fn contains_any(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
521    let ca = s[0].str()?;
522    let patterns = s[1].list()?;
523    polars_ops::chunked_array::strings::contains_any(ca, patterns, ascii_case_insensitive)
524        .map(|out| out.into_column())
525}
526
527#[cfg(feature = "find_many")]
528fn replace_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
529    let ca = s[0].str()?;
530    let patterns = s[1].list()?;
531    let replace_with = s[2].list()?;
532    polars_ops::chunked_array::strings::replace_all(
533        ca,
534        patterns,
535        replace_with,
536        ascii_case_insensitive,
537    )
538    .map(|out| out.into_column())
539}
540
541#[cfg(feature = "find_many")]
542fn extract_many(
543    s: &[Column],
544    ascii_case_insensitive: bool,
545    overlapping: bool,
546) -> PolarsResult<Column> {
547    let ca = s[0].str()?;
548    let patterns = s[1].list()?;
549
550    polars_ops::chunked_array::strings::extract_many(
551        ca,
552        patterns,
553        ascii_case_insensitive,
554        overlapping,
555    )
556    .map(|out| out.into_column())
557}
558
559#[cfg(feature = "find_many")]
560fn find_many(
561    s: &[Column],
562    ascii_case_insensitive: bool,
563    overlapping: bool,
564) -> PolarsResult<Column> {
565    let ca = s[0].str()?;
566    let patterns = s[1].list()?;
567
568    polars_ops::chunked_array::strings::find_many(ca, patterns, ascii_case_insensitive, overlapping)
569        .map(|out| out.into_column())
570}
571
572fn uppercase(s: &Column) -> PolarsResult<Column> {
573    let ca = s.str()?;
574    Ok(ca.to_uppercase().into_column())
575}
576
577fn lowercase(s: &Column) -> PolarsResult<Column> {
578    let ca = s.str()?;
579    Ok(ca.to_lowercase().into_column())
580}
581
582#[cfg(feature = "nightly")]
583pub(super) fn titlecase(s: &Column) -> PolarsResult<Column> {
584    let ca = s.str()?;
585    Ok(ca.to_titlecase().into_column())
586}
587
588pub(super) fn len_chars(s: &Column) -> PolarsResult<Column> {
589    let ca = s.str()?;
590    Ok(ca.str_len_chars().into_column())
591}
592
593pub(super) fn len_bytes(s: &Column) -> PolarsResult<Column> {
594    let ca = s.str()?;
595    Ok(ca.str_len_bytes().into_column())
596}
597
598#[cfg(feature = "regex")]
599pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
600    _check_same_length(s, "contains")?;
601    let ca = s[0].str()?;
602    let pat = s[1].str()?;
603    ca.contains_chunked(pat, literal, strict)
604        .map(|ok| ok.into_column())
605}
606
607#[cfg(feature = "regex")]
608pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
609    _check_same_length(s, "find")?;
610    let ca = s[0].str()?;
611    let pat = s[1].str()?;
612    ca.find_chunked(pat, literal, strict)
613        .map(|ok| ok.into_column())
614}
615
616pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
617    _check_same_length(s, "ends_with")?;
618    let ca = s[0].str()?.as_binary();
619    let suffix = s[1].str()?.as_binary();
620
621    Ok(ca.ends_with_chunked(&suffix)?.into_column())
622}
623
624pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
625    _check_same_length(s, "starts_with")?;
626    let ca = s[0].str()?.as_binary();
627    let prefix = s[1].str()?.as_binary();
628    Ok(ca.starts_with_chunked(&prefix)?.into_column())
629}
630
631/// Extract a regex pattern from the a string value.
632pub(super) fn extract(s: &[Column], group_index: usize) -> PolarsResult<Column> {
633    let ca = s[0].str()?;
634    let pat = s[1].str()?;
635    ca.extract(pat, group_index).map(|ca| ca.into_column())
636}
637
638#[cfg(feature = "extract_groups")]
639/// Extract all capture groups from a regex pattern as a struct
640pub(super) fn extract_groups(s: &Column, pat: &str, dtype: &DataType) -> PolarsResult<Column> {
641    let ca = s.str()?;
642    ca.extract_groups(pat, dtype).map(Column::from)
643}
644
645#[cfg(feature = "string_pad")]
646pub(super) fn pad_start(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
647    let ca = s.str()?;
648    Ok(ca.pad_start(length, fill_char).into_column())
649}
650
651#[cfg(feature = "string_pad")]
652pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
653    let ca = s.str()?;
654    Ok(ca.pad_end(length, fill_char).into_column())
655}
656
657#[cfg(feature = "string_pad")]
658pub(super) fn zfill(s: &[Column]) -> PolarsResult<Column> {
659    _check_same_length(s, "zfill")?;
660    let ca = s[0].str()?;
661    let length_s = s[1].strict_cast(&DataType::UInt64)?;
662    let length = length_s.u64()?;
663    Ok(ca.zfill(length).into_column())
664}
665
666pub(super) fn strip_chars(s: &[Column]) -> PolarsResult<Column> {
667    _check_same_length(s, "strip_chars")?;
668    let ca = s[0].str()?;
669    let pat_s = &s[1];
670    ca.strip_chars(pat_s).map(|ok| ok.into_column())
671}
672
673pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult<Column> {
674    _check_same_length(s, "strip_chars_start")?;
675    let ca = s[0].str()?;
676    let pat_s = &s[1];
677    ca.strip_chars_start(pat_s).map(|ok| ok.into_column())
678}
679
680pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult<Column> {
681    _check_same_length(s, "strip_chars_end")?;
682    let ca = s[0].str()?;
683    let pat_s = &s[1];
684    ca.strip_chars_end(pat_s).map(|ok| ok.into_column())
685}
686
687pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult<Column> {
688    _check_same_length(s, "strip_prefix")?;
689    let ca = s[0].str()?;
690    let prefix = s[1].str()?;
691    Ok(ca.strip_prefix(prefix).into_column())
692}
693
694pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult<Column> {
695    _check_same_length(s, "strip_suffix")?;
696    let ca = s[0].str()?;
697    let suffix = s[1].str()?;
698    Ok(ca.strip_suffix(suffix).into_column())
699}
700
701pub(super) fn extract_all(args: &[Column]) -> PolarsResult<Column> {
702    let s = &args[0];
703    let pat = &args[1];
704
705    let ca = s.str()?;
706    let pat = pat.str()?;
707
708    if pat.len() == 1 {
709        if let Some(pat) = pat.get(0) {
710            ca.extract_all(pat).map(|ca| ca.into_column())
711        } else {
712            Ok(Column::full_null(
713                ca.name().clone(),
714                ca.len(),
715                &DataType::List(Box::new(DataType::String)),
716            ))
717        }
718    } else {
719        ca.extract_all_many(pat).map(|ca| ca.into_column())
720    }
721}
722
723pub(super) fn count_matches(args: &[Column], literal: bool) -> PolarsResult<Column> {
724    let s = &args[0];
725    let pat = &args[1];
726
727    let ca = s.str()?;
728    let pat = pat.str()?;
729    if pat.len() == 1 {
730        if let Some(pat) = pat.get(0) {
731            ca.count_matches(pat, literal).map(|ca| ca.into_column())
732        } else {
733            Ok(Column::full_null(
734                ca.name().clone(),
735                ca.len(),
736                &DataType::UInt32,
737            ))
738        }
739    } else {
740        ca.count_matches_many(pat, literal)
741            .map(|ca| ca.into_column())
742    }
743}
744
745#[cfg(feature = "temporal")]
746pub(super) fn strptime(
747    s: &[Column],
748    dtype: DataType,
749    options: &StrptimeOptions,
750) -> PolarsResult<Column> {
751    match dtype {
752        #[cfg(feature = "dtype-date")]
753        DataType::Date => to_date(&s[0], options),
754        #[cfg(feature = "dtype-datetime")]
755        DataType::Datetime(time_unit, time_zone) => {
756            to_datetime(s, &time_unit, time_zone.as_ref(), options)
757        },
758        #[cfg(feature = "dtype-time")]
759        DataType::Time => to_time(&s[0], options),
760        dt => polars_bail!(ComputeError: "not implemented for dtype {}", dt),
761    }
762}
763
764#[cfg(feature = "dtype-struct")]
765pub(super) fn split_exact(s: &[Column], n: usize, inclusive: bool) -> PolarsResult<Column> {
766    let ca = s[0].str()?;
767    let by = s[1].str()?;
768
769    if inclusive {
770        ca.split_exact_inclusive(by, n).map(|ca| ca.into_column())
771    } else {
772        ca.split_exact(by, n).map(|ca| ca.into_column())
773    }
774}
775
776#[cfg(feature = "dtype-struct")]
777pub(super) fn splitn(s: &[Column], n: usize) -> PolarsResult<Column> {
778    let ca = s[0].str()?;
779    let by = s[1].str()?;
780
781    ca.splitn(by, n).map(|ca| ca.into_column())
782}
783
784pub(super) fn split(s: &[Column], inclusive: bool) -> PolarsResult<Column> {
785    let ca = s[0].str()?;
786    let by = s[1].str()?;
787
788    if inclusive {
789        Ok(ca.split_inclusive(by)?.into_column())
790    } else {
791        Ok(ca.split(by)?.into_column())
792    }
793}
794
795#[cfg(feature = "dtype-date")]
796fn to_date(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
797    let ca = s.str()?;
798    let out = {
799        if options.exact {
800            ca.as_date(options.format.as_deref(), options.cache)?
801                .into_column()
802        } else {
803            ca.as_date_not_exact(options.format.as_deref())?
804                .into_column()
805        }
806    };
807
808    if options.strict && ca.null_count() != out.null_count() {
809        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
810    }
811    Ok(out.into_column())
812}
813
814#[cfg(feature = "dtype-datetime")]
815fn to_datetime(
816    s: &[Column],
817    time_unit: &TimeUnit,
818    time_zone: Option<&TimeZone>,
819    options: &StrptimeOptions,
820) -> PolarsResult<Column> {
821    let datetime_strings = &s[0].str()?;
822    let ambiguous = &s[1].str()?;
823
824    polars_ensure!(
825        datetime_strings.len() == ambiguous.len()
826            || datetime_strings.len() == 1
827            || ambiguous.len() == 1,
828        length_mismatch = "str.strptime",
829        datetime_strings.len(),
830        ambiguous.len()
831    );
832
833    let tz_aware = match &options.format {
834        #[cfg(all(feature = "regex", feature = "timezones"))]
835        Some(format) => TZ_AWARE_RE.is_match(format),
836        _ => false,
837    };
838    #[cfg(feature = "timezones")]
839    if let Some(time_zone) = time_zone {
840        validate_time_zone(time_zone)?;
841    }
842    let out = if options.exact {
843        datetime_strings
844            .as_datetime(
845                options.format.as_deref(),
846                *time_unit,
847                options.cache,
848                tz_aware,
849                time_zone,
850                ambiguous,
851            )?
852            .into_column()
853    } else {
854        datetime_strings
855            .as_datetime_not_exact(
856                options.format.as_deref(),
857                *time_unit,
858                tz_aware,
859                time_zone,
860                ambiguous,
861            )?
862            .into_column()
863    };
864
865    if options.strict && datetime_strings.null_count() != out.null_count() {
866        handle_casting_failures(s[0].as_materialized_series(), out.as_materialized_series())?;
867    }
868    Ok(out.into_column())
869}
870
871#[cfg(feature = "dtype-time")]
872fn to_time(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
873    polars_ensure!(
874        options.exact, ComputeError: "non-exact not implemented for Time data type"
875    );
876
877    let ca = s.str()?;
878    let out = ca
879        .as_time(options.format.as_deref(), options.cache)?
880        .into_column();
881
882    if options.strict && ca.null_count() != out.null_count() {
883        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
884    }
885    Ok(out.into_column())
886}
887
888#[cfg(feature = "concat_str")]
889pub(super) fn join(s: &Column, delimiter: &str, ignore_nulls: bool) -> PolarsResult<Column> {
890    let str_s = s.cast(&DataType::String)?;
891    let joined = polars_ops::chunked_array::str_join(str_s.str()?, delimiter, ignore_nulls);
892    Ok(joined.into_column())
893}
894
895#[cfg(feature = "concat_str")]
896pub(super) fn concat_hor(
897    series: &[Column],
898    delimiter: &str,
899    ignore_nulls: bool,
900) -> PolarsResult<Column> {
901    let str_series: Vec<_> = series
902        .iter()
903        .map(|s| s.cast(&DataType::String))
904        .collect::<PolarsResult<_>>()?;
905    let cas: Vec<_> = str_series.iter().map(|s| s.str().unwrap()).collect();
906    Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter, ignore_nulls)?.into_column())
907}
908
909impl From<StringFunction> for FunctionExpr {
910    fn from(str: StringFunction) -> Self {
911        FunctionExpr::StringExpr(str)
912    }
913}
914
915#[cfg(feature = "regex")]
916fn get_pat(pat: &StringChunked) -> PolarsResult<&str> {
917    pat.get(0).ok_or_else(
918        || polars_err!(ComputeError: "pattern cannot be 'null' in 'replace' expression"),
919    )
920}
921
922// used only if feature="regex"
923#[allow(dead_code)]
924fn iter_and_replace<'a, F>(ca: &'a StringChunked, val: &'a StringChunked, f: F) -> StringChunked
925where
926    F: Fn(&'a str, &'a str) -> Cow<'a, str>,
927{
928    let mut out: StringChunked = ca
929        .into_iter()
930        .zip(val)
931        .map(|(opt_src, opt_val)| match (opt_src, opt_val) {
932            (Some(src), Some(val)) => Some(f(src, val)),
933            _ => None,
934        })
935        .collect_trusted();
936
937    out.rename(ca.name().clone());
938    out
939}
940
941#[cfg(feature = "regex")]
942fn is_literal_pat(pat: &str) -> bool {
943    pat.chars().all(|c| !c.is_ascii_punctuation())
944}
945
946#[cfg(feature = "regex")]
947fn replace_n<'a>(
948    ca: &'a StringChunked,
949    pat: &'a StringChunked,
950    val: &'a StringChunked,
951    literal: bool,
952    n: usize,
953) -> PolarsResult<StringChunked> {
954    match (pat.len(), val.len()) {
955        (1, 1) => {
956            let pat = get_pat(pat)?;
957            let val = val.get(0).ok_or_else(
958                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
959            )?;
960            let literal = literal || is_literal_pat(pat);
961
962            match literal {
963                true => ca.replace_literal(pat, val, n),
964                false => {
965                    if n > 1 {
966                        polars_bail!(ComputeError: "regex replacement with 'n > 1' not yet supported")
967                    }
968                    ca.replace(pat, val)
969                },
970            }
971        },
972        (1, len_val) => {
973            if n > 1 {
974                polars_bail!(ComputeError: "multivalue replacement with 'n > 1' not yet supported")
975            }
976            let mut pat = get_pat(pat)?.to_string();
977            polars_ensure!(
978                len_val == ca.len(),
979                ComputeError:
980                "replacement value length ({}) does not match string column length ({})",
981                len_val, ca.len(),
982            );
983            let lit = is_literal_pat(&pat);
984            let literal_pat = literal || lit;
985
986            if literal_pat {
987                pat = escape(&pat)
988            }
989
990            let reg = polars_utils::regex_cache::compile_regex(&pat)?;
991
992            let f = |s: &'a str, val: &'a str| {
993                if lit && (s.len() <= 32) {
994                    Cow::Owned(s.replacen(&pat, val, 1))
995                } else {
996                    // According to the docs for replace
997                    // when literal = True then capture groups are ignored.
998                    if literal {
999                        reg.replace(s, NoExpand(val))
1000                    } else {
1001                        reg.replace(s, val)
1002                    }
1003                }
1004            };
1005            Ok(iter_and_replace(ca, val, f))
1006        },
1007        _ => polars_bail!(
1008            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
1009        ),
1010    }
1011}
1012
1013#[cfg(feature = "regex")]
1014fn replace_all<'a>(
1015    ca: &'a StringChunked,
1016    pat: &'a StringChunked,
1017    val: &'a StringChunked,
1018    literal: bool,
1019) -> PolarsResult<StringChunked> {
1020    match (pat.len(), val.len()) {
1021        (1, 1) => {
1022            let pat = get_pat(pat)?;
1023            let val = val.get(0).ok_or_else(
1024                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
1025            )?;
1026            let literal = literal || is_literal_pat(pat);
1027
1028            match literal {
1029                true => ca.replace_literal_all(pat, val),
1030                false => ca.replace_all(pat, val),
1031            }
1032        },
1033        (1, len_val) => {
1034            let mut pat = get_pat(pat)?.to_string();
1035            polars_ensure!(
1036                len_val == ca.len(),
1037                ComputeError:
1038                "replacement value length ({}) does not match string column length ({})",
1039                len_val, ca.len(),
1040            );
1041
1042            let literal_pat = literal || is_literal_pat(&pat);
1043
1044            if literal_pat {
1045                pat = escape(&pat)
1046            }
1047
1048            let reg = polars_utils::regex_cache::compile_regex(&pat)?;
1049
1050            let f = |s: &'a str, val: &'a str| {
1051                // According to the docs for replace_all
1052                // when literal = True then capture groups are ignored.
1053                if literal {
1054                    reg.replace_all(s, NoExpand(val))
1055                } else {
1056                    reg.replace_all(s, val)
1057                }
1058            };
1059
1060            Ok(iter_and_replace(ca, val, f))
1061        },
1062        _ => polars_bail!(
1063            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
1064        ),
1065    }
1066}
1067
1068#[cfg(feature = "regex")]
1069pub(super) fn replace(s: &[Column], literal: bool, n: i64) -> PolarsResult<Column> {
1070    let column = &s[0];
1071    let pat = &s[1];
1072    let val = &s[2];
1073    let all = n < 0;
1074
1075    let column = column.str()?;
1076    let pat = pat.str()?;
1077    let val = val.str()?;
1078
1079    if all {
1080        replace_all(column, pat, val, literal)
1081    } else {
1082        replace_n(column, pat, val, literal, n as usize)
1083    }
1084    .map(|ca| ca.into_column())
1085}
1086
1087#[cfg(feature = "string_normalize")]
1088pub(super) fn normalize(s: &Column, form: UnicodeForm) -> PolarsResult<Column> {
1089    let ca = s.str()?;
1090    Ok(ca.str_normalize(form).into_column())
1091}
1092
1093#[cfg(feature = "string_reverse")]
1094pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
1095    let ca = s.str()?;
1096    Ok(ca.str_reverse().into_column())
1097}
1098
1099#[cfg(feature = "string_to_integer")]
1100pub(super) fn to_integer(s: &[Column], strict: bool) -> PolarsResult<Column> {
1101    let ca = s[0].str()?;
1102    let base = s[1].strict_cast(&DataType::UInt32)?;
1103    ca.to_integer(base.u32()?, strict)
1104        .map(|ok| ok.into_column())
1105}
1106
1107fn _ensure_lengths(s: &[Column]) -> bool {
1108    // Calculate the post-broadcast length and ensure everything is consistent.
1109    let len = s
1110        .iter()
1111        .map(|series| series.len())
1112        .filter(|l| *l != 1)
1113        .max()
1114        .unwrap_or(1);
1115    s.iter()
1116        .all(|series| series.len() == 1 || series.len() == len)
1117}
1118
1119fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> {
1120    polars_ensure!(
1121        _ensure_lengths(s),
1122        ShapeMismatch: "all series in `str.{}()` should have equal or unit length",
1123        fn_name
1124    );
1125    Ok(())
1126}
1127
1128pub(super) fn str_slice(s: &[Column]) -> PolarsResult<Column> {
1129    _check_same_length(s, "slice")?;
1130    let ca = s[0].str()?;
1131    let offset = &s[1];
1132    let length = &s[2];
1133    Ok(ca.str_slice(offset, length)?.into_column())
1134}
1135
1136pub(super) fn str_head(s: &[Column]) -> PolarsResult<Column> {
1137    _check_same_length(s, "head")?;
1138    let ca = s[0].str()?;
1139    let n = &s[1];
1140    Ok(ca.str_head(n)?.into_column())
1141}
1142
1143pub(super) fn str_tail(s: &[Column]) -> PolarsResult<Column> {
1144    _check_same_length(s, "tail")?;
1145    let ca = s[0].str()?;
1146    let n = &s[1];
1147    Ok(ca.str_tail(n)?.into_column())
1148}
1149
1150#[cfg(feature = "string_encoding")]
1151pub(super) fn hex_encode(s: &Column) -> PolarsResult<Column> {
1152    Ok(s.str()?.hex_encode().into_column())
1153}
1154
1155#[cfg(feature = "binary_encoding")]
1156pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1157    s.str()?.hex_decode(strict).map(|ca| ca.into_column())
1158}
1159
1160#[cfg(feature = "string_encoding")]
1161pub(super) fn base64_encode(s: &Column) -> PolarsResult<Column> {
1162    Ok(s.str()?.base64_encode().into_column())
1163}
1164
1165#[cfg(feature = "binary_encoding")]
1166pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1167    s.str()?.base64_decode(strict).map(|ca| ca.into_column())
1168}
1169
1170#[cfg(feature = "dtype-decimal")]
1171pub(super) fn to_decimal(s: &Column, infer_len: usize) -> PolarsResult<Column> {
1172    let ca = s.str()?;
1173    ca.to_decimal(infer_len).map(Column::from)
1174}
1175
1176#[cfg(feature = "extract_jsonpath")]
1177pub(super) fn json_decode(
1178    s: &Column,
1179    dtype: Option<DataType>,
1180    infer_schema_len: Option<usize>,
1181) -> PolarsResult<Column> {
1182    let ca = s.str()?;
1183    ca.json_decode(dtype, infer_schema_len).map(Column::from)
1184}
1185
1186#[cfg(feature = "extract_jsonpath")]
1187pub(super) fn json_path_match(s: &[Column]) -> PolarsResult<Column> {
1188    _check_same_length(s, "json_path_match")?;
1189    let ca = s[0].str()?;
1190    let pat = s[1].str()?;
1191    Ok(ca.json_path_match(pat)?.into_column())
1192}
1193
1194#[cfg(feature = "regex")]
1195pub(super) fn escape_regex(s: &Column) -> PolarsResult<Column> {
1196    let ca = s.str()?;
1197    Ok(ca.str_escape_regex().into_column())
1198}