lance_datafusion/
udf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Datafusion user defined functions
5
6use arrow_array::{ArrayRef, BooleanArray, StringArray};
7use arrow_schema::DataType;
8use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
9use datafusion::prelude::SessionContext;
10use datafusion_functions::utils::make_scalar_function;
11use std::sync::{Arc, LazyLock};
12
13/// Register UDF functions to datafusion context.
14pub fn register_functions(ctx: &SessionContext) {
15    ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
16}
17
18/// This method checks whether a string contains another string. It utilizes FTS (Full-Text Search)
19/// indexes, but due to the false negative characteristic of FTS, the results may have omissions.
20/// For example, "bakin" will not match documents containing "baking."
21/// If the query string is a whole word, or if you prioritize better performance, `contains_tokens`
22/// is the better choice. Otherwise, you can use the `contains` method to obtain accurate results.
23///
24///
25/// Usage
26/// * Use `contains_tokens` in sql.
27/// ```rust,ignore
28/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'bakin')"
29/// let mut ds = Dataset::open(&ds_path).await?;
30/// let mut builder = ds.sql(&sql);
31/// let records = builder.clone().build().await?.into_batch_records().await?;
32/// ```
33fn contains_tokens() -> ScalarUDF {
34    let function = Arc::new(make_scalar_function(
35        |args: &[ArrayRef]| {
36            let column = args[0].as_any().downcast_ref::<StringArray>().ok_or(
37                datafusion::error::DataFusionError::Execution(
38                    "First argument of contains_tokens can't be cast to string".to_string(),
39                ),
40            )?;
41            let scalar_str = args[1].as_any().downcast_ref::<StringArray>().ok_or(
42                datafusion::error::DataFusionError::Execution(
43                    "Second argument of contains_tokens can't be cast to string".to_string(),
44                ),
45            )?;
46
47            let result = column
48                .iter()
49                .enumerate()
50                .map(|(i, column)| column.map(|value| value.contains(scalar_str.value(i))));
51
52            Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
53        },
54        vec![],
55    ));
56
57    create_udf(
58        "contains_tokens",
59        vec![DataType::Utf8, DataType::Utf8],
60        DataType::Boolean,
61        Volatility::Immutable,
62        function,
63    )
64}
65
66static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
67
68#[cfg(test)]
69mod tests {
70    use crate::udf::CONTAINS_TOKENS_UDF;
71    use arrow_array::{Array, BooleanArray, StringArray};
72    use arrow_schema::{DataType, Field};
73    use datafusion::logical_expr::ScalarFunctionArgs;
74    use datafusion::physical_plan::ColumnarValue;
75    use std::sync::Arc;
76
77    #[tokio::test]
78    async fn test_contains_tokens() {
79        // Prepare arguments
80        let contains_tokens = CONTAINS_TOKENS_UDF.clone();
81        let text_col = Arc::new(StringArray::from(vec![
82            "a cat",
83            "lovely cat",
84            "white cat",
85            "catch up",
86            "fish",
87        ]));
88        let token = Arc::new(StringArray::from(vec!["cat", "cat", "cat", "cat", "cat"]));
89
90        let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
91        let arg_fields = vec![
92            Arc::new(Field::new("text_col".to_string(), DataType::Utf8, false)),
93            Arc::new(Field::new("token".to_string(), DataType::Utf8, false)),
94        ];
95
96        let args = ScalarFunctionArgs {
97            args,
98            arg_fields,
99            number_rows: 5,
100            return_field: Arc::new(Field::new("res".to_string(), DataType::Boolean, false)),
101        };
102
103        // Invoke contains_tokens manually
104        let values = contains_tokens.invoke_with_args(args).unwrap();
105
106        if let ColumnarValue::Array(array) = values {
107            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
108            assert_eq!(
109                array.clone(),
110                BooleanArray::from(vec![true, true, true, true, false])
111            );
112        } else {
113            panic!("Expected an Array but got {:?}", values);
114        }
115    }
116}