polars_plan/dsl/function_expr/
strings.rs

1use std::borrow::Cow;
2
3use arrow::legacy::utils::CustomIterTools;
4use polars_core::utils::handle_casting_failures;
5#[cfg(feature = "dtype-struct")]
6use polars_utils::format_pl_smallstr;
7#[cfg(feature = "regex")]
8use regex::{NoExpand, escape};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::*;
13use crate::{map, map_as_slice};
14
15#[cfg(all(feature = "regex", feature = "timezones"))]
16polars_utils::regex_cache::cached_regex! {
17    static TZ_AWARE_RE = r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)";
18}
19
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21#[derive(Clone, PartialEq, Debug, Eq, Hash)]
22pub enum StringFunction {
23    #[cfg(feature = "concat_str")]
24    ConcatHorizontal {
25        delimiter: PlSmallStr,
26        ignore_nulls: bool,
27    },
28    #[cfg(feature = "concat_str")]
29    ConcatVertical {
30        delimiter: PlSmallStr,
31        ignore_nulls: bool,
32    },
33    #[cfg(feature = "regex")]
34    Contains {
35        literal: bool,
36        strict: bool,
37    },
38    CountMatches(bool),
39    EndsWith,
40    Extract(usize),
41    ExtractAll,
42    #[cfg(feature = "extract_groups")]
43    ExtractGroups {
44        dtype: DataType,
45        pat: PlSmallStr,
46    },
47    #[cfg(feature = "regex")]
48    Find {
49        literal: bool,
50        strict: bool,
51    },
52    #[cfg(feature = "string_to_integer")]
53    ToInteger(bool),
54    LenBytes,
55    LenChars,
56    Lowercase,
57    #[cfg(feature = "extract_jsonpath")]
58    JsonDecode {
59        dtype: Option<DataType>,
60        infer_schema_len: Option<usize>,
61    },
62    #[cfg(feature = "extract_jsonpath")]
63    JsonPathMatch,
64    #[cfg(feature = "regex")]
65    Replace {
66        // negative is replace all
67        // how many matches to replace
68        n: i64,
69        literal: bool,
70    },
71    #[cfg(feature = "string_normalize")]
72    Normalize {
73        form: UnicodeForm,
74    },
75    #[cfg(feature = "string_reverse")]
76    Reverse,
77    #[cfg(feature = "string_pad")]
78    PadStart {
79        length: usize,
80        fill_char: char,
81    },
82    #[cfg(feature = "string_pad")]
83    PadEnd {
84        length: usize,
85        fill_char: char,
86    },
87    Slice,
88    Head,
89    Tail,
90    #[cfg(feature = "string_encoding")]
91    HexEncode,
92    #[cfg(feature = "binary_encoding")]
93    HexDecode(bool),
94    #[cfg(feature = "string_encoding")]
95    Base64Encode,
96    #[cfg(feature = "binary_encoding")]
97    Base64Decode(bool),
98    StartsWith,
99    StripChars,
100    StripCharsStart,
101    StripCharsEnd,
102    StripPrefix,
103    StripSuffix,
104    #[cfg(feature = "dtype-struct")]
105    SplitExact {
106        n: usize,
107        inclusive: bool,
108    },
109    #[cfg(feature = "dtype-struct")]
110    SplitN(usize),
111    #[cfg(feature = "temporal")]
112    Strptime(DataType, StrptimeOptions),
113    Split(bool),
114    #[cfg(feature = "dtype-decimal")]
115    ToDecimal(usize),
116    #[cfg(feature = "nightly")]
117    Titlecase,
118    Uppercase,
119    #[cfg(feature = "string_pad")]
120    ZFill,
121    #[cfg(feature = "find_many")]
122    ContainsAny {
123        ascii_case_insensitive: bool,
124    },
125    #[cfg(feature = "find_many")]
126    ReplaceMany {
127        ascii_case_insensitive: bool,
128    },
129    #[cfg(feature = "find_many")]
130    ExtractMany {
131        ascii_case_insensitive: bool,
132        overlapping: bool,
133    },
134    #[cfg(feature = "find_many")]
135    FindMany {
136        ascii_case_insensitive: bool,
137        overlapping: bool,
138    },
139    #[cfg(feature = "regex")]
140    EscapeRegex,
141}
142
143impl StringFunction {
144    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
145        use StringFunction::*;
146        match self {
147            #[cfg(feature = "concat_str")]
148            ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String),
149            #[cfg(feature = "regex")]
150            Contains { .. } => mapper.with_dtype(DataType::Boolean),
151            CountMatches(_) => mapper.with_dtype(DataType::UInt32),
152            EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
153            Extract(_) => mapper.with_same_dtype(),
154            ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
155            #[cfg(feature = "extract_groups")]
156            ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
157            #[cfg(feature = "string_to_integer")]
158            ToInteger { .. } => mapper.with_dtype(DataType::Int64),
159            #[cfg(feature = "regex")]
160            Find { .. } => mapper.with_dtype(DataType::UInt32),
161            #[cfg(feature = "extract_jsonpath")]
162            JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
163            #[cfg(feature = "extract_jsonpath")]
164            JsonPathMatch => mapper.with_dtype(DataType::String),
165            LenBytes => mapper.with_dtype(DataType::UInt32),
166            LenChars => mapper.with_dtype(DataType::UInt32),
167            #[cfg(feature = "regex")]
168            Replace { .. } => mapper.with_same_dtype(),
169            #[cfg(feature = "string_normalize")]
170            Normalize { .. } => mapper.with_same_dtype(),
171            #[cfg(feature = "string_reverse")]
172            Reverse => mapper.with_same_dtype(),
173            #[cfg(feature = "temporal")]
174            Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
175            Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
176            #[cfg(feature = "nightly")]
177            Titlecase => mapper.with_same_dtype(),
178            #[cfg(feature = "dtype-decimal")]
179            ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
180            #[cfg(feature = "string_encoding")]
181            HexEncode => mapper.with_same_dtype(),
182            #[cfg(feature = "binary_encoding")]
183            HexDecode(_) => mapper.with_dtype(DataType::Binary),
184            #[cfg(feature = "string_encoding")]
185            Base64Encode => mapper.with_same_dtype(),
186            #[cfg(feature = "binary_encoding")]
187            Base64Decode(_) => mapper.with_dtype(DataType::Binary),
188            Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
189            | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
190            #[cfg(feature = "string_pad")]
191            PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
192            #[cfg(feature = "dtype-struct")]
193            SplitExact { n, .. } => mapper.with_dtype(DataType::Struct(
194                (0..n + 1)
195                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
196                    .collect(),
197            )),
198            #[cfg(feature = "dtype-struct")]
199            SplitN(n) => mapper.with_dtype(DataType::Struct(
200                (0..*n)
201                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
202                    .collect(),
203            )),
204            #[cfg(feature = "find_many")]
205            ContainsAny { .. } => mapper.with_dtype(DataType::Boolean),
206            #[cfg(feature = "find_many")]
207            ReplaceMany { .. } => mapper.with_same_dtype(),
208            #[cfg(feature = "find_many")]
209            ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
210            #[cfg(feature = "find_many")]
211            FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))),
212            #[cfg(feature = "regex")]
213            EscapeRegex => mapper.with_same_dtype(),
214        }
215    }
216
217    pub fn function_options(&self) -> FunctionOptions {
218        use StringFunction as S;
219        match self {
220            #[cfg(feature = "concat_str")]
221            S::ConcatHorizontal { .. } => FunctionOptions::elementwise()
222                .with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
223            #[cfg(feature = "concat_str")]
224            S::ConcatVertical { .. } => FunctionOptions::aggregation(),
225            #[cfg(feature = "regex")]
226            S::Contains { .. } => {
227                FunctionOptions::elementwise().with_supertyping(Default::default())
228            },
229            S::CountMatches(_) => FunctionOptions::elementwise(),
230            S::EndsWith | S::StartsWith | S::Extract(_) => {
231                FunctionOptions::elementwise().with_supertyping(Default::default())
232            },
233            S::ExtractAll => FunctionOptions::elementwise(),
234            #[cfg(feature = "extract_groups")]
235            S::ExtractGroups { .. } => FunctionOptions::elementwise(),
236            #[cfg(feature = "string_to_integer")]
237            S::ToInteger { .. } => FunctionOptions::elementwise(),
238            #[cfg(feature = "regex")]
239            S::Find { .. } => FunctionOptions::elementwise().with_supertyping(Default::default()),
240            #[cfg(feature = "extract_jsonpath")]
241            S::JsonDecode { dtype: Some(_), .. } => FunctionOptions::elementwise(),
242            // because dtype should be inferred only once and be consistent over chunks/morsels.
243            #[cfg(feature = "extract_jsonpath")]
244            S::JsonDecode { dtype: None, .. } => FunctionOptions::elementwise_with_infer(),
245            #[cfg(feature = "extract_jsonpath")]
246            S::JsonPathMatch => FunctionOptions::elementwise(),
247            S::LenBytes | S::LenChars => FunctionOptions::elementwise(),
248            #[cfg(feature = "regex")]
249            S::Replace { .. } => {
250                FunctionOptions::elementwise().with_supertyping(Default::default())
251            },
252            #[cfg(feature = "string_normalize")]
253            S::Normalize { .. } => FunctionOptions::elementwise(),
254            #[cfg(feature = "string_reverse")]
255            S::Reverse => FunctionOptions::elementwise(),
256            #[cfg(feature = "temporal")]
257            S::Strptime(_, options) if options.format.is_some() => FunctionOptions::elementwise(),
258            S::Strptime(_, _) => FunctionOptions::elementwise_with_infer(),
259            S::Split(_) => FunctionOptions::elementwise(),
260            #[cfg(feature = "nightly")]
261            S::Titlecase => FunctionOptions::elementwise(),
262            #[cfg(feature = "dtype-decimal")]
263            S::ToDecimal(_) => FunctionOptions::elementwise_with_infer(),
264            #[cfg(feature = "string_encoding")]
265            S::HexEncode | S::Base64Encode => FunctionOptions::elementwise(),
266            #[cfg(feature = "binary_encoding")]
267            S::HexDecode(_) | S::Base64Decode(_) => FunctionOptions::elementwise(),
268            S::Uppercase | S::Lowercase => FunctionOptions::elementwise(),
269            S::StripChars
270            | S::StripCharsStart
271            | S::StripCharsEnd
272            | S::StripPrefix
273            | S::StripSuffix
274            | S::Head
275            | S::Tail => FunctionOptions::elementwise(),
276            S::Slice => FunctionOptions::elementwise(),
277            #[cfg(feature = "string_pad")]
278            S::PadStart { .. } | S::PadEnd { .. } | S::ZFill => FunctionOptions::elementwise(),
279            #[cfg(feature = "dtype-struct")]
280            S::SplitExact { .. } => FunctionOptions::elementwise(),
281            #[cfg(feature = "dtype-struct")]
282            S::SplitN(_) => FunctionOptions::elementwise(),
283            #[cfg(feature = "find_many")]
284            S::ContainsAny { .. } => FunctionOptions::elementwise(),
285            #[cfg(feature = "find_many")]
286            S::ReplaceMany { .. } => FunctionOptions::elementwise(),
287            #[cfg(feature = "find_many")]
288            S::ExtractMany { .. } => FunctionOptions::elementwise(),
289            #[cfg(feature = "find_many")]
290            S::FindMany { .. } => FunctionOptions::elementwise(),
291            #[cfg(feature = "regex")]
292            S::EscapeRegex => FunctionOptions::elementwise(),
293        }
294    }
295}
296
297impl Display for StringFunction {
298    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
299        use StringFunction::*;
300        let s = match self {
301            #[cfg(feature = "regex")]
302            Contains { .. } => "contains",
303            CountMatches(_) => "count_matches",
304            EndsWith => "ends_with",
305            Extract(_) => "extract",
306            #[cfg(feature = "concat_str")]
307            ConcatHorizontal { .. } => "concat_horizontal",
308            #[cfg(feature = "concat_str")]
309            ConcatVertical { .. } => "concat_vertical",
310            ExtractAll => "extract_all",
311            #[cfg(feature = "extract_groups")]
312            ExtractGroups { .. } => "extract_groups",
313            #[cfg(feature = "string_to_integer")]
314            ToInteger { .. } => "to_integer",
315            #[cfg(feature = "regex")]
316            Find { .. } => "find",
317            Head => "head",
318            Tail => "tail",
319            #[cfg(feature = "extract_jsonpath")]
320            JsonDecode { .. } => "json_decode",
321            #[cfg(feature = "extract_jsonpath")]
322            JsonPathMatch => "json_path_match",
323            LenBytes => "len_bytes",
324            Lowercase => "lowercase",
325            LenChars => "len_chars",
326            #[cfg(feature = "string_pad")]
327            PadEnd { .. } => "pad_end",
328            #[cfg(feature = "string_pad")]
329            PadStart { .. } => "pad_start",
330            #[cfg(feature = "regex")]
331            Replace { .. } => "replace",
332            #[cfg(feature = "string_normalize")]
333            Normalize { .. } => "normalize",
334            #[cfg(feature = "string_reverse")]
335            Reverse => "reverse",
336            #[cfg(feature = "string_encoding")]
337            HexEncode => "hex_encode",
338            #[cfg(feature = "binary_encoding")]
339            HexDecode(_) => "hex_decode",
340            #[cfg(feature = "string_encoding")]
341            Base64Encode => "base64_encode",
342            #[cfg(feature = "binary_encoding")]
343            Base64Decode(_) => "base64_decode",
344            Slice => "slice",
345            StartsWith => "starts_with",
346            StripChars => "strip_chars",
347            StripCharsStart => "strip_chars_start",
348            StripCharsEnd => "strip_chars_end",
349            StripPrefix => "strip_prefix",
350            StripSuffix => "strip_suffix",
351            #[cfg(feature = "dtype-struct")]
352            SplitExact { inclusive, .. } => {
353                if *inclusive {
354                    "split_exact_inclusive"
355                } else {
356                    "split_exact"
357                }
358            },
359            #[cfg(feature = "dtype-struct")]
360            SplitN(_) => "splitn",
361            #[cfg(feature = "temporal")]
362            Strptime(_, _) => "strptime",
363            Split(inclusive) => {
364                if *inclusive {
365                    "split_inclusive"
366                } else {
367                    "split"
368                }
369            },
370            #[cfg(feature = "nightly")]
371            Titlecase => "titlecase",
372            #[cfg(feature = "dtype-decimal")]
373            ToDecimal(_) => "to_decimal",
374            Uppercase => "uppercase",
375            #[cfg(feature = "string_pad")]
376            ZFill => "zfill",
377            #[cfg(feature = "find_many")]
378            ContainsAny { .. } => "contains_any",
379            #[cfg(feature = "find_many")]
380            ReplaceMany { .. } => "replace_many",
381            #[cfg(feature = "find_many")]
382            ExtractMany { .. } => "extract_many",
383            #[cfg(feature = "find_many")]
384            FindMany { .. } => "extract_many",
385            #[cfg(feature = "regex")]
386            EscapeRegex => "escape_regex",
387        };
388        write!(f, "str.{s}")
389    }
390}
391
392impl From<StringFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
393    fn from(func: StringFunction) -> Self {
394        use StringFunction::*;
395        match func {
396            #[cfg(feature = "regex")]
397            Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
398            CountMatches(literal) => {
399                map_as_slice!(strings::count_matches, literal)
400            },
401            EndsWith => map_as_slice!(strings::ends_with),
402            StartsWith => map_as_slice!(strings::starts_with),
403            Extract(group_index) => map_as_slice!(strings::extract, group_index),
404            ExtractAll => {
405                map_as_slice!(strings::extract_all)
406            },
407            #[cfg(feature = "extract_groups")]
408            ExtractGroups { pat, dtype } => {
409                map!(strings::extract_groups, &pat, &dtype)
410            },
411            #[cfg(feature = "regex")]
412            Find { literal, strict } => map_as_slice!(strings::find, literal, strict),
413            LenBytes => map!(strings::len_bytes),
414            LenChars => map!(strings::len_chars),
415            #[cfg(feature = "string_pad")]
416            PadEnd { length, fill_char } => {
417                map!(strings::pad_end, length, fill_char)
418            },
419            #[cfg(feature = "string_pad")]
420            PadStart { length, fill_char } => {
421                map!(strings::pad_start, length, fill_char)
422            },
423            #[cfg(feature = "string_pad")]
424            ZFill => {
425                map_as_slice!(strings::zfill)
426            },
427            #[cfg(feature = "temporal")]
428            Strptime(dtype, options) => {
429                map_as_slice!(strings::strptime, dtype.clone(), &options)
430            },
431            Split(inclusive) => {
432                map_as_slice!(strings::split, inclusive)
433            },
434            #[cfg(feature = "dtype-struct")]
435            SplitExact { n, inclusive } => map_as_slice!(strings::split_exact, n, inclusive),
436            #[cfg(feature = "dtype-struct")]
437            SplitN(n) => map_as_slice!(strings::splitn, n),
438            #[cfg(feature = "concat_str")]
439            ConcatVertical {
440                delimiter,
441                ignore_nulls,
442            } => map!(strings::join, &delimiter, ignore_nulls),
443            #[cfg(feature = "concat_str")]
444            ConcatHorizontal {
445                delimiter,
446                ignore_nulls,
447            } => map_as_slice!(strings::concat_hor, &delimiter, ignore_nulls),
448            #[cfg(feature = "regex")]
449            Replace { n, literal } => map_as_slice!(strings::replace, literal, n),
450            #[cfg(feature = "string_normalize")]
451            Normalize { form } => map!(strings::normalize, form.clone()),
452            #[cfg(feature = "string_reverse")]
453            Reverse => map!(strings::reverse),
454            Uppercase => map!(uppercase),
455            Lowercase => map!(lowercase),
456            #[cfg(feature = "nightly")]
457            Titlecase => map!(strings::titlecase),
458            StripChars => map_as_slice!(strings::strip_chars),
459            StripCharsStart => map_as_slice!(strings::strip_chars_start),
460            StripCharsEnd => map_as_slice!(strings::strip_chars_end),
461            StripPrefix => map_as_slice!(strings::strip_prefix),
462            StripSuffix => map_as_slice!(strings::strip_suffix),
463            #[cfg(feature = "string_to_integer")]
464            ToInteger(strict) => map_as_slice!(strings::to_integer, strict),
465            Slice => map_as_slice!(strings::str_slice),
466            Head => map_as_slice!(strings::str_head),
467            Tail => map_as_slice!(strings::str_tail),
468            #[cfg(feature = "string_encoding")]
469            HexEncode => map!(strings::hex_encode),
470            #[cfg(feature = "binary_encoding")]
471            HexDecode(strict) => map!(strings::hex_decode, strict),
472            #[cfg(feature = "string_encoding")]
473            Base64Encode => map!(strings::base64_encode),
474            #[cfg(feature = "binary_encoding")]
475            Base64Decode(strict) => map!(strings::base64_decode, strict),
476            #[cfg(feature = "dtype-decimal")]
477            ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
478            #[cfg(feature = "extract_jsonpath")]
479            JsonDecode {
480                dtype,
481                infer_schema_len,
482            } => map!(strings::json_decode, dtype.clone(), infer_schema_len),
483            #[cfg(feature = "extract_jsonpath")]
484            JsonPathMatch => map_as_slice!(strings::json_path_match),
485            #[cfg(feature = "find_many")]
486            ContainsAny {
487                ascii_case_insensitive,
488            } => {
489                map_as_slice!(contains_any, ascii_case_insensitive)
490            },
491            #[cfg(feature = "find_many")]
492            ReplaceMany {
493                ascii_case_insensitive,
494            } => {
495                map_as_slice!(replace_many, ascii_case_insensitive)
496            },
497            #[cfg(feature = "find_many")]
498            ExtractMany {
499                ascii_case_insensitive,
500                overlapping,
501            } => {
502                map_as_slice!(extract_many, ascii_case_insensitive, overlapping)
503            },
504            #[cfg(feature = "find_many")]
505            FindMany {
506                ascii_case_insensitive,
507                overlapping,
508            } => {
509                map_as_slice!(find_many, ascii_case_insensitive, overlapping)
510            },
511            #[cfg(feature = "regex")]
512            EscapeRegex => map!(escape_regex),
513        }
514    }
515}
516
517#[cfg(feature = "find_many")]
518fn contains_any(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
519    let ca = s[0].str()?;
520    let patterns = s[1].list()?;
521    polars_ops::chunked_array::strings::contains_any(ca, patterns, ascii_case_insensitive)
522        .map(|out| out.into_column())
523}
524
525#[cfg(feature = "find_many")]
526fn replace_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
527    let ca = s[0].str()?;
528    let patterns = s[1].list()?;
529    let replace_with = s[2].list()?;
530    polars_ops::chunked_array::strings::replace_all(
531        ca,
532        patterns,
533        replace_with,
534        ascii_case_insensitive,
535    )
536    .map(|out| out.into_column())
537}
538
539#[cfg(feature = "find_many")]
540fn extract_many(
541    s: &[Column],
542    ascii_case_insensitive: bool,
543    overlapping: bool,
544) -> PolarsResult<Column> {
545    let ca = s[0].str()?;
546    let patterns = s[1].list()?;
547
548    polars_ops::chunked_array::strings::extract_many(
549        ca,
550        patterns,
551        ascii_case_insensitive,
552        overlapping,
553    )
554    .map(|out| out.into_column())
555}
556
557#[cfg(feature = "find_many")]
558fn find_many(
559    s: &[Column],
560    ascii_case_insensitive: bool,
561    overlapping: bool,
562) -> PolarsResult<Column> {
563    let ca = s[0].str()?;
564    let patterns = s[1].list()?;
565
566    polars_ops::chunked_array::strings::find_many(ca, patterns, ascii_case_insensitive, overlapping)
567        .map(|out| out.into_column())
568}
569
570fn uppercase(s: &Column) -> PolarsResult<Column> {
571    let ca = s.str()?;
572    Ok(ca.to_uppercase().into_column())
573}
574
575fn lowercase(s: &Column) -> PolarsResult<Column> {
576    let ca = s.str()?;
577    Ok(ca.to_lowercase().into_column())
578}
579
580#[cfg(feature = "nightly")]
581pub(super) fn titlecase(s: &Column) -> PolarsResult<Column> {
582    let ca = s.str()?;
583    Ok(ca.to_titlecase().into_column())
584}
585
586pub(super) fn len_chars(s: &Column) -> PolarsResult<Column> {
587    let ca = s.str()?;
588    Ok(ca.str_len_chars().into_column())
589}
590
591pub(super) fn len_bytes(s: &Column) -> PolarsResult<Column> {
592    let ca = s.str()?;
593    Ok(ca.str_len_bytes().into_column())
594}
595
596#[cfg(feature = "regex")]
597pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
598    _check_same_length(s, "contains")?;
599    let ca = s[0].str()?;
600    let pat = s[1].str()?;
601    ca.contains_chunked(pat, literal, strict)
602        .map(|ok| ok.into_column())
603}
604
605#[cfg(feature = "regex")]
606pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
607    _check_same_length(s, "find")?;
608    let ca = s[0].str()?;
609    let pat = s[1].str()?;
610    ca.find_chunked(pat, literal, strict)
611        .map(|ok| ok.into_column())
612}
613
614pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
615    _check_same_length(s, "ends_with")?;
616    let ca = s[0].str()?.as_binary();
617    let suffix = s[1].str()?.as_binary();
618
619    Ok(ca.ends_with_chunked(&suffix)?.into_column())
620}
621
622pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
623    _check_same_length(s, "starts_with")?;
624    let ca = s[0].str()?.as_binary();
625    let prefix = s[1].str()?.as_binary();
626    Ok(ca.starts_with_chunked(&prefix)?.into_column())
627}
628
629/// Extract a regex pattern from the a string value.
630pub(super) fn extract(s: &[Column], group_index: usize) -> PolarsResult<Column> {
631    let ca = s[0].str()?;
632    let pat = s[1].str()?;
633    ca.extract(pat, group_index).map(|ca| ca.into_column())
634}
635
636#[cfg(feature = "extract_groups")]
637/// Extract all capture groups from a regex pattern as a struct
638pub(super) fn extract_groups(s: &Column, pat: &str, dtype: &DataType) -> PolarsResult<Column> {
639    let ca = s.str()?;
640    ca.extract_groups(pat, dtype).map(Column::from)
641}
642
643#[cfg(feature = "string_pad")]
644pub(super) fn pad_start(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
645    let ca = s.str()?;
646    Ok(ca.pad_start(length, fill_char).into_column())
647}
648
649#[cfg(feature = "string_pad")]
650pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
651    let ca = s.str()?;
652    Ok(ca.pad_end(length, fill_char).into_column())
653}
654
655#[cfg(feature = "string_pad")]
656pub(super) fn zfill(s: &[Column]) -> PolarsResult<Column> {
657    _check_same_length(s, "zfill")?;
658    let ca = s[0].str()?;
659    let length_s = s[1].strict_cast(&DataType::UInt64)?;
660    let length = length_s.u64()?;
661    Ok(ca.zfill(length).into_column())
662}
663
664pub(super) fn strip_chars(s: &[Column]) -> PolarsResult<Column> {
665    _check_same_length(s, "strip_chars")?;
666    let ca = s[0].str()?;
667    let pat_s = &s[1];
668    ca.strip_chars(pat_s).map(|ok| ok.into_column())
669}
670
671pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult<Column> {
672    _check_same_length(s, "strip_chars_start")?;
673    let ca = s[0].str()?;
674    let pat_s = &s[1];
675    ca.strip_chars_start(pat_s).map(|ok| ok.into_column())
676}
677
678pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult<Column> {
679    _check_same_length(s, "strip_chars_end")?;
680    let ca = s[0].str()?;
681    let pat_s = &s[1];
682    ca.strip_chars_end(pat_s).map(|ok| ok.into_column())
683}
684
685pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult<Column> {
686    _check_same_length(s, "strip_prefix")?;
687    let ca = s[0].str()?;
688    let prefix = s[1].str()?;
689    Ok(ca.strip_prefix(prefix).into_column())
690}
691
692pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult<Column> {
693    _check_same_length(s, "strip_suffix")?;
694    let ca = s[0].str()?;
695    let suffix = s[1].str()?;
696    Ok(ca.strip_suffix(suffix).into_column())
697}
698
699pub(super) fn extract_all(args: &[Column]) -> PolarsResult<Column> {
700    let s = &args[0];
701    let pat = &args[1];
702
703    let ca = s.str()?;
704    let pat = pat.str()?;
705
706    if pat.len() == 1 {
707        if let Some(pat) = pat.get(0) {
708            ca.extract_all(pat).map(|ca| ca.into_column())
709        } else {
710            Ok(Column::full_null(
711                ca.name().clone(),
712                ca.len(),
713                &DataType::List(Box::new(DataType::String)),
714            ))
715        }
716    } else {
717        ca.extract_all_many(pat).map(|ca| ca.into_column())
718    }
719}
720
721pub(super) fn count_matches(args: &[Column], literal: bool) -> PolarsResult<Column> {
722    let s = &args[0];
723    let pat = &args[1];
724
725    let ca = s.str()?;
726    let pat = pat.str()?;
727    if pat.len() == 1 {
728        if let Some(pat) = pat.get(0) {
729            ca.count_matches(pat, literal).map(|ca| ca.into_column())
730        } else {
731            Ok(Column::full_null(
732                ca.name().clone(),
733                ca.len(),
734                &DataType::UInt32,
735            ))
736        }
737    } else {
738        ca.count_matches_many(pat, literal)
739            .map(|ca| ca.into_column())
740    }
741}
742
743#[cfg(feature = "temporal")]
744pub(super) fn strptime(
745    s: &[Column],
746    dtype: DataType,
747    options: &StrptimeOptions,
748) -> PolarsResult<Column> {
749    match dtype {
750        #[cfg(feature = "dtype-date")]
751        DataType::Date => to_date(&s[0], options),
752        #[cfg(feature = "dtype-datetime")]
753        DataType::Datetime(time_unit, time_zone) => {
754            to_datetime(s, &time_unit, time_zone.as_ref(), options)
755        },
756        #[cfg(feature = "dtype-time")]
757        DataType::Time => to_time(&s[0], options),
758        dt => polars_bail!(ComputeError: "not implemented for dtype {}", dt),
759    }
760}
761
762#[cfg(feature = "dtype-struct")]
763pub(super) fn split_exact(s: &[Column], n: usize, inclusive: bool) -> PolarsResult<Column> {
764    let ca = s[0].str()?;
765    let by = s[1].str()?;
766
767    if inclusive {
768        ca.split_exact_inclusive(by, n).map(|ca| ca.into_column())
769    } else {
770        ca.split_exact(by, n).map(|ca| ca.into_column())
771    }
772}
773
774#[cfg(feature = "dtype-struct")]
775pub(super) fn splitn(s: &[Column], n: usize) -> PolarsResult<Column> {
776    let ca = s[0].str()?;
777    let by = s[1].str()?;
778
779    ca.splitn(by, n).map(|ca| ca.into_column())
780}
781
782pub(super) fn split(s: &[Column], inclusive: bool) -> PolarsResult<Column> {
783    let ca = s[0].str()?;
784    let by = s[1].str()?;
785
786    if inclusive {
787        Ok(ca.split_inclusive(by)?.into_column())
788    } else {
789        Ok(ca.split(by)?.into_column())
790    }
791}
792
793#[cfg(feature = "dtype-date")]
794fn to_date(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
795    let ca = s.str()?;
796    let out = {
797        if options.exact {
798            ca.as_date(options.format.as_deref(), options.cache)?
799                .into_column()
800        } else {
801            ca.as_date_not_exact(options.format.as_deref())?
802                .into_column()
803        }
804    };
805
806    if options.strict && ca.null_count() != out.null_count() {
807        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
808    }
809    Ok(out.into_column())
810}
811
812#[cfg(feature = "dtype-datetime")]
813fn to_datetime(
814    s: &[Column],
815    time_unit: &TimeUnit,
816    time_zone: Option<&TimeZone>,
817    options: &StrptimeOptions,
818) -> PolarsResult<Column> {
819    let datetime_strings = &s[0].str()?;
820    let ambiguous = &s[1].str()?;
821
822    polars_ensure!(
823        datetime_strings.len() == ambiguous.len()
824            || datetime_strings.len() == 1
825            || ambiguous.len() == 1,
826        length_mismatch = "str.strptime",
827        datetime_strings.len(),
828        ambiguous.len()
829    );
830
831    let tz_aware = match &options.format {
832        #[cfg(all(feature = "regex", feature = "timezones"))]
833        Some(format) => TZ_AWARE_RE.is_match(format),
834        _ => false,
835    };
836
837    let out = if options.exact {
838        datetime_strings
839            .as_datetime(
840                options.format.as_deref(),
841                *time_unit,
842                options.cache,
843                tz_aware,
844                time_zone,
845                ambiguous,
846            )?
847            .into_column()
848    } else {
849        datetime_strings
850            .as_datetime_not_exact(
851                options.format.as_deref(),
852                *time_unit,
853                tz_aware,
854                time_zone,
855                ambiguous,
856            )?
857            .into_column()
858    };
859
860    if options.strict && datetime_strings.null_count() != out.null_count() {
861        handle_casting_failures(s[0].as_materialized_series(), out.as_materialized_series())?;
862    }
863    Ok(out.into_column())
864}
865
866#[cfg(feature = "dtype-time")]
867fn to_time(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
868    polars_ensure!(
869        options.exact, ComputeError: "non-exact not implemented for Time data type"
870    );
871
872    let ca = s.str()?;
873    let out = ca
874        .as_time(options.format.as_deref(), options.cache)?
875        .into_column();
876
877    if options.strict && ca.null_count() != out.null_count() {
878        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
879    }
880    Ok(out.into_column())
881}
882
883#[cfg(feature = "concat_str")]
884pub(super) fn join(s: &Column, delimiter: &str, ignore_nulls: bool) -> PolarsResult<Column> {
885    let str_s = s.cast(&DataType::String)?;
886    let joined = polars_ops::chunked_array::str_join(str_s.str()?, delimiter, ignore_nulls);
887    Ok(joined.into_column())
888}
889
890#[cfg(feature = "concat_str")]
891pub(super) fn concat_hor(
892    series: &[Column],
893    delimiter: &str,
894    ignore_nulls: bool,
895) -> PolarsResult<Column> {
896    let str_series: Vec<_> = series
897        .iter()
898        .map(|s| s.cast(&DataType::String))
899        .collect::<PolarsResult<_>>()?;
900    let cas: Vec<_> = str_series.iter().map(|s| s.str().unwrap()).collect();
901    Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter, ignore_nulls)?.into_column())
902}
903
904impl From<StringFunction> for FunctionExpr {
905    fn from(str: StringFunction) -> Self {
906        FunctionExpr::StringExpr(str)
907    }
908}
909
910#[cfg(feature = "regex")]
911fn get_pat(pat: &StringChunked) -> PolarsResult<&str> {
912    pat.get(0).ok_or_else(
913        || polars_err!(ComputeError: "pattern cannot be 'null' in 'replace' expression"),
914    )
915}
916
917// used only if feature="regex"
918#[allow(dead_code)]
919fn iter_and_replace<'a, F>(ca: &'a StringChunked, val: &'a StringChunked, f: F) -> StringChunked
920where
921    F: Fn(&'a str, &'a str) -> Cow<'a, str>,
922{
923    let mut out: StringChunked = ca
924        .into_iter()
925        .zip(val)
926        .map(|(opt_src, opt_val)| match (opt_src, opt_val) {
927            (Some(src), Some(val)) => Some(f(src, val)),
928            _ => None,
929        })
930        .collect_trusted();
931
932    out.rename(ca.name().clone());
933    out
934}
935
936#[cfg(feature = "regex")]
937fn is_literal_pat(pat: &str) -> bool {
938    pat.chars().all(|c| !c.is_ascii_punctuation())
939}
940
941#[cfg(feature = "regex")]
942fn replace_n<'a>(
943    ca: &'a StringChunked,
944    pat: &'a StringChunked,
945    val: &'a StringChunked,
946    literal: bool,
947    n: usize,
948) -> PolarsResult<StringChunked> {
949    match (pat.len(), val.len()) {
950        (1, 1) => {
951            let pat = get_pat(pat)?;
952            let val = val.get(0).ok_or_else(
953                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
954            )?;
955            let literal = literal || is_literal_pat(pat);
956
957            match literal {
958                true => ca.replace_literal(pat, val, n),
959                false => {
960                    if n > 1 {
961                        polars_bail!(ComputeError: "regex replacement with 'n > 1' not yet supported")
962                    }
963                    ca.replace(pat, val)
964                },
965            }
966        },
967        (1, len_val) => {
968            if n > 1 {
969                polars_bail!(ComputeError: "multivalue replacement with 'n > 1' not yet supported")
970            }
971            let mut pat = get_pat(pat)?.to_string();
972            polars_ensure!(
973                len_val == ca.len(),
974                ComputeError:
975                "replacement value length ({}) does not match string column length ({})",
976                len_val, ca.len(),
977            );
978            let lit = is_literal_pat(&pat);
979            let literal_pat = literal || lit;
980
981            if literal_pat {
982                pat = escape(&pat)
983            }
984
985            let reg = polars_utils::regex_cache::compile_regex(&pat)?;
986
987            let f = |s: &'a str, val: &'a str| {
988                if lit && (s.len() <= 32) {
989                    Cow::Owned(s.replacen(&pat, val, 1))
990                } else {
991                    // According to the docs for replace
992                    // when literal = True then capture groups are ignored.
993                    if literal {
994                        reg.replace(s, NoExpand(val))
995                    } else {
996                        reg.replace(s, val)
997                    }
998                }
999            };
1000            Ok(iter_and_replace(ca, val, f))
1001        },
1002        _ => polars_bail!(
1003            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
1004        ),
1005    }
1006}
1007
1008#[cfg(feature = "regex")]
1009fn replace_all<'a>(
1010    ca: &'a StringChunked,
1011    pat: &'a StringChunked,
1012    val: &'a StringChunked,
1013    literal: bool,
1014) -> PolarsResult<StringChunked> {
1015    match (pat.len(), val.len()) {
1016        (1, 1) => {
1017            let pat = get_pat(pat)?;
1018            let val = val.get(0).ok_or_else(
1019                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
1020            )?;
1021            let literal = literal || is_literal_pat(pat);
1022
1023            match literal {
1024                true => ca.replace_literal_all(pat, val),
1025                false => ca.replace_all(pat, val),
1026            }
1027        },
1028        (1, len_val) => {
1029            let mut pat = get_pat(pat)?.to_string();
1030            polars_ensure!(
1031                len_val == ca.len(),
1032                ComputeError:
1033                "replacement value length ({}) does not match string column length ({})",
1034                len_val, ca.len(),
1035            );
1036
1037            let literal_pat = literal || is_literal_pat(&pat);
1038
1039            if literal_pat {
1040                pat = escape(&pat)
1041            }
1042
1043            let reg = polars_utils::regex_cache::compile_regex(&pat)?;
1044
1045            let f = |s: &'a str, val: &'a str| {
1046                // According to the docs for replace_all
1047                // when literal = True then capture groups are ignored.
1048                if literal {
1049                    reg.replace_all(s, NoExpand(val))
1050                } else {
1051                    reg.replace_all(s, val)
1052                }
1053            };
1054
1055            Ok(iter_and_replace(ca, val, f))
1056        },
1057        _ => polars_bail!(
1058            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
1059        ),
1060    }
1061}
1062
1063#[cfg(feature = "regex")]
1064pub(super) fn replace(s: &[Column], literal: bool, n: i64) -> PolarsResult<Column> {
1065    let column = &s[0];
1066    let pat = &s[1];
1067    let val = &s[2];
1068    let all = n < 0;
1069
1070    let column = column.str()?;
1071    let pat = pat.str()?;
1072    let val = val.str()?;
1073
1074    if all {
1075        replace_all(column, pat, val, literal)
1076    } else {
1077        replace_n(column, pat, val, literal, n as usize)
1078    }
1079    .map(|ca| ca.into_column())
1080}
1081
1082#[cfg(feature = "string_normalize")]
1083pub(super) fn normalize(s: &Column, form: UnicodeForm) -> PolarsResult<Column> {
1084    let ca = s.str()?;
1085    Ok(ca.str_normalize(form).into_column())
1086}
1087
1088#[cfg(feature = "string_reverse")]
1089pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
1090    let ca = s.str()?;
1091    Ok(ca.str_reverse().into_column())
1092}
1093
1094#[cfg(feature = "string_to_integer")]
1095pub(super) fn to_integer(s: &[Column], strict: bool) -> PolarsResult<Column> {
1096    let ca = s[0].str()?;
1097    let base = s[1].strict_cast(&DataType::UInt32)?;
1098    ca.to_integer(base.u32()?, strict)
1099        .map(|ok| ok.into_column())
1100}
1101
1102fn _ensure_lengths(s: &[Column]) -> bool {
1103    // Calculate the post-broadcast length and ensure everything is consistent.
1104    let len = s
1105        .iter()
1106        .map(|series| series.len())
1107        .filter(|l| *l != 1)
1108        .max()
1109        .unwrap_or(1);
1110    s.iter()
1111        .all(|series| series.len() == 1 || series.len() == len)
1112}
1113
1114fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> {
1115    polars_ensure!(
1116        _ensure_lengths(s),
1117        ShapeMismatch: "all series in `str.{}()` should have equal or unit length",
1118        fn_name
1119    );
1120    Ok(())
1121}
1122
1123pub(super) fn str_slice(s: &[Column]) -> PolarsResult<Column> {
1124    _check_same_length(s, "slice")?;
1125    let ca = s[0].str()?;
1126    let offset = &s[1];
1127    let length = &s[2];
1128    Ok(ca.str_slice(offset, length)?.into_column())
1129}
1130
1131pub(super) fn str_head(s: &[Column]) -> PolarsResult<Column> {
1132    _check_same_length(s, "head")?;
1133    let ca = s[0].str()?;
1134    let n = &s[1];
1135    Ok(ca.str_head(n)?.into_column())
1136}
1137
1138pub(super) fn str_tail(s: &[Column]) -> PolarsResult<Column> {
1139    _check_same_length(s, "tail")?;
1140    let ca = s[0].str()?;
1141    let n = &s[1];
1142    Ok(ca.str_tail(n)?.into_column())
1143}
1144
1145#[cfg(feature = "string_encoding")]
1146pub(super) fn hex_encode(s: &Column) -> PolarsResult<Column> {
1147    Ok(s.str()?.hex_encode().into_column())
1148}
1149
1150#[cfg(feature = "binary_encoding")]
1151pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1152    s.str()?.hex_decode(strict).map(|ca| ca.into_column())
1153}
1154
1155#[cfg(feature = "string_encoding")]
1156pub(super) fn base64_encode(s: &Column) -> PolarsResult<Column> {
1157    Ok(s.str()?.base64_encode().into_column())
1158}
1159
1160#[cfg(feature = "binary_encoding")]
1161pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1162    s.str()?.base64_decode(strict).map(|ca| ca.into_column())
1163}
1164
1165#[cfg(feature = "dtype-decimal")]
1166pub(super) fn to_decimal(s: &Column, infer_len: usize) -> PolarsResult<Column> {
1167    let ca = s.str()?;
1168    ca.to_decimal(infer_len).map(Column::from)
1169}
1170
1171#[cfg(feature = "extract_jsonpath")]
1172pub(super) fn json_decode(
1173    s: &Column,
1174    dtype: Option<DataType>,
1175    infer_schema_len: Option<usize>,
1176) -> PolarsResult<Column> {
1177    let ca = s.str()?;
1178    ca.json_decode(dtype, infer_schema_len).map(Column::from)
1179}
1180
1181#[cfg(feature = "extract_jsonpath")]
1182pub(super) fn json_path_match(s: &[Column]) -> PolarsResult<Column> {
1183    _check_same_length(s, "json_path_match")?;
1184    let ca = s[0].str()?;
1185    let pat = s[1].str()?;
1186    Ok(ca.json_path_match(pat)?.into_column())
1187}
1188
1189#[cfg(feature = "regex")]
1190pub(super) fn escape_regex(s: &Column) -> PolarsResult<Column> {
1191    let ca = s.str()?;
1192    Ok(ca.str_escape_regex().into_column())
1193}