Skip to main content

datafusion_functions/core/
getfield.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    Array, BooleanArray, Capacities, MutableArrayData, Scalar, cast::AsArray, make_array,
22    make_comparator,
23};
24use arrow::compute::SortOptions;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use arrow_buffer::NullBuffer;
27
28use datafusion_common::cast::{as_map_array, as_struct_array};
29use datafusion_common::{
30    Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, plan_datafusion_err,
31};
32use datafusion_expr::expr::ScalarFunction;
33use datafusion_expr::simplify::ExprSimplifyResult;
34use datafusion_expr::{
35    ColumnarValue, Documentation, Expr, ExpressionPlacement, ReturnFieldArgs,
36    ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
37};
38use datafusion_macros::user_doc;
39
40#[user_doc(
41    doc_section(label = "Other Functions"),
42    description = r#"Returns a field within a map or a struct with the given key.
43    Supports nested field access by providing multiple field names.
44    Note: most users invoke `get_field` indirectly via field access
45    syntax such as `my_struct_col['field_name']` which results in a call to
46    `get_field(my_struct_col, 'field_name')`.
47    Nested access like `my_struct['a']['b']` is optimized to a single call:
48    `get_field(my_struct, 'a', 'b')`."#,
49    syntax_example = "get_field(expression, field_name[, field_name2, ...])",
50    sql_example = r#"```sql
51> -- Access a field from a struct column
52> create table test( struct_col) as values
53    ({name: 'Alice', age: 30}),
54    ({name: 'Bob', age: 25});
55> select struct_col from test;
56+-----------------------------+
57| struct_col                  |
58+-----------------------------+
59| {name: Alice, age: 30}      |
60| {name: Bob, age: 25}        |
61+-----------------------------+
62> select struct_col['name'] as name from test;
63+-------+
64| name  |
65+-------+
66| Alice |
67| Bob   |
68+-------+
69
70> -- Nested field access with multiple arguments
71> create table test(struct_col) as values
72    ({outer: {inner_val: 42}});
73> select struct_col['outer']['inner_val'] as result from test;
74+--------+
75| result |
76+--------+
77| 42     |
78+--------+
79```"#,
80    argument(
81        name = "expression",
82        description = "The map or struct to retrieve a field from."
83    ),
84    argument(
85        name = "field_name",
86        description = "The field name(s) to access, in order for nested access. Must evaluate to strings."
87    )
88)]
89#[derive(Debug, PartialEq, Eq, Hash)]
90pub struct GetFieldFunc {
91    signature: Signature,
92}
93
94impl Default for GetFieldFunc {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100/// Process a map array by finding matching keys and extracting corresponding values.
101///
102/// This function handles both simple (scalar) and nested key types by using
103/// appropriate comparison strategies.
104fn process_map_array(
105    array: &dyn Array,
106    key_array: Arc<dyn Array>,
107) -> Result<ColumnarValue> {
108    let map_array = as_map_array(array)?;
109    let keys = if key_array.data_type().is_nested() {
110        let comparator = make_comparator(
111            map_array.keys().as_ref(),
112            key_array.as_ref(),
113            SortOptions::default(),
114        )?;
115        let len = map_array.keys().len().min(key_array.len());
116        let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
117        let nulls = NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
118        BooleanArray::new(values, nulls)
119    } else {
120        let be_compared = Scalar::new(key_array);
121        arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
122    };
123
124    let original_data = map_array.entries().column(1).to_data();
125    let capacity = Capacities::Array(original_data.len());
126    let mut mutable =
127        MutableArrayData::with_capacities(vec![&original_data], true, capacity);
128
129    for entry in 0..map_array.len() {
130        let start = map_array.value_offsets()[entry] as usize;
131        let end = map_array.value_offsets()[entry + 1] as usize;
132
133        let maybe_matched = keys
134            .slice(start, end - start)
135            .iter()
136            .enumerate()
137            .find(|(_, t)| t.unwrap());
138
139        if maybe_matched.is_none() {
140            mutable.extend_nulls(1);
141            continue;
142        }
143        let (match_offset, _) = maybe_matched.unwrap();
144        mutable.extend(0, start + match_offset, start + match_offset + 1);
145    }
146
147    let data = mutable.freeze();
148    let data = make_array(data);
149    Ok(ColumnarValue::Array(data))
150}
151
152/// Process a map array with a nested key type by iterating through entries
153/// and using a comparator for key matching.
154///
155/// This specialized version is used when the key type is nested (e.g., struct, list).
156fn process_map_with_nested_key(
157    array: &dyn Array,
158    key_array: &dyn Array,
159) -> Result<ColumnarValue> {
160    let map_array = as_map_array(array)?;
161
162    let comparator =
163        make_comparator(map_array.keys().as_ref(), key_array, SortOptions::default())?;
164
165    let original_data = map_array.entries().column(1).to_data();
166    let capacity = Capacities::Array(original_data.len());
167    let mut mutable =
168        MutableArrayData::with_capacities(vec![&original_data], true, capacity);
169
170    for entry in 0..map_array.len() {
171        let start = map_array.value_offsets()[entry] as usize;
172        let end = map_array.value_offsets()[entry + 1] as usize;
173
174        let mut found_match = false;
175        for i in start..end {
176            if comparator(i, 0).is_eq() {
177                mutable.extend(0, i, i + 1);
178                found_match = true;
179                break;
180            }
181        }
182
183        if !found_match {
184            mutable.extend_nulls(1);
185        }
186    }
187
188    let data = mutable.freeze();
189    let data = make_array(data);
190    Ok(ColumnarValue::Array(data))
191}
192
193/// Extract a single field from a struct or map array
194fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<ColumnarValue> {
195    let arrays = ColumnarValue::values_to_arrays(&[base])?;
196    let array = Arc::clone(&arrays[0]);
197
198    let string_value = name.try_as_str().flatten().map(|s| s.to_string());
199
200    match (array.data_type(), name, string_value) {
201        // Dictionary-encoded struct: extract the field from the dictionary's
202        // values (the deduplicated struct array) and rebuild a dictionary with
203        // the same keys. This preserves dictionary encoding without expanding.
204        (DataType::Dictionary(_, value_type), _, Some(field_name))
205            if matches!(value_type.as_ref(), DataType::Struct(_)) =>
206        {
207            let dict = array.as_any_dictionary();
208            let values_struct = dict.values().as_struct();
209            let field_col =
210                values_struct.column_by_name(&field_name).ok_or_else(|| {
211                    exec_datafusion_err!(
212                        "Field {field_name} not found in dictionary struct"
213                    )
214                })?;
215            Ok(ColumnarValue::Array(
216                dict.with_values(Arc::clone(field_col)),
217            ))
218        }
219        (DataType::Map(_, _), ScalarValue::List(arr), _) => {
220            let key_array: Arc<dyn Array> = arr;
221            process_map_array(&array, key_array)
222        }
223        (DataType::Map(_, _), ScalarValue::Struct(arr), _) => {
224            process_map_array(&array, arr as Arc<dyn Array>)
225        }
226        (DataType::Map(_, _), other, _) => {
227            let data_type = other.data_type();
228            if data_type.is_nested() {
229                process_map_with_nested_key(&array, &other.to_array()?)
230            } else {
231                process_map_array(&array, other.to_array()?)
232            }
233        }
234        (DataType::Struct(_), _, Some(k)) => {
235            let as_struct_array = as_struct_array(&array)?;
236            match as_struct_array.column_by_name(&k) {
237                None => exec_err!("Field {k} not found in struct"),
238                Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
239            }
240        }
241        (DataType::Struct(_), name, _) => exec_err!(
242            "get_field is only possible on struct with utf8 indexes. \
243                         Received with {name:?} index"
244        ),
245        (DataType::Null, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
246        (dt, name, _) => exec_err!(
247            "get_field is only possible on maps or structs. Received {dt} with {name:?} index"
248        ),
249    }
250}
251
252impl GetFieldFunc {
253    pub fn new() -> Self {
254        Self {
255            signature: Signature::user_defined(Volatility::Immutable),
256        }
257    }
258}
259
260// get_field(struct_array, field_name)
261impl ScalarUDFImpl for GetFieldFunc {
262    fn name(&self) -> &str {
263        "get_field"
264    }
265
266    fn display_name(&self, args: &[Expr]) -> Result<String> {
267        if args.len() < 2 {
268            return exec_err!(
269                "get_field requires at least 2 arguments, got {}",
270                args.len()
271            );
272        }
273
274        let base = &args[0];
275        let field_names: Vec<String> = args[1..]
276            .iter()
277            .map(|f| match f {
278                Expr::Literal(name, _) => name.to_string(),
279                other => other.schema_name().to_string(),
280            })
281            .collect();
282
283        Ok(format!("{}[{}]", base, field_names.join("][")))
284    }
285
286    fn schema_name(&self, args: &[Expr]) -> Result<String> {
287        if args.len() < 2 {
288            return exec_err!(
289                "get_field requires at least 2 arguments, got {}",
290                args.len()
291            );
292        }
293
294        let base = &args[0];
295        let field_names: Vec<String> = args[1..]
296            .iter()
297            .map(|f| match f {
298                Expr::Literal(name, _) => name.to_string(),
299                other => other.schema_name().to_string(),
300            })
301            .collect();
302
303        Ok(format!(
304            "{}[{}]",
305            base.schema_name(),
306            field_names.join("][")
307        ))
308    }
309
310    fn signature(&self) -> &Signature {
311        &self.signature
312    }
313
314    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
315        internal_err!("return_field_from_args should be called instead")
316    }
317
318    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
319        // Validate minimum 2 arguments: base expression + at least one field name
320        if args.scalar_arguments.len() < 2 {
321            return exec_err!(
322                "get_field requires at least 2 arguments, got {}",
323                args.scalar_arguments.len()
324            );
325        }
326
327        let mut current_field = Arc::clone(&args.arg_fields[0]);
328
329        // Iterate through each field name (starting from index 1)
330        for (i, sv) in args.scalar_arguments.iter().enumerate().skip(1) {
331            match current_field.data_type() {
332                DataType::Map(map_field, _) => {
333                    match map_field.data_type() {
334                        DataType::Struct(fields) if fields.len() == 2 => {
335                            // Arrow's MapArray is essentially a ListArray of structs with two columns. They are
336                            // often named "key", and "value", but we don't require any specific naming here;
337                            // instead, we assume that the second column is the "value" column both here and in
338                            // execution.
339                            let value_field = fields
340                                .get(1)
341                                .expect("fields should have exactly two members");
342
343                            current_field = Arc::new(
344                                value_field.as_ref().clone().with_nullable(true),
345                            );
346                        }
347                        _ => {
348                            return exec_err!(
349                                "Map fields must contain a Struct with exactly 2 fields"
350                            );
351                        }
352                    }
353                }
354                // Dictionary-encoded struct: resolve the child field from
355                // the underlying struct, then wrap the result back in the
356                // same Dictionary type so the promised type matches execution.
357                DataType::Dictionary(key_type, value_type)
358                    if matches!(value_type.as_ref(), DataType::Struct(_)) =>
359                {
360                    let DataType::Struct(fields) = value_type.as_ref() else {
361                        unreachable!()
362                    };
363                    let field_name = sv
364                        .as_ref()
365                        .and_then(|sv| {
366                            sv.try_as_str().flatten().filter(|s| !s.is_empty())
367                        })
368                        .ok_or_else(|| {
369                            exec_datafusion_err!("Field name must be a non-empty string")
370                        })?;
371
372                    let child_field = fields
373                        .iter()
374                        .find(|f| f.name() == field_name)
375                        .ok_or_else(|| {
376                            plan_datafusion_err!("Field {field_name} not found in struct")
377                        })?;
378
379                    let dict_type = DataType::Dictionary(
380                        key_type.clone(),
381                        Box::new(child_field.data_type().clone()),
382                    );
383                    let mut new_field =
384                        child_field.as_ref().clone().with_data_type(dict_type);
385                    if current_field.is_nullable() {
386                        new_field = new_field.with_nullable(true);
387                    }
388                    current_field = Arc::new(new_field);
389                }
390                DataType::Struct(fields) => {
391                    let field_name = sv
392                        .as_ref()
393                        .and_then(|sv| {
394                            sv.try_as_str().flatten().filter(|s| !s.is_empty())
395                        })
396                        .ok_or_else(|| {
397                            datafusion_common::DataFusionError::Execution(
398                                "Field name must be a non-empty string".to_string(),
399                            )
400                        })?;
401
402                    let child_field = fields
403                        .iter()
404                        .find(|f| f.name() == field_name)
405                        .ok_or_else(|| {
406                            plan_datafusion_err!("Field {field_name} not found in struct")
407                        })?;
408
409                    let mut new_field = child_field.as_ref().clone();
410
411                    // If the parent is nullable, then getting the child must be nullable
412                    if current_field.is_nullable() {
413                        new_field = new_field.with_nullable(true);
414                    }
415                    current_field = Arc::new(new_field);
416                }
417                DataType::Null => {
418                    return Ok(Field::new(self.name(), DataType::Null, true).into());
419                }
420                other => {
421                    return exec_err!(
422                        "Cannot access field at argument {}: type {} is not Struct, Map, or Null",
423                        i,
424                        other
425                    );
426                }
427            }
428        }
429
430        Ok(current_field)
431    }
432
433    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
434        if args.args.len() < 2 {
435            return exec_err!(
436                "get_field requires at least 2 arguments, got {}",
437                args.args.len()
438            );
439        }
440
441        let mut current = args.args[0].clone();
442
443        // Early exit for null base
444        if current.data_type().is_null() {
445            return Ok(ColumnarValue::Scalar(ScalarValue::Null));
446        }
447
448        // Iterate through each field name
449        for field_name in args.args.iter().skip(1) {
450            let field_name_scalar = match field_name {
451                ColumnarValue::Scalar(name) => name.clone(),
452                _ => {
453                    return exec_err!(
454                        "get_field function requires all field_name arguments to be scalars"
455                    );
456                }
457            };
458
459            current = extract_single_field(current, field_name_scalar)?;
460
461            // Early exit if we hit null
462            if current.data_type().is_null() {
463                return Ok(ColumnarValue::Scalar(ScalarValue::Null));
464            }
465        }
466
467        Ok(current)
468    }
469
470    fn simplify(
471        &self,
472        args: Vec<Expr>,
473        _info: &datafusion_expr::simplify::SimplifyContext,
474    ) -> Result<ExprSimplifyResult> {
475        // Need at least 2 args (base + field)
476        if args.len() < 2 {
477            return Ok(ExprSimplifyResult::Original(args));
478        }
479
480        // Flatten all nested get_field calls in a single pass
481        // Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c)
482
483        // Collect path arguments from all nested levels
484        let mut path_args_stack = Vec::new();
485        let mut current_expr = &args[0];
486
487        // Push the outermost path arguments first
488        path_args_stack.push(&args[1..]);
489
490        // Walk down the chain of nested get_field calls
491        let base_expr = loop {
492            if let Expr::ScalarFunction(ScalarFunction {
493                func,
494                args: inner_args,
495            }) = current_expr
496                && func.inner().is::<GetFieldFunc>()
497            {
498                // Store this level's path arguments (all except the first, which is base/nested call)
499                path_args_stack.push(&inner_args[1..]);
500
501                // Move to the next level down
502                current_expr = &inner_args[0];
503                continue;
504            }
505            // Not a get_field call, this is the base expression
506            break current_expr;
507        };
508
509        // If no nested get_field calls were found, return original
510        if path_args_stack.len() == args.len() - 1 {
511            return Ok(ExprSimplifyResult::Original(args));
512        }
513
514        // If we found any nested get_field calls, flatten them
515        // Build merged args: [base, ...all_path_args_in_correct_order]
516        let mut merged_args = vec![base_expr.clone()];
517
518        // Add path args in reverse order (innermost to outermost)
519        // Stack is: [outermost_paths, ..., innermost_paths]
520        // We want: [base, innermost_paths, ..., outermost_paths]
521        for path_slice in path_args_stack.iter().rev() {
522            merged_args.extend_from_slice(path_slice);
523        }
524
525        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
526            ScalarFunction::new_udf(
527                Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())),
528                merged_args,
529            ),
530        )))
531    }
532
533    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
534        if arg_types.len() < 2 {
535            return exec_err!(
536                "get_field requires at least 2 arguments, got {}",
537                arg_types.len()
538            );
539        }
540        // Accept types as-is, validation happens in return_field_from_args
541        Ok(arg_types.to_vec())
542    }
543
544    fn documentation(&self) -> Option<&Documentation> {
545        self.doc()
546    }
547
548    fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement {
549        // get_field can be pushed to leaves if:
550        // 1. The base (first arg) is a column or already placeable at leaves
551        // 2. All field keys (remaining args) are literals
552        if args.is_empty() {
553            return ExpressionPlacement::KeepInPlace;
554        }
555
556        let base_placement = args[0];
557        let base_is_pushable = matches!(
558            base_placement,
559            ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes
560        );
561
562        let all_keys_are_literals = args
563            .iter()
564            .skip(1)
565            .all(|p| *p == ExpressionPlacement::Literal);
566
567        if base_is_pushable && all_keys_are_literals {
568            ExpressionPlacement::MoveTowardsLeafNodes
569        } else {
570            ExpressionPlacement::KeepInPlace
571        }
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use arrow::array::{ArrayRef, Int32Array, StructArray};
579    use arrow::datatypes::Fields;
580
581    #[test]
582    fn test_get_field_utf8view_key() -> Result<()> {
583        // Create a struct array with fields "a" and "b"
584        let a_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
585        let b_values = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
586
587        let fields: Fields = vec![
588            Field::new("a", DataType::Int32, true),
589            Field::new("b", DataType::Int32, true),
590        ]
591        .into();
592
593        let struct_array = StructArray::new(
594            fields,
595            vec![
596                Arc::new(a_values) as ArrayRef,
597                Arc::new(b_values) as ArrayRef,
598            ],
599            None,
600        );
601
602        let base = ColumnarValue::Array(Arc::new(struct_array));
603
604        // Use Utf8View key to access field "a"
605        let key = ScalarValue::Utf8View(Some("a".to_string()));
606
607        let result = extract_single_field(base, key)?;
608
609        let result_array = result.into_array(3)?;
610        let expected = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
611
612        assert_eq!(result_array.as_ref(), &expected as &dyn Array);
613
614        Ok(())
615    }
616
617    #[test]
618    fn test_get_field_dict_encoded_struct() -> Result<()> {
619        use arrow::array::{DictionaryArray, StringArray, UInt32Array};
620        use arrow::datatypes::UInt32Type;
621
622        let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef;
623        let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
624
625        let struct_fields: Fields = vec![
626            Field::new("name", DataType::Utf8, false),
627            Field::new("id", DataType::Int32, false),
628        ]
629        .into();
630
631        let values_struct =
632            Arc::new(StructArray::new(struct_fields, vec![names, ids], None)) as ArrayRef;
633
634        let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]);
635        let dict = DictionaryArray::<UInt32Type>::try_new(keys, values_struct)?;
636
637        let base = ColumnarValue::Array(Arc::new(dict));
638        let key = ScalarValue::Utf8(Some("name".to_string()));
639
640        let result = extract_single_field(base, key)?;
641        let result_array = result.into_array(5)?;
642
643        assert!(
644            matches!(result_array.data_type(), DataType::Dictionary(_, _)),
645            "expected dictionary output, got {:?}",
646            result_array.data_type()
647        );
648
649        let result_dict = result_array
650            .as_any()
651            .downcast_ref::<DictionaryArray<UInt32Type>>()
652            .unwrap();
653        assert_eq!(result_dict.values().len(), 3);
654        assert_eq!(result_dict.len(), 5);
655
656        let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?;
657        let string_arr = resolved.as_any().downcast_ref::<StringArray>().unwrap();
658        assert_eq!(string_arr.value(0), "main");
659        assert_eq!(string_arr.value(1), "foo");
660        assert_eq!(string_arr.value(2), "bar");
661        assert_eq!(string_arr.value(3), "main");
662        assert_eq!(string_arr.value(4), "foo");
663
664        Ok(())
665    }
666
667    #[test]
668    fn test_get_field_nested_dict_struct() -> Result<()> {
669        use arrow::array::{DictionaryArray, StringArray, UInt32Array};
670        use arrow::datatypes::UInt32Type;
671
672        let func_names = Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef;
673        let func_files = Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef;
674        let func_fields: Fields = vec![
675            Field::new("name", DataType::Utf8, false),
676            Field::new("file", DataType::Utf8, false),
677        ]
678        .into();
679        let func_struct = Arc::new(StructArray::new(
680            func_fields.clone(),
681            vec![func_names, func_files],
682            None,
683        )) as ArrayRef;
684        let func_dict = Arc::new(DictionaryArray::<UInt32Type>::try_new(
685            UInt32Array::from(vec![0u32, 1, 0]),
686            func_struct,
687        )?) as ArrayRef;
688
689        let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef;
690        let line_fields: Fields = vec![
691            Field::new("num", DataType::Int32, false),
692            Field::new(
693                "function",
694                DataType::Dictionary(
695                    Box::new(DataType::UInt32),
696                    Box::new(DataType::Struct(func_fields)),
697                ),
698                false,
699            ),
700        ]
701        .into();
702        let line_struct = StructArray::new(line_fields, vec![line_nums, func_dict], None);
703
704        let base = ColumnarValue::Array(Arc::new(line_struct));
705
706        let func_result =
707            extract_single_field(base, ScalarValue::Utf8(Some("function".to_string())))?;
708
709        let func_array = func_result.into_array(3)?;
710        assert!(
711            matches!(func_array.data_type(), DataType::Dictionary(_, _)),
712            "expected dictionary for function, got {:?}",
713            func_array.data_type()
714        );
715
716        let name_result = extract_single_field(
717            ColumnarValue::Array(func_array),
718            ScalarValue::Utf8(Some("name".to_string())),
719        )?;
720        let name_array = name_result.into_array(3)?;
721
722        assert!(
723            matches!(name_array.data_type(), DataType::Dictionary(_, _)),
724            "expected dictionary for name, got {:?}",
725            name_array.data_type()
726        );
727
728        let name_dict = name_array
729            .as_any()
730            .downcast_ref::<DictionaryArray<UInt32Type>>()
731            .unwrap();
732        assert_eq!(name_dict.values().len(), 2);
733        assert_eq!(name_dict.len(), 3);
734
735        let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?;
736        let strings = resolved.as_any().downcast_ref::<StringArray>().unwrap();
737        assert_eq!(strings.value(0), "main");
738        assert_eq!(strings.value(1), "foo");
739        assert_eq!(strings.value(2), "main");
740
741        Ok(())
742    }
743
744    #[test]
745    fn test_placement_literal_key() {
746        let func = GetFieldFunc::new();
747
748        // get_field(col, 'literal') -> leaf-pushable (static field access)
749        let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal];
750        assert_eq!(
751            func.placement(&args),
752            ExpressionPlacement::MoveTowardsLeafNodes
753        );
754
755        // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access)
756        let args = vec![
757            ExpressionPlacement::Column,
758            ExpressionPlacement::Literal,
759            ExpressionPlacement::Literal,
760        ];
761        assert_eq!(
762            func.placement(&args),
763            ExpressionPlacement::MoveTowardsLeafNodes
764        );
765
766        // get_field(get_field(col, 'a'), 'b') represented as MoveTowardsLeafNodes for base
767        let args = vec![
768            ExpressionPlacement::MoveTowardsLeafNodes,
769            ExpressionPlacement::Literal,
770        ];
771        assert_eq!(
772            func.placement(&args),
773            ExpressionPlacement::MoveTowardsLeafNodes
774        );
775    }
776
777    #[test]
778    fn test_placement_column_key() {
779        let func = GetFieldFunc::new();
780
781        // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup)
782        let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Column];
783        assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
784
785        // get_field(col, 'a', other_col) -> NOT leaf-pushable (dynamic nested lookup)
786        let args = vec![
787            ExpressionPlacement::Column,
788            ExpressionPlacement::Literal,
789            ExpressionPlacement::Column,
790        ];
791        assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
792    }
793
794    #[test]
795    fn test_placement_root() {
796        let func = GetFieldFunc::new();
797
798        // get_field(root_expr, 'literal') -> NOT leaf-pushable
799        let args = vec![
800            ExpressionPlacement::KeepInPlace,
801            ExpressionPlacement::Literal,
802        ];
803        assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
804
805        // get_field(col, root_expr) -> NOT leaf-pushable
806        let args = vec![
807            ExpressionPlacement::Column,
808            ExpressionPlacement::KeepInPlace,
809        ];
810        assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
811    }
812
813    #[test]
814    fn test_placement_edge_cases() {
815        let func = GetFieldFunc::new();
816
817        // Empty args -> NOT leaf-pushable
818        assert_eq!(func.placement(&[]), ExpressionPlacement::KeepInPlace);
819
820        // Just base, no key -> MoveTowardsLeafNodes (not a valid call but should handle gracefully)
821        let args = vec![ExpressionPlacement::Column];
822        assert_eq!(
823            func.placement(&args),
824            ExpressionPlacement::MoveTowardsLeafNodes
825        );
826
827        // Literal base with literal key -> NOT leaf-pushable (would be constant-folded)
828        let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal];
829        assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
830    }
831}