1use arrow_array::{ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub fn register_functions(ctx: &SessionContext) {
15 ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
16}
17
18fn contains_tokens() -> ScalarUDF {
34 let function = Arc::new(make_scalar_function(
35 |args: &[ArrayRef]| {
36 let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
37 datafusion::error::DataFusionError::Execution(
38 "First argument of contains_tokens can't be cast to string".to_string(),
39 ),
40 )?;
41 let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
42 datafusion::error::DataFusionError::Execution(
43 "Second argument of contains_tokens can't be cast to string".to_string(),
44 ),
45 )?;
46
47 let result = column
48 .iter()
49 .enumerate()
50 .map(|(i, column)| column.map(|value| value.contains(scalar_str.value(i))));
51
52 Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
53 },
54 vec![],
55 ));
56
57 create_udf(
58 "contains_tokens",
59 vec![DataType::Utf8, DataType::Utf8],
60 DataType::Boolean,
61 Volatility::Immutable,
62 function,
63 )
64}
65
66static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
67
68#[cfg(test)]
69mod tests {
70 use crate::udf::CONTAINS_TOKENS_UDF;
71 use arrow_array::{Array, BooleanArray, StringArray};
72 use arrow_schema::{DataType, Field};
73 use datafusion::logical_expr::ScalarFunctionArgs;
74 use datafusion::physical_plan::ColumnarValue;
75 use std::sync::Arc;
76
77 #[tokio::test]
78 async fn test_contains_tokens() {
79 let contains_tokens = CONTAINS_TOKENS_UDF.clone();
81 let text_col = Arc::new(StringArray::from(vec![
82 "a cat",
83 "lovely cat",
84 "white cat",
85 "catch up",
86 "fish",
87 ]));
88 let token = Arc::new(StringArray::from(vec!["cat", "cat", "cat", "cat", "cat"]));
89
90 let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
91 let arg_fields = vec![
92 Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
93 Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
94 ];
95
96 let args = ScalarFunctionArgs {
97 args,
98 arg_fields,
99 number_rows: 5,
100 return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
101 };
102
103 let values = contains_tokens.invoke_with_args(args).unwrap();
105
106 if let ColumnarValue::Array(array) = values {
107 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
108 assert_eq!(
109 array.clone(),
110 BooleanArray::from(vec![true, true, true, true, false])
111 );
112 } else {
113 panic!("Expected an Array but got {:?}", values);
114 }
115 }
116}