datafusion_loki/
function.rs

1use std::{any::Any, sync::Arc};
2
3use datafusion::{
4    arrow::{
5        array::{Array, ArrayRef, Capacities, MapArray, MutableArrayData, make_array},
6        datatypes::{DataType, Fields},
7    },
8    common::{cast::as_map_array, exec_err, internal_err, utils::take_function_args},
9    logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility},
10    scalar::ScalarValue,
11};
12
13use crate::DFResult;
14
15#[derive(Debug)]
16pub struct MapGet {
17    signature: Signature,
18    aliases: Vec<String>,
19}
20
21impl Default for MapGet {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl MapGet {
28    pub fn new() -> Self {
29        Self {
30            signature: Signature::user_defined(Volatility::Immutable),
31            aliases: vec![],
32        }
33    }
34}
35
36impl ScalarUDFImpl for MapGet {
37    fn as_any(&self) -> &dyn Any {
38        self
39    }
40    fn name(&self) -> &str {
41        "map_get"
42    }
43
44    fn signature(&self) -> &Signature {
45        &self.signature
46    }
47
48    fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
49        let [map_type, _] = take_function_args(self.name(), arg_types)?;
50        let map_fields = get_map_entry_field(map_type)?;
51        let value_type = map_fields.last().unwrap().data_type().clone();
52        Ok(value_type)
53    }
54
55    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
56        make_scalar_function(map_extract_inner)(&args.args)
57    }
58
59    fn aliases(&self) -> &[String] {
60        &self.aliases
61    }
62
63    fn coerce_types(&self, arg_types: &[DataType]) -> DFResult<Vec<DataType>> {
64        let [map_type, _] = take_function_args(self.name(), arg_types)?;
65
66        let field = get_map_entry_field(map_type)?;
67        Ok(vec![
68            map_type.clone(),
69            field.first().unwrap().data_type().clone(),
70        ])
71    }
72}
73
74pub(crate) fn get_map_entry_field(data_type: &DataType) -> DFResult<&Fields> {
75    match data_type {
76        DataType::Map(field, _) => {
77            let field_data_type = field.data_type();
78            match field_data_type {
79                DataType::Struct(fields) => Ok(fields),
80                _ => {
81                    internal_err!("Expected a Struct type, got {:?}", field_data_type)
82                }
83            }
84        }
85        _ => internal_err!("Expected a Map type, got {:?}", data_type),
86    }
87}
88
89/// array function wrapper that differentiates between scalar (length 1) and array.
90pub(crate) fn make_scalar_function<F>(
91    inner: F,
92) -> impl Fn(&[ColumnarValue]) -> DFResult<ColumnarValue>
93where
94    F: Fn(&[ArrayRef]) -> DFResult<ArrayRef>,
95{
96    move |args: &[ColumnarValue]| {
97        // first, identify if any of the arguments is an Array. If yes, store its `len`,
98        // as any scalar will need to be converted to an array of len `len`.
99        let len = args
100            .iter()
101            .fold(Option::<usize>::None, |acc, arg| match arg {
102                ColumnarValue::Scalar(_) => acc,
103                ColumnarValue::Array(a) => Some(a.len()),
104            });
105
106        let is_scalar = len.is_none();
107
108        let args = ColumnarValue::values_to_arrays(args)?;
109
110        let result = (inner)(&args);
111
112        if is_scalar {
113            // If all inputs are scalar, keeps output as scalar
114            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
115            result.map(ColumnarValue::Scalar)
116        } else {
117            result.map(ColumnarValue::Array)
118        }
119    }
120}
121
122fn general_map_extract_inner(
123    map_array: &MapArray,
124    query_keys_array: &dyn Array,
125) -> DFResult<ArrayRef> {
126    let keys = map_array.keys();
127    let mut offsets = vec![0_i32];
128
129    let values = map_array.values();
130    let original_data = values.to_data();
131    let capacity = Capacities::Array(original_data.len());
132
133    let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity);
134
135    for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() {
136        let start = offset_window[0] as usize;
137        let end = offset_window[1] as usize;
138        let len = end - start;
139
140        let query_key = query_keys_array.slice(row_index, 1);
141
142        let value_index =
143            (0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref());
144
145        match value_index {
146            Some(index) => {
147                mutable.extend(0, start + index, start + index + 1);
148            }
149            None => {
150                mutable.extend_nulls(1);
151            }
152        }
153        offsets.push(offsets[row_index] + 1);
154    }
155
156    let data = mutable.freeze();
157
158    Ok(Arc::new(make_array(data)))
159}
160
161fn map_extract_inner(args: &[ArrayRef]) -> DFResult<ArrayRef> {
162    let [map_arg, key_arg] = take_function_args("map_extract", args)?;
163
164    let map_array = match map_arg.data_type() {
165        DataType::Map(_, _) => as_map_array(&map_arg)?,
166        _ => return exec_err!("The first argument in map_get must be a map"),
167    };
168
169    let key_type = map_array.key_type();
170
171    if key_type != key_arg.data_type() {
172        return exec_err!(
173            "The key type {} does not match the map key type {}",
174            key_arg.data_type(),
175            key_type
176        );
177    }
178
179    general_map_extract_inner(map_array, key_arg)
180}