Skip to main content

lance_datafusion/
udf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Datafusion user defined functions
5
6use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{ScalarUDF, Volatility, create_udf};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15/// Register UDF functions to datafusion context.
16pub fn register_functions(ctx: &SessionContext) {
17    ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18    // JSON functions
19    ctx.register_udf(json::json_extract_udf());
20    ctx.register_udf(json::json_extract_with_type_udf());
21    ctx.register_udf(json::json_exists_udf());
22    ctx.register_udf(json::json_get_udf());
23    ctx.register_udf(json::json_get_string_udf());
24    ctx.register_udf(json::json_get_int_udf());
25    ctx.register_udf(json::json_get_float_udf());
26    ctx.register_udf(json::json_get_bool_udf());
27    ctx.register_udf(json::json_array_contains_udf());
28    ctx.register_udf(json::json_array_length_udf());
29    // GEO functions
30    #[cfg(feature = "geo")]
31    lance_geo::register_functions(ctx);
32    #[cfg(not(feature = "geo"))]
33    register_geo_stub_functions(ctx);
34}
35
36/// When the `geo` feature is disabled, register stub UDFs for spatial SQL functions
37/// so that users get a clear error mentioning the feature flag instead of
38/// DataFusion's generic "Unknown function" error.
39#[cfg(not(feature = "geo"))]
40fn register_geo_stub_functions(ctx: &SessionContext) {
41    let geo_funcs = [
42        "st_intersects",
43        "st_contains",
44        "st_within",
45        "st_touches",
46        "st_crosses",
47        "st_overlaps",
48        "st_covers",
49        "st_coveredby",
50        "st_distance",
51        "st_area",
52        "st_length",
53    ];
54
55    for name in geo_funcs {
56        let func_name = name.to_string();
57        let stub = Arc::new(make_scalar_function(
58            move |_args: &[ArrayRef]| {
59                Err(datafusion::error::DataFusionError::Plan(format!(
60                    "Function '{}' requires the `geo` feature. \
61                     Rebuild with `--features geo` to enable geospatial functions.",
62                    func_name
63                )))
64            },
65            vec![],
66        ));
67
68        ctx.register_udf(create_udf(
69            name,
70            vec![DataType::Binary, DataType::Binary],
71            DataType::Boolean,
72            Volatility::Immutable,
73            stub,
74        ));
75    }
76}
77
78/// This method checks whether a string contains all specified tokens. The tokens are separated by
79/// punctuations and white spaces.
80///
81/// The functionality is equivalent to FTS MatchQuery (with fuzziness disabled, Operator::And,
82/// and using the simple tokenizer). If FTS index exists and suites the query, it will be used to
83/// optimize the query.
84///
85/// Usage
86/// * Use `contains_tokens` in sql.
87/// ```rust,ignore
88/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'fox jumps dog')";
89/// let mut ds = Dataset::open(&ds_path).await?;
90/// let ctx = SessionContext::new();
91/// ctx.register_table(
92///     "table",
93///     Arc::new(LanceTableProvider::new(dataset, false, false)),
94/// )?;
95/// register_functions(&ctx);
96/// let df = ctx.sql(sql).await?;
97/// ```
98fn contains_tokens() -> ScalarUDF {
99    let function = Arc::new(make_scalar_function(
100        |args: &[ArrayRef]| {
101            let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
102                datafusion::error::DataFusionError::Execution(
103                    "First argument of contains_tokens can't be cast to string".to_string(),
104                ),
105            )?;
106            let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
107                datafusion::error::DataFusionError::Execution(
108                    "Second argument of contains_tokens can't be cast to string".to_string(),
109                ),
110            )?;
111
112            let tokens: Option<Vec<&str>> = match scalar_str.len() {
113                0 => None,
114                _ => Some(collect_tokens(scalar_str.value(0))),
115            };
116
117            let result = column.iter().map(|text| {
118                text.map(|text| {
119                    let text_tokens = collect_tokens(text);
120                    if let Some(tokens) = &tokens {
121                        tokens.len()
122                            == tokens
123                                .iter()
124                                .filter(|token| text_tokens.contains(*token))
125                                .count()
126                    } else {
127                        true
128                    }
129                })
130            });
131
132            Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
133        },
134        vec![],
135    ));
136
137    create_udf(
138        "contains_tokens",
139        vec![DataType::Utf8, DataType::Utf8],
140        DataType::Boolean,
141        Volatility::Immutable,
142        function,
143    )
144}
145
146/// Split tokens separated by punctuations and white spaces.
147fn collect_tokens(text: &str) -> Vec<&str> {
148    text.split(|c: char| !c.is_alphanumeric())
149        .filter(|word| !word.is_empty())
150        .collect()
151}
152
153pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
154
155#[cfg(test)]
156mod tests {
157    use crate::udf::CONTAINS_TOKENS_UDF;
158    use arrow_array::{Array, BooleanArray, StringArray};
159    use arrow_schema::{DataType, Field};
160    use datafusion::logical_expr::ScalarFunctionArgs;
161    use datafusion::physical_plan::ColumnarValue;
162    use std::sync::Arc;
163
164    #[tokio::test]
165    async fn test_contains_tokens() {
166        // Prepare arguments
167        let contains_tokens = CONTAINS_TOKENS_UDF.clone();
168        let text_col = Arc::new(StringArray::from(vec![
169            "a cat catch a fish",
170            "a fish catch a cat",
171            "a white cat catch a big fish",
172            "cat catchup fish",
173            "cat fish catch",
174        ]));
175        let token = Arc::new(StringArray::from(vec![
176            " cat catch fish.",
177            " cat catch fish.",
178            " cat catch fish.",
179            " cat catch fish.",
180            " cat catch fish.",
181        ]));
182
183        let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
184        let arg_fields = vec![
185            Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
186            Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
187        ];
188
189        let args = ScalarFunctionArgs {
190            args,
191            arg_fields,
192            number_rows: 5,
193            return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
194            config_options: Arc::new(Default::default()),
195        };
196
197        // Invoke contains_tokens manually
198        let values = contains_tokens.invoke_with_args(args).unwrap();
199
200        if let ColumnarValue::Array(array) = values {
201            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
202            assert_eq!(
203                array.clone(),
204                BooleanArray::from(vec![true, true, true, false, true])
205            );
206        } else {
207            panic!("Expected an Array but got {:?}", values);
208        }
209    }
210}