lance_datafusion/
udf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Datafusion user defined functions
5
6use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13pub mod json;
14
15/// Register UDF functions to datafusion context.
16pub fn register_functions(ctx: &SessionContext) {
17    ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
18    // JSON functions
19    ctx.register_udf(json::json_extract_udf());
20    ctx.register_udf(json::json_extract_with_type_udf());
21    ctx.register_udf(json::json_exists_udf());
22    ctx.register_udf(json::json_get_udf());
23    ctx.register_udf(json::json_get_string_udf());
24    ctx.register_udf(json::json_get_int_udf());
25    ctx.register_udf(json::json_get_float_udf());
26    ctx.register_udf(json::json_get_bool_udf());
27    ctx.register_udf(json::json_array_contains_udf());
28    ctx.register_udf(json::json_array_length_udf());
29    // GEO functions
30    lance_geo::register_functions(ctx);
31}
32
33/// This method checks whether a string contains all specified tokens. The tokens are separated by
34/// punctuations and white spaces.
35///
36/// The functionality is equivalent to FTS MatchQuery (with fuzziness disabled, Operator::And,
37/// and using the simple tokenizer). If FTS index exists and suites the query, it will be used to
38/// optimize the query.
39///
40/// Usage
41/// * Use `contains_tokens` in sql.
42/// ```rust,ignore
43/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'fox jumps dog')";
44/// let mut ds = Dataset::open(&ds_path).await?;
45/// let ctx = SessionContext::new();
46/// ctx.register_table(
47///     "table",
48///     Arc::new(LanceTableProvider::new(dataset, false, false)),
49/// )?;
50/// register_functions(&ctx);
51/// let df = ctx.sql(sql).await?;
52/// ```
53fn contains_tokens() -> ScalarUDF {
54    let function = Arc::new(make_scalar_function(
55        |args: &[ArrayRef]| {
56            let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
57                datafusion::error::DataFusionError::Execution(
58                    "First argument of contains_tokens can't be cast to string".to_string(),
59                ),
60            )?;
61            let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
62                datafusion::error::DataFusionError::Execution(
63                    "Second argument of contains_tokens can't be cast to string".to_string(),
64                ),
65            )?;
66
67            let tokens: Option<Vec<&str>> = match scalar_str.len() {
68                0 => None,
69                _ => Some(collect_tokens(scalar_str.value(0))),
70            };
71
72            let result = column.iter().map(|text| {
73                text.map(|text| {
74                    let text_tokens = collect_tokens(text);
75                    if let Some(tokens) = &tokens {
76                        tokens.len()
77                            == tokens
78                                .iter()
79                                .filter(|token| text_tokens.contains(*token))
80                                .count()
81                    } else {
82                        true
83                    }
84                })
85            });
86
87            Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
88        },
89        vec![],
90    ));
91
92    create_udf(
93        "contains_tokens",
94        vec![DataType::Utf8, DataType::Utf8],
95        DataType::Boolean,
96        Volatility::Immutable,
97        function,
98    )
99}
100
101/// Split tokens separated by punctuations and white spaces.
102fn collect_tokens(text: &str) -> Vec<&str> {
103    text.split(|c: char| !c.is_alphanumeric())
104        .filter(|word| !word.is_empty())
105        .collect()
106}
107
108pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
109
110#[cfg(test)]
111mod tests {
112    use crate::udf::CONTAINS_TOKENS_UDF;
113    use arrow_array::{Array, BooleanArray, StringArray};
114    use arrow_schema::{DataType, Field};
115    use datafusion::logical_expr::ScalarFunctionArgs;
116    use datafusion::physical_plan::ColumnarValue;
117    use std::sync::Arc;
118
119    #[tokio::test]
120    async fn test_contains_tokens() {
121        // Prepare arguments
122        let contains_tokens = CONTAINS_TOKENS_UDF.clone();
123        let text_col = Arc::new(StringArray::from(vec![
124            "a cat catch a fish",
125            "a fish catch a cat",
126            "a white cat catch a big fish",
127            "cat catchup fish",
128            "cat fish catch",
129        ]));
130        let token = Arc::new(StringArray::from(vec![
131            " cat catch fish.",
132            " cat catch fish.",
133            " cat catch fish.",
134            " cat catch fish.",
135            " cat catch fish.",
136        ]));
137
138        let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
139        let arg_fields = vec![
140            Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
141            Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
142        ];
143
144        let args = ScalarFunctionArgs {
145            args,
146            arg_fields,
147            number_rows: 5,
148            return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
149            config_options: Arc::new(Default::default()),
150        };
151
152        // Invoke contains_tokens manually
153        let values = contains_tokens.invoke_with_args(args).unwrap();
154
155        if let ColumnarValue::Array(array) = values {
156            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
157            assert_eq!(
158                array.clone(),
159                BooleanArray::from(vec![true, true, true, false, true])
160            );
161        } else {
162            panic!("Expected an Array but got {:?}", values);
163        }
164    }
165}