1use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15pub fn register_functions(ctx: &SessionContext) {
17 ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18 ctx.register_udf(json::json_extract_udf());
20 ctx.register_udf(json::json_extract_with_type_udf());
21 ctx.register_udf(json::json_exists_udf());
22 ctx.register_udf(json::json_get_udf());
23 ctx.register_udf(json::json_get_string_udf());
24 ctx.register_udf(json::json_get_int_udf());
25 ctx.register_udf(json::json_get_float_udf());
26 ctx.register_udf(json::json_get_bool_udf());
27 ctx.register_udf(json::json_array_contains_udf());
28 ctx.register_udf(json::json_array_length_udf());
29 lance_geo::register_functions(ctx);
31}
32
33fn contains_tokens() -> ScalarUDF {
54 let function = Arc::new(make_scalar_function(
55 |args: &[ArrayRef]| {
56 let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
57 datafusion::error::DataFusionError::Execution(
58 "First argument of contains_tokens can't be cast to string".to_string(),
59 ),
60 )?;
61 let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
62 datafusion::error::DataFusionError::Execution(
63 "Second argument of contains_tokens can't be cast to string".to_string(),
64 ),
65 )?;
66
67 let tokens: Option<Vec<&str>> = match scalar_str.len() {
68 0 => None,
69 _ => Some(collect_tokens(scalar_str.value(0))),
70 };
71
72 let result = column.iter().map(|text| {
73 text.map(|text| {
74 let text_tokens = collect_tokens(text);
75 if let Some(tokens) = &tokens {
76 tokens.len()
77 == tokens
78 .iter()
79 .filter(|token| text_tokens.contains(*token))
80 .count()
81 } else {
82 true
83 }
84 })
85 });
86
87 Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
88 },
89 vec![],
90 ));
91
92 create_udf(
93 "contains_tokens",
94 vec![DataType::Utf8, DataType::Utf8],
95 DataType::Boolean,
96 Volatility::Immutable,
97 function,
98 )
99}
100
101fn collect_tokens(text: &str) -> Vec<&str> {
103 text.split(|c: char| !c.is_alphanumeric())
104 .filter(|word| !word.is_empty())
105 .collect()
106}
107
108pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
109
110#[cfg(test)]
111mod tests {
112 use crate::udf::CONTAINS_TOKENS_UDF;
113 use arrow_array::{Array, BooleanArray, StringArray};
114 use arrow_schema::{DataType, Field};
115 use datafusion::logical_expr::ScalarFunctionArgs;
116 use datafusion::physical_plan::ColumnarValue;
117 use std::sync::Arc;
118
119 #[tokio::test]
120 async fn test_contains_tokens() {
121 let contains_tokens = CONTAINS_TOKENS_UDF.clone();
123 let text_col = Arc::new(StringArray::from(vec![
124 "a cat catch a fish",
125 "a fish catch a cat",
126 "a white cat catch a big fish",
127 "cat catchup fish",
128 "cat fish catch",
129 ]));
130 let token = Arc::new(StringArray::from(vec![
131 " cat catch fish.",
132 " cat catch fish.",
133 " cat catch fish.",
134 " cat catch fish.",
135 " cat catch fish.",
136 ]));
137
138 let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
139 let arg_fields = vec![
140 Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
141 Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
142 ];
143
144 let args = ScalarFunctionArgs {
145 args,
146 arg_fields,
147 number_rows: 5,
148 return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
149 config_options: Arc::new(Default::default()),
150 };
151
152 let values = contains_tokens.invoke_with_args(args).unwrap();
154
155 if let ColumnarValue::Array(array) = values {
156 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
157 assert_eq!(
158 array.clone(),
159 BooleanArray::from(vec![true, true, true, false, true])
160 );
161 } else {
162 panic!("Expected an Array but got {:?}", values);
163 }
164 }
165}