datafusion_functions/core/
getfield.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, BooleanArray, Capacities, MutableArrayData, Scalar, make_array,
23    make_comparator,
24};
25use arrow::compute::SortOptions;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use arrow_buffer::NullBuffer;
28
29use datafusion_common::cast::{as_map_array, as_struct_array};
30use datafusion_common::{
31    Result, ScalarValue, exec_err, internal_err, plan_datafusion_err,
32};
33use datafusion_expr::expr::ScalarFunction;
34use datafusion_expr::simplify::ExprSimplifyResult;
35use datafusion_expr::{
36    ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
37    ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_macros::user_doc;
40
41#[user_doc(
42    doc_section(label = "Other Functions"),
43    description = r#"Returns a field within a map or a struct with the given key.
44    Supports nested field access by providing multiple field names.
45    Note: most users invoke `get_field` indirectly via field access
46    syntax such as `my_struct_col['field_name']` which results in a call to
47    `get_field(my_struct_col, 'field_name')`.
48    Nested access like `my_struct['a']['b']` is optimized to a single call:
49    `get_field(my_struct, 'a', 'b')`."#,
50    syntax_example = "get_field(expression, field_name[, field_name2, ...])",
51    sql_example = r#"```sql
52> -- Access a field from a struct column
53> create table test( struct_col) as values
54    ({name: 'Alice', age: 30}),
55    ({name: 'Bob', age: 25});
56> select struct_col from test;
57+-----------------------------+
58| struct_col                  |
59+-----------------------------+
60| {name: Alice, age: 30}      |
61| {name: Bob, age: 25}        |
62+-----------------------------+
63> select struct_col['name'] as name from test;
64+-------+
65| name  |
66+-------+
67| Alice |
68| Bob   |
69+-------+
70
71> -- Nested field access with multiple arguments
72> create table test(struct_col) as values
73    ({outer: {inner_val: 42}});
74> select struct_col['outer']['inner_val'] as result from test;
75+--------+
76| result |
77+--------+
78| 42     |
79+--------+
80```"#,
81    argument(
82        name = "expression",
83        description = "The map or struct to retrieve a field from."
84    ),
85    argument(
86        name = "field_name",
87        description = "The field name(s) to access, in order for nested access. Must evaluate to strings."
88    )
89)]
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct GetFieldFunc {
92    signature: Signature,
93}
94
95impl Default for GetFieldFunc {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101/// Process a map array by finding matching keys and extracting corresponding values.
102///
103/// This function handles both simple (scalar) and nested key types by using
104/// appropriate comparison strategies.
105fn process_map_array(
106    array: &dyn Array,
107    key_array: Arc<dyn Array>,
108) -> Result<ColumnarValue> {
109    let map_array = as_map_array(array)?;
110    let keys = if key_array.data_type().is_nested() {
111        let comparator = make_comparator(
112            map_array.keys().as_ref(),
113            key_array.as_ref(),
114            SortOptions::default(),
115        )?;
116        let len = map_array.keys().len().min(key_array.len());
117        let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
118        let nulls = NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
119        BooleanArray::new(values, nulls)
120    } else {
121        let be_compared = Scalar::new(key_array);
122        arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
123    };
124
125    let original_data = map_array.entries().column(1).to_data();
126    let capacity = Capacities::Array(original_data.len());
127    let mut mutable =
128        MutableArrayData::with_capacities(vec![&original_data], true, capacity);
129
130    for entry in 0..map_array.len() {
131        let start = map_array.value_offsets()[entry] as usize;
132        let end = map_array.value_offsets()[entry + 1] as usize;
133
134        let maybe_matched = keys
135            .slice(start, end - start)
136            .iter()
137            .enumerate()
138            .find(|(_, t)| t.unwrap());
139
140        if maybe_matched.is_none() {
141            mutable.extend_nulls(1);
142            continue;
143        }
144        let (match_offset, _) = maybe_matched.unwrap();
145        mutable.extend(0, start + match_offset, start + match_offset + 1);
146    }
147
148    let data = mutable.freeze();
149    let data = make_array(data);
150    Ok(ColumnarValue::Array(data))
151}
152
153/// Process a map array with a nested key type by iterating through entries
154/// and using a comparator for key matching.
155///
156/// This specialized version is used when the key type is nested (e.g., struct, list).
157fn process_map_with_nested_key(
158    array: &dyn Array,
159    key_array: &dyn Array,
160) -> Result<ColumnarValue> {
161    let map_array = as_map_array(array)?;
162
163    let comparator =
164        make_comparator(map_array.keys().as_ref(), key_array, SortOptions::default())?;
165
166    let original_data = map_array.entries().column(1).to_data();
167    let capacity = Capacities::Array(original_data.len());
168    let mut mutable =
169        MutableArrayData::with_capacities(vec![&original_data], true, capacity);
170
171    for entry in 0..map_array.len() {
172        let start = map_array.value_offsets()[entry] as usize;
173        let end = map_array.value_offsets()[entry + 1] as usize;
174
175        let mut found_match = false;
176        for i in start..end {
177            if comparator(i, 0).is_eq() {
178                mutable.extend(0, i, i + 1);
179                found_match = true;
180                break;
181            }
182        }
183
184        if !found_match {
185            mutable.extend_nulls(1);
186        }
187    }
188
189    let data = mutable.freeze();
190    let data = make_array(data);
191    Ok(ColumnarValue::Array(data))
192}
193
194/// Extract a single field from a struct or map array
195fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<ColumnarValue> {
196    let arrays = ColumnarValue::values_to_arrays(&[base])?;
197    let array = Arc::clone(&arrays[0]);
198
199    let string_value = name.try_as_str().flatten().map(|s| s.to_string());
200
201    match (array.data_type(), name, string_value) {
202        (DataType::Map(_, _), ScalarValue::List(arr), _) => {
203            let key_array: Arc<dyn Array> = arr;
204            process_map_array(&array, key_array)
205        }
206        (DataType::Map(_, _), ScalarValue::Struct(arr), _) => {
207            process_map_array(&array, arr as Arc<dyn Array>)
208        }
209        (DataType::Map(_, _), other, _) => {
210            let data_type = other.data_type();
211            if data_type.is_nested() {
212                process_map_with_nested_key(&array, &other.to_array()?)
213            } else {
214                process_map_array(&array, other.to_array()?)
215            }
216        }
217        (DataType::Struct(_), _, Some(k)) => {
218            let as_struct_array = as_struct_array(&array)?;
219            match as_struct_array.column_by_name(&k) {
220                None => exec_err!("Field {k} not found in struct"),
221                Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
222            }
223        }
224        (DataType::Struct(_), name, _) => exec_err!(
225            "get_field is only possible on struct with utf8 indexes. \
226                         Received with {name:?} index"
227        ),
228        (DataType::Null, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
229        (dt, name, _) => exec_err!(
230            "get_field is only possible on maps or structs. Received {dt} with {name:?} index"
231        ),
232    }
233}
234
235impl GetFieldFunc {
236    pub fn new() -> Self {
237        Self {
238            signature: Signature::user_defined(Volatility::Immutable),
239        }
240    }
241}
242
243// get_field(struct_array, field_name)
244impl ScalarUDFImpl for GetFieldFunc {
245    fn as_any(&self) -> &dyn Any {
246        self
247    }
248
249    fn name(&self) -> &str {
250        "get_field"
251    }
252
253    fn display_name(&self, args: &[Expr]) -> Result<String> {
254        if args.len() < 2 {
255            return exec_err!(
256                "get_field requires at least 2 arguments, got {}",
257                args.len()
258            );
259        }
260
261        let base = &args[0];
262        let field_names: Vec<String> = args[1..]
263            .iter()
264            .map(|f| match f {
265                Expr::Literal(name, _) => name.to_string(),
266                other => other.schema_name().to_string(),
267            })
268            .collect();
269
270        Ok(format!("{}[{}]", base, field_names.join("][")))
271    }
272
273    fn schema_name(&self, args: &[Expr]) -> Result<String> {
274        if args.len() < 2 {
275            return exec_err!(
276                "get_field requires at least 2 arguments, got {}",
277                args.len()
278            );
279        }
280
281        let base = &args[0];
282        let field_names: Vec<String> = args[1..]
283            .iter()
284            .map(|f| match f {
285                Expr::Literal(name, _) => name.to_string(),
286                other => other.schema_name().to_string(),
287            })
288            .collect();
289
290        Ok(format!(
291            "{}[{}]",
292            base.schema_name(),
293            field_names.join("][")
294        ))
295    }
296
297    fn signature(&self) -> &Signature {
298        &self.signature
299    }
300
301    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
302        internal_err!("return_field_from_args should be called instead")
303    }
304
305    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
306        // Validate minimum 2 arguments: base expression + at least one field name
307        if args.scalar_arguments.len() < 2 {
308            return exec_err!(
309                "get_field requires at least 2 arguments, got {}",
310                args.scalar_arguments.len()
311            );
312        }
313
314        let mut current_field = Arc::clone(&args.arg_fields[0]);
315
316        // Iterate through each field name (starting from index 1)
317        for (i, sv) in args.scalar_arguments.iter().enumerate().skip(1) {
318            match current_field.data_type() {
319                DataType::Map(map_field, _) => {
320                    match map_field.data_type() {
321                        DataType::Struct(fields) if fields.len() == 2 => {
322                            // Arrow's MapArray is essentially a ListArray of structs with two columns. They are
323                            // often named "key", and "value", but we don't require any specific naming here;
324                            // instead, we assume that the second column is the "value" column both here and in
325                            // execution.
326                            let value_field = fields
327                                .get(1)
328                                .expect("fields should have exactly two members");
329
330                            current_field = Arc::new(
331                                value_field.as_ref().clone().with_nullable(true),
332                            );
333                        }
334                        _ => {
335                            return exec_err!(
336                                "Map fields must contain a Struct with exactly 2 fields"
337                            );
338                        }
339                    }
340                }
341                DataType::Struct(fields) => {
342                    let field_name = sv
343                        .as_ref()
344                        .and_then(|sv| {
345                            sv.try_as_str().flatten().filter(|s| !s.is_empty())
346                        })
347                        .ok_or_else(|| {
348                            datafusion_common::DataFusionError::Execution(
349                                "Field name must be a non-empty string".to_string(),
350                            )
351                        })?;
352
353                    let child_field = fields
354                        .iter()
355                        .find(|f| f.name() == field_name)
356                        .ok_or_else(|| {
357                            plan_datafusion_err!("Field {field_name} not found in struct")
358                        })?;
359
360                    let mut new_field = child_field.as_ref().clone();
361
362                    // If the parent is nullable, then getting the child must be nullable
363                    if current_field.is_nullable() {
364                        new_field = new_field.with_nullable(true);
365                    }
366                    current_field = Arc::new(new_field);
367                }
368                DataType::Null => {
369                    return Ok(Field::new(self.name(), DataType::Null, true).into());
370                }
371                other => {
372                    return exec_err!(
373                        "Cannot access field at argument {}: type {} is not Struct, Map, or Null",
374                        i,
375                        other
376                    );
377                }
378            }
379        }
380
381        Ok(current_field)
382    }
383
384    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
385        if args.args.len() < 2 {
386            return exec_err!(
387                "get_field requires at least 2 arguments, got {}",
388                args.args.len()
389            );
390        }
391
392        let mut current = args.args[0].clone();
393
394        // Early exit for null base
395        if current.data_type().is_null() {
396            return Ok(ColumnarValue::Scalar(ScalarValue::Null));
397        }
398
399        // Iterate through each field name
400        for field_name in args.args.iter().skip(1) {
401            let field_name_scalar = match field_name {
402                ColumnarValue::Scalar(name) => name.clone(),
403                _ => {
404                    return exec_err!(
405                        "get_field function requires all field_name arguments to be scalars"
406                    );
407                }
408            };
409
410            current = extract_single_field(current, field_name_scalar)?;
411
412            // Early exit if we hit null
413            if current.data_type().is_null() {
414                return Ok(ColumnarValue::Scalar(ScalarValue::Null));
415            }
416        }
417
418        Ok(current)
419    }
420
421    fn simplify(
422        &self,
423        args: Vec<Expr>,
424        _info: &dyn datafusion_expr::simplify::SimplifyInfo,
425    ) -> Result<ExprSimplifyResult> {
426        // Need at least 2 args (base + field)
427        if args.len() < 2 {
428            return Ok(ExprSimplifyResult::Original(args));
429        }
430
431        // Flatten all nested get_field calls in a single pass
432        // Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c)
433
434        // Collect path arguments from all nested levels
435        let mut path_args_stack = Vec::new();
436        let mut current_expr = &args[0];
437
438        // Push the outermost path arguments first
439        path_args_stack.push(&args[1..]);
440
441        // Walk down the chain of nested get_field calls
442        let base_expr = loop {
443            if let Expr::ScalarFunction(ScalarFunction {
444                func,
445                args: inner_args,
446            }) = current_expr
447                && func
448                    .inner()
449                    .as_any()
450                    .downcast_ref::<GetFieldFunc>()
451                    .is_some()
452            {
453                // Store this level's path arguments (all except the first, which is base/nested call)
454                path_args_stack.push(&inner_args[1..]);
455
456                // Move to the next level down
457                current_expr = &inner_args[0];
458                continue;
459            }
460            // Not a get_field call, this is the base expression
461            break current_expr;
462        };
463
464        // If no nested get_field calls were found, return original
465        if path_args_stack.len() == args.len() - 1 {
466            return Ok(ExprSimplifyResult::Original(args));
467        }
468
469        // If we found any nested get_field calls, flatten them
470        // Build merged args: [base, ...all_path_args_in_correct_order]
471        let mut merged_args = vec![base_expr.clone()];
472
473        // Add path args in reverse order (innermost to outermost)
474        // Stack is: [outermost_paths, ..., innermost_paths]
475        // We want: [base, innermost_paths, ..., outermost_paths]
476        for path_slice in path_args_stack.iter().rev() {
477            merged_args.extend_from_slice(path_slice);
478        }
479
480        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
481            ScalarFunction::new_udf(
482                Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())),
483                merged_args,
484            ),
485        )))
486    }
487
488    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
489        if arg_types.len() < 2 {
490            return exec_err!(
491                "get_field requires at least 2 arguments, got {}",
492                arg_types.len()
493            );
494        }
495        // Accept types as-is, validation happens in return_field_from_args
496        Ok(arg_types.to_vec())
497    }
498
499    fn documentation(&self) -> Option<&Documentation> {
500        self.doc()
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use arrow::array::{ArrayRef, Int32Array, StructArray};
508    use arrow::datatypes::Fields;
509
510    #[test]
511    fn test_get_field_utf8view_key() -> Result<()> {
512        // Create a struct array with fields "a" and "b"
513        let a_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
514        let b_values = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
515
516        let fields: Fields = vec![
517            Field::new("a", DataType::Int32, true),
518            Field::new("b", DataType::Int32, true),
519        ]
520        .into();
521
522        let struct_array = StructArray::new(
523            fields,
524            vec![
525                Arc::new(a_values) as ArrayRef,
526                Arc::new(b_values) as ArrayRef,
527            ],
528            None,
529        );
530
531        let base = ColumnarValue::Array(Arc::new(struct_array));
532
533        // Use Utf8View key to access field "a"
534        let key = ScalarValue::Utf8View(Some("a".to_string()));
535
536        let result = extract_single_field(base, key)?;
537
538        let result_array = result.into_array(3)?;
539        let expected = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
540
541        assert_eq!(result_array.as_ref(), &expected as &dyn Array);
542
543        Ok(())
544    }
545}