datafusion_loki/
function.rs1use std::{any::Any, sync::Arc};
2
3use datafusion::{
4 arrow::{
5 array::{Array, ArrayRef, Capacities, MapArray, MutableArrayData, make_array},
6 datatypes::{DataType, Fields},
7 },
8 common::{cast::as_map_array, exec_err, internal_err, utils::take_function_args},
9 logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility},
10 scalar::ScalarValue,
11};
12
13use crate::DFResult;
14
15#[derive(Debug)]
16pub struct MapGet {
17 signature: Signature,
18 aliases: Vec<String>,
19}
20
21impl Default for MapGet {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl MapGet {
28 pub fn new() -> Self {
29 Self {
30 signature: Signature::user_defined(Volatility::Immutable),
31 aliases: vec![],
32 }
33 }
34}
35
36impl ScalarUDFImpl for MapGet {
37 fn as_any(&self) -> &dyn Any {
38 self
39 }
40 fn name(&self) -> &str {
41 "map_get"
42 }
43
44 fn signature(&self) -> &Signature {
45 &self.signature
46 }
47
48 fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
49 let [map_type, _] = take_function_args(self.name(), arg_types)?;
50 let map_fields = get_map_entry_field(map_type)?;
51 let value_type = map_fields.last().unwrap().data_type().clone();
52 Ok(value_type)
53 }
54
55 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
56 make_scalar_function(map_extract_inner)(&args.args)
57 }
58
59 fn aliases(&self) -> &[String] {
60 &self.aliases
61 }
62
63 fn coerce_types(&self, arg_types: &[DataType]) -> DFResult<Vec<DataType>> {
64 let [map_type, _] = take_function_args(self.name(), arg_types)?;
65
66 let field = get_map_entry_field(map_type)?;
67 Ok(vec![
68 map_type.clone(),
69 field.first().unwrap().data_type().clone(),
70 ])
71 }
72}
73
74pub(crate) fn get_map_entry_field(data_type: &DataType) -> DFResult<&Fields> {
75 match data_type {
76 DataType::Map(field, _) => {
77 let field_data_type = field.data_type();
78 match field_data_type {
79 DataType::Struct(fields) => Ok(fields),
80 _ => {
81 internal_err!("Expected a Struct type, got {:?}", field_data_type)
82 }
83 }
84 }
85 _ => internal_err!("Expected a Map type, got {:?}", data_type),
86 }
87}
88
89pub(crate) fn make_scalar_function<F>(
91 inner: F,
92) -> impl Fn(&[ColumnarValue]) -> DFResult<ColumnarValue>
93where
94 F: Fn(&[ArrayRef]) -> DFResult<ArrayRef>,
95{
96 move |args: &[ColumnarValue]| {
97 let len = args
100 .iter()
101 .fold(Option::<usize>::None, |acc, arg| match arg {
102 ColumnarValue::Scalar(_) => acc,
103 ColumnarValue::Array(a) => Some(a.len()),
104 });
105
106 let is_scalar = len.is_none();
107
108 let args = ColumnarValue::values_to_arrays(args)?;
109
110 let result = (inner)(&args);
111
112 if is_scalar {
113 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
115 result.map(ColumnarValue::Scalar)
116 } else {
117 result.map(ColumnarValue::Array)
118 }
119 }
120}
121
122fn general_map_extract_inner(
123 map_array: &MapArray,
124 query_keys_array: &dyn Array,
125) -> DFResult<ArrayRef> {
126 let keys = map_array.keys();
127 let mut offsets = vec![0_i32];
128
129 let values = map_array.values();
130 let original_data = values.to_data();
131 let capacity = Capacities::Array(original_data.len());
132
133 let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity);
134
135 for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() {
136 let start = offset_window[0] as usize;
137 let end = offset_window[1] as usize;
138 let len = end - start;
139
140 let query_key = query_keys_array.slice(row_index, 1);
141
142 let value_index =
143 (0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref());
144
145 match value_index {
146 Some(index) => {
147 mutable.extend(0, start + index, start + index + 1);
148 }
149 None => {
150 mutable.extend_nulls(1);
151 }
152 }
153 offsets.push(offsets[row_index] + 1);
154 }
155
156 let data = mutable.freeze();
157
158 Ok(Arc::new(make_array(data)))
159}
160
161fn map_extract_inner(args: &[ArrayRef]) -> DFResult<ArrayRef> {
162 let [map_arg, key_arg] = take_function_args("map_extract", args)?;
163
164 let map_array = match map_arg.data_type() {
165 DataType::Map(_, _) => as_map_array(&map_arg)?,
166 _ => return exec_err!("The first argument in map_get must be a map"),
167 };
168
169 let key_type = map_array.key_type();
170
171 if key_type != key_arg.data_type() {
172 return exec_err!(
173 "The key type {} does not match the map key type {}",
174 key_arg.data_type(),
175 key_type
176 );
177 }
178
179 general_map_extract_inner(map_array, key_arg)
180}