1use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15pub fn register_functions(ctx: &SessionContext) {
17 ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18 ctx.register_udf(json::json_extract_udf());
20 ctx.register_udf(json::json_extract_with_type_udf());
21 ctx.register_udf(json::json_exists_udf());
22 ctx.register_udf(json::json_get_udf());
23 ctx.register_udf(json::json_get_string_udf());
24 ctx.register_udf(json::json_get_int_udf());
25 ctx.register_udf(json::json_get_float_udf());
26 ctx.register_udf(json::json_get_bool_udf());
27 ctx.register_udf(json::json_array_contains_udf());
28 ctx.register_udf(json::json_array_length_udf());
29}
30
31fn contains_tokens() -> ScalarUDF {
52 let function = Arc::new(make_scalar_function(
53 |args: &[ArrayRef]| {
54 let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
55 datafusion::error::DataFusionError::Execution(
56 "First argument of contains_tokens can't be cast to string".to_string(),
57 ),
58 )?;
59 let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
60 datafusion::error::DataFusionError::Execution(
61 "Second argument of contains_tokens can't be cast to string".to_string(),
62 ),
63 )?;
64
65 let tokens: Option<Vec<&str>> = match scalar_str.len() {
66 0 => None,
67 _ => Some(collect_tokens(scalar_str.value(0))),
68 };
69
70 let result = column.iter().map(|text| {
71 text.map(|text| {
72 let text_tokens = collect_tokens(text);
73 if let Some(tokens) = &tokens {
74 tokens.len()
75 == tokens
76 .iter()
77 .filter(|token| text_tokens.contains(*token))
78 .count()
79 } else {
80 true
81 }
82 })
83 });
84
85 Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
86 },
87 vec![],
88 ));
89
90 create_udf(
91 "contains_tokens",
92 vec![DataType::Utf8, DataType::Utf8],
93 DataType::Boolean,
94 Volatility::Immutable,
95 function,
96 )
97}
98
99fn collect_tokens(text: &str) -> Vec<&str> {
101 text.split(|c: char| !c.is_alphanumeric())
102 .filter(|word| !word.is_empty())
103 .collect()
104}
105
106pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
107
108#[cfg(test)]
109mod tests {
110 use crate::udf::CONTAINS_TOKENS_UDF;
111 use arrow_array::{Array, BooleanArray, StringArray};
112 use arrow_schema::{DataType, Field};
113 use datafusion::logical_expr::ScalarFunctionArgs;
114 use datafusion::physical_plan::ColumnarValue;
115 use std::sync::Arc;
116
117 #[tokio::test]
118 async fn test_contains_tokens() {
119 let contains_tokens = CONTAINS_TOKENS_UDF.clone();
121 let text_col = Arc::new(StringArray::from(vec![
122 "a cat catch a fish",
123 "a fish catch a cat",
124 "a white cat catch a big fish",
125 "cat catchup fish",
126 "cat fish catch",
127 ]));
128 let token = Arc::new(StringArray::from(vec![
129 " cat catch fish.",
130 " cat catch fish.",
131 " cat catch fish.",
132 " cat catch fish.",
133 " cat catch fish.",
134 ]));
135
136 let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
137 let arg_fields = vec![
138 Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
139 Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
140 ];
141
142 let args = ScalarFunctionArgs {
143 args,
144 arg_fields,
145 number_rows: 5,
146 return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
147 };
148
149 let values = contains_tokens.invoke_with_args(args).unwrap();
151
152 if let ColumnarValue::Array(array) = values {
153 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
154 assert_eq!(
155 array.clone(),
156 BooleanArray::from(vec![true, true, true, false, true])
157 );
158 } else {
159 panic!("Expected an Array but got {:?}", values);
160 }
161 }
162}