1use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{ScalarUDF, Volatility, create_udf};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15pub fn register_functions(ctx: &SessionContext) {
17 ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18 ctx.register_udf(json::json_extract_udf());
20 ctx.register_udf(json::json_extract_with_type_udf());
21 ctx.register_udf(json::json_exists_udf());
22 ctx.register_udf(json::json_get_udf());
23 ctx.register_udf(json::json_get_string_udf());
24 ctx.register_udf(json::json_get_int_udf());
25 ctx.register_udf(json::json_get_float_udf());
26 ctx.register_udf(json::json_get_bool_udf());
27 ctx.register_udf(json::json_array_contains_udf());
28 ctx.register_udf(json::json_array_length_udf());
29 #[cfg(feature = "geo")]
31 lance_geo::register_functions(ctx);
32 #[cfg(not(feature = "geo"))]
33 register_geo_stub_functions(ctx);
34}
35
36#[cfg(not(feature = "geo"))]
40fn register_geo_stub_functions(ctx: &SessionContext) {
41 let geo_funcs = [
42 "st_intersects",
43 "st_contains",
44 "st_within",
45 "st_touches",
46 "st_crosses",
47 "st_overlaps",
48 "st_covers",
49 "st_coveredby",
50 "st_distance",
51 "st_area",
52 "st_length",
53 ];
54
55 for name in geo_funcs {
56 let func_name = name.to_string();
57 let stub = Arc::new(make_scalar_function(
58 move |_args: &[ArrayRef]| {
59 Err(datafusion::error::DataFusionError::Plan(format!(
60 "Function '{}' requires the `geo` feature. \
61 Rebuild with `--features geo` to enable geospatial functions.",
62 func_name
63 )))
64 },
65 vec![],
66 ));
67
68 ctx.register_udf(create_udf(
69 name,
70 vec![DataType::Binary, DataType::Binary],
71 DataType::Boolean,
72 Volatility::Immutable,
73 stub,
74 ));
75 }
76}
77
78fn contains_tokens() -> ScalarUDF {
99 let function = Arc::new(make_scalar_function(
100 |args: &[ArrayRef]| {
101 let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
102 datafusion::error::DataFusionError::Execution(
103 "First argument of contains_tokens can't be cast to string".to_string(),
104 ),
105 )?;
106 let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
107 datafusion::error::DataFusionError::Execution(
108 "Second argument of contains_tokens can't be cast to string".to_string(),
109 ),
110 )?;
111
112 let tokens: Option<Vec<&str>> = match scalar_str.len() {
113 0 => None,
114 _ => Some(collect_tokens(scalar_str.value(0))),
115 };
116
117 let result = column.iter().map(|text| {
118 text.map(|text| {
119 let text_tokens = collect_tokens(text);
120 if let Some(tokens) = &tokens {
121 tokens.len()
122 == tokens
123 .iter()
124 .filter(|token| text_tokens.contains(*token))
125 .count()
126 } else {
127 true
128 }
129 })
130 });
131
132 Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
133 },
134 vec![],
135 ));
136
137 create_udf(
138 "contains_tokens",
139 vec![DataType::Utf8, DataType::Utf8],
140 DataType::Boolean,
141 Volatility::Immutable,
142 function,
143 )
144}
145
146fn collect_tokens(text: &str) -> Vec<&str> {
148 text.split(|c: char| !c.is_alphanumeric())
149 .filter(|word| !word.is_empty())
150 .collect()
151}
152
153pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
154
155#[cfg(test)]
156mod tests {
157 use crate::udf::CONTAINS_TOKENS_UDF;
158 use arrow_array::{Array, BooleanArray, StringArray};
159 use arrow_schema::{DataType, Field};
160 use datafusion::logical_expr::ScalarFunctionArgs;
161 use datafusion::physical_plan::ColumnarValue;
162 use std::sync::Arc;
163
164 #[tokio::test]
165 async fn test_contains_tokens() {
166 let contains_tokens = CONTAINS_TOKENS_UDF.clone();
168 let text_col = Arc::new(StringArray::from(vec![
169 "a cat catch a fish",
170 "a fish catch a cat",
171 "a white cat catch a big fish",
172 "cat catchup fish",
173 "cat fish catch",
174 ]));
175 let token = Arc::new(StringArray::from(vec![
176 " cat catch fish.",
177 " cat catch fish.",
178 " cat catch fish.",
179 " cat catch fish.",
180 " cat catch fish.",
181 ]));
182
183 let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
184 let arg_fields = vec![
185 Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
186 Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
187 ];
188
189 let args = ScalarFunctionArgs {
190 args,
191 arg_fields,
192 number_rows: 5,
193 return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
194 config_options: Arc::new(Default::default()),
195 };
196
197 let values = contains_tokens.invoke_with_args(args).unwrap();
199
200 if let ColumnarValue::Array(array) = values {
201 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
202 assert_eq!(
203 array.clone(),
204 BooleanArray::from(vec![true, true, true, false, true])
205 );
206 } else {
207 panic!("Expected an Array but got {:?}", values);
208 }
209 }
210}