datafusion_functions/string/
contains.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::make_scalar_function;
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::contains as arrow_contains;
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23use datafusion_common::types::logical_string;
24use datafusion_common::{exec_err, DataFusionError, Result};
25use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
26use datafusion_expr::{
27    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    TypeSignatureClass, Volatility,
29};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Return true if search_str is found within string (case-sensitive).",
37    syntax_example = "contains(str, search_str)",
38    sql_example = r#"```sql
39> select contains('the quick brown fox', 'row');
40+---------------------------------------------------+
41| contains(Utf8("the quick brown fox"),Utf8("row")) |
42+---------------------------------------------------+
43| true                                              |
44+---------------------------------------------------+
45```"#,
46    standard_argument(name = "str", prefix = "String"),
47    argument(name = "search_str", description = "The string to search for in str.")
48)]
49#[derive(Debug)]
50pub struct ContainsFunc {
51    signature: Signature,
52}
53
54impl Default for ContainsFunc {
55    fn default() -> Self {
56        ContainsFunc::new()
57    }
58}
59
60impl ContainsFunc {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::coercible(
64                vec![
65                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
66                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
67                ],
68                Volatility::Immutable,
69            ),
70        }
71    }
72}
73
74impl ScalarUDFImpl for ContainsFunc {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    fn name(&self) -> &str {
80        "contains"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
88        Ok(Boolean)
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        make_scalar_function(contains, vec![])(&args.args)
93    }
94
95    fn documentation(&self) -> Option<&Documentation> {
96        self.doc()
97    }
98}
99
100/// use `arrow::compute::contains` to do the calculation for contains
101fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
102    if let Some(coercion_data_type) =
103        string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
104            binary_to_string_coercion(args[0].data_type(), args[1].data_type())
105        })
106    {
107        let arg0 = if args[0].data_type() == &coercion_data_type {
108            Arc::clone(&args[0])
109        } else {
110            arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
111        };
112        let arg1 = if args[1].data_type() == &coercion_data_type {
113            Arc::clone(&args[1])
114        } else {
115            arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
116        };
117
118        match coercion_data_type {
119            Utf8View => {
120                let mod_str = arg0.as_string_view();
121                let match_str = arg1.as_string_view();
122                let res = arrow_contains(mod_str, match_str)?;
123                Ok(Arc::new(res) as ArrayRef)
124            }
125            Utf8 => {
126                let mod_str = arg0.as_string::<i32>();
127                let match_str = arg1.as_string::<i32>();
128                let res = arrow_contains(mod_str, match_str)?;
129                Ok(Arc::new(res) as ArrayRef)
130            }
131            LargeUtf8 => {
132                let mod_str = arg0.as_string::<i64>();
133                let match_str = arg1.as_string::<i64>();
134                let res = arrow_contains(mod_str, match_str)?;
135                Ok(Arc::new(res) as ArrayRef)
136            }
137            other => {
138                exec_err!("Unsupported data type {other:?} for function `contains`.")
139            }
140        }
141    } else {
142        exec_err!(
143            "Unsupported data type {:?}, {:?} for function `contains`.",
144            args[0].data_type(),
145            args[1].data_type()
146        )
147    }
148}
149
150#[cfg(test)]
151mod test {
152    use super::ContainsFunc;
153    use crate::expr_fn::contains;
154    use arrow::array::{BooleanArray, StringArray};
155    use arrow::datatypes::{DataType, Field};
156    use datafusion_common::ScalarValue;
157    use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl};
158    use std::sync::Arc;
159
160    #[test]
161    fn test_contains_udf() {
162        let udf = ContainsFunc::new();
163        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
164            Some("xxx?()"),
165            Some("yyy?()"),
166        ])));
167        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
168        let arg_fields = vec![
169            Field::new("a", DataType::Utf8, true).into(),
170            Field::new("a", DataType::Utf8, true).into(),
171        ];
172
173        let args = ScalarFunctionArgs {
174            args: vec![array, scalar],
175            arg_fields,
176            number_rows: 2,
177            return_field: Field::new("f", DataType::Boolean, true).into(),
178        };
179
180        let actual = udf.invoke_with_args(args).unwrap();
181        let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
182            Some(true),
183            Some(false),
184        ])));
185        assert_eq!(
186            *actual.into_array(2).unwrap(),
187            *expect.into_array(2).unwrap()
188        );
189    }
190
191    #[test]
192    fn test_contains_api() {
193        let expr = contains(
194            Expr::Literal(
195                ScalarValue::Utf8(Some("the quick brown fox".to_string())),
196                None,
197            ),
198            Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None),
199        );
200        assert_eq!(
201            expr.to_string(),
202            "contains(Utf8(\"the quick brown fox\"), Utf8(\"row\"))"
203        );
204    }
205}