1use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15pub fn register_functions(ctx: &SessionContext) {
17 ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18 ctx.register_udf(json::json_extract_udf());
20 ctx.register_udf(json::json_exists_udf());
21 ctx.register_udf(json::json_get_udf());
22 ctx.register_udf(json::json_get_string_udf());
23 ctx.register_udf(json::json_get_int_udf());
24 ctx.register_udf(json::json_get_float_udf());
25 ctx.register_udf(json::json_get_bool_udf());
26 ctx.register_udf(json::json_array_contains_udf());
27 ctx.register_udf(json::json_array_length_udf());
28}
29
30fn contains_tokens() -> ScalarUDF {
51 let function = Arc::new(make_scalar_function(
52 |args: &[ArrayRef]| {
53 let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
54 datafusion::error::DataFusionError::Execution(
55 "First argument of contains_tokens can't be cast to string".to_string(),
56 ),
57 )?;
58 let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
59 datafusion::error::DataFusionError::Execution(
60 "Second argument of contains_tokens can't be cast to string".to_string(),
61 ),
62 )?;
63
64 let tokens: Option<Vec<&str>> = match scalar_str.len() {
65 0 => None,
66 _ => Some(collect_tokens(scalar_str.value(0))),
67 };
68
69 let result = column.iter().map(|text| {
70 text.map(|text| {
71 let text_tokens = collect_tokens(text);
72 if let Some(tokens) = &tokens {
73 tokens.len()
74 == tokens
75 .iter()
76 .filter(|token| text_tokens.contains(*token))
77 .count()
78 } else {
79 true
80 }
81 })
82 });
83
84 Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
85 },
86 vec![],
87 ));
88
89 create_udf(
90 "contains_tokens",
91 vec![DataType::Utf8, DataType::Utf8],
92 DataType::Boolean,
93 Volatility::Immutable,
94 function,
95 )
96}
97
98fn collect_tokens(text: &str) -> Vec<&str> {
100 text.split(|c: char| !c.is_alphanumeric())
101 .filter(|word| !word.is_empty())
102 .collect()
103}
104
105pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
106
107#[cfg(test)]
108mod tests {
109 use crate::udf::CONTAINS_TOKENS_UDF;
110 use arrow_array::{Array, BooleanArray, StringArray};
111 use arrow_schema::{DataType, Field};
112 use datafusion::logical_expr::ScalarFunctionArgs;
113 use datafusion::physical_plan::ColumnarValue;
114 use std::sync::Arc;
115
116 #[tokio::test]
117 async fn test_contains_tokens() {
118 let contains_tokens = CONTAINS_TOKENS_UDF.clone();
120 let text_col = Arc::new(StringArray::from(vec![
121 "a cat catch a fish",
122 "a fish catch a cat",
123 "a white cat catch a big fish",
124 "cat catchup fish",
125 "cat fish catch",
126 ]));
127 let token = Arc::new(StringArray::from(vec![
128 " cat catch fish.",
129 " cat catch fish.",
130 " cat catch fish.",
131 " cat catch fish.",
132 " cat catch fish.",
133 ]));
134
135 let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
136 let arg_fields = vec![
137 Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
138 Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
139 ];
140
141 let args = ScalarFunctionArgs {
142 args,
143 arg_fields,
144 number_rows: 5,
145 return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
146 };
147
148 let values = contains_tokens.invoke_with_args(args).unwrap();
150
151 if let ColumnarValue::Array(array) = values {
152 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
153 assert_eq!(
154 array.clone(),
155 BooleanArray::from(vec![true, true, true, false, true])
156 );
157 } else {
158 panic!("Expected an Array but got {:?}", values);
159 }
160 }
161}