datafusion_functions/string/
contains.rs1use arrow::array::{Array, ArrayRef, Scalar};
19use arrow::compute::contains as arrow_contains;
20use arrow::datatypes::DataType;
21use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
22use datafusion_common::types::logical_string;
23use datafusion_common::{Result, exec_err};
24use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
25use datafusion_expr::{
26 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27 TypeSignatureClass, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::sync::Arc;
31
32#[user_doc(
33 doc_section(label = "String Functions"),
34 description = "Return true if search_str is found within string (case-sensitive).",
35 syntax_example = "contains(str, search_str)",
36 sql_example = r#"```sql
37> select contains('the quick brown fox', 'row');
38+---------------------------------------------------+
39| contains(Utf8("the quick brown fox"),Utf8("row")) |
40+---------------------------------------------------+
41| true |
42+---------------------------------------------------+
43```"#,
44 standard_argument(name = "str", prefix = "String"),
45 argument(name = "search_str", description = "The string to search for in str.")
46)]
47#[derive(Debug, PartialEq, Eq, Hash)]
48pub struct ContainsFunc {
49 signature: Signature,
50}
51
52impl Default for ContainsFunc {
53 fn default() -> Self {
54 ContainsFunc::new()
55 }
56}
57
58impl ContainsFunc {
59 pub fn new() -> Self {
60 Self {
61 signature: Signature::coercible(
62 vec![
63 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
64 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
65 ],
66 Volatility::Immutable,
67 ),
68 }
69 }
70}
71
72impl ScalarUDFImpl for ContainsFunc {
73 fn name(&self) -> &str {
74 "contains"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
82 Ok(Boolean)
83 }
84
85 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86 contains(args.args.as_slice())
87 }
88
89 fn documentation(&self) -> Option<&Documentation> {
90 self.doc()
91 }
92}
93
94fn to_array(value: &ColumnarValue) -> Result<(ArrayRef, bool)> {
95 match value {
96 ColumnarValue::Array(array) => Ok((Arc::clone(array), false)),
97 ColumnarValue::Scalar(scalar) => Ok((scalar.to_array()?, true)),
98 }
99}
100
101fn call_arrow_contains(
105 haystack: &ArrayRef,
106 haystack_is_scalar: bool,
107 needle: &ArrayRef,
108 needle_is_scalar: bool,
109) -> Result<ColumnarValue> {
110 let result = match (haystack_is_scalar, needle_is_scalar) {
113 (false, false) => arrow_contains(haystack, needle)?,
114 (false, true) => arrow_contains(haystack, &Scalar::new(Arc::clone(needle)))?,
115 (true, false) => arrow_contains(&Scalar::new(Arc::clone(haystack)), needle)?,
116 (true, true) => arrow_contains(
117 &Scalar::new(Arc::clone(haystack)),
118 &Scalar::new(Arc::clone(needle)),
119 )?,
120 };
121
122 if haystack_is_scalar && needle_is_scalar {
124 let scalar = datafusion_common::ScalarValue::try_from_array(&result, 0)?;
125 Ok(ColumnarValue::Scalar(scalar))
126 } else {
127 Ok(ColumnarValue::Array(Arc::new(result)))
128 }
129}
130
131fn contains(args: &[ColumnarValue]) -> Result<ColumnarValue> {
133 let (haystack, haystack_is_scalar) = to_array(&args[0])?;
134 let (needle, needle_is_scalar) = to_array(&args[1])?;
135
136 if let Some(coercion_data_type) =
137 string_coercion(haystack.data_type(), needle.data_type()).or_else(|| {
138 binary_to_string_coercion(haystack.data_type(), needle.data_type())
139 })
140 {
141 let haystack = if haystack.data_type() == &coercion_data_type {
142 haystack
143 } else {
144 arrow::compute::kernels::cast::cast(&haystack, &coercion_data_type)?
145 };
146 let needle = if needle.data_type() == &coercion_data_type {
147 needle
148 } else {
149 arrow::compute::kernels::cast::cast(&needle, &coercion_data_type)?
150 };
151
152 match coercion_data_type {
153 Utf8View | Utf8 | LargeUtf8 => call_arrow_contains(
154 &haystack,
155 haystack_is_scalar,
156 &needle,
157 needle_is_scalar,
158 ),
159 other => {
160 exec_err!("Unsupported data type {other:?} for function `contains`.")
161 }
162 }
163 } else {
164 exec_err!(
165 "Unsupported data type {}, {:?} for function `contains`.",
166 args[0].data_type(),
167 args[1].data_type()
168 )
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::ContainsFunc;
175 use crate::expr_fn::contains;
176 use arrow::array::{BooleanArray, StringArray};
177 use arrow::datatypes::{DataType, Field};
178 use datafusion_common::ScalarValue;
179 use datafusion_common::config::ConfigOptions;
180 use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl};
181 use std::sync::Arc;
182
183 #[test]
184 fn test_contains_udf() {
185 let udf = ContainsFunc::new();
186 let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
187 Some("xxx?()"),
188 Some("yyy?()"),
189 ])));
190 let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
191 let arg_fields = vec![
192 Field::new("a", DataType::Utf8, true).into(),
193 Field::new("a", DataType::Utf8, true).into(),
194 ];
195
196 let args = ScalarFunctionArgs {
197 args: vec![array, scalar],
198 arg_fields,
199 number_rows: 2,
200 return_field: Field::new("f", DataType::Boolean, true).into(),
201 config_options: Arc::new(ConfigOptions::default()),
202 };
203
204 let actual = udf.invoke_with_args(args).unwrap();
205 let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
206 Some(true),
207 Some(false),
208 ])));
209 assert_eq!(
210 *actual.into_array(2).unwrap(),
211 *expect.into_array(2).unwrap()
212 );
213 }
214
215 #[test]
216 fn test_contains_api() {
217 let expr = contains(
218 Expr::Literal(
219 ScalarValue::Utf8(Some("the quick brown fox".to_string())),
220 None,
221 ),
222 Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None),
223 );
224 assert_eq!(
225 expr.to_string(),
226 "contains(Utf8(\"the quick brown fox\"), Utf8(\"row\"))"
227 );
228 }
229}