1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, BooleanArray, Capacities, MutableArrayData, Scalar, make_array,
23 make_comparator,
24};
25use arrow::compute::SortOptions;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use arrow_buffer::NullBuffer;
28
29use datafusion_common::cast::{as_map_array, as_struct_array};
30use datafusion_common::{
31 Result, ScalarValue, exec_err, internal_err, plan_datafusion_err,
32};
33use datafusion_expr::expr::ScalarFunction;
34use datafusion_expr::simplify::ExprSimplifyResult;
35use datafusion_expr::{
36 ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
37 ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_macros::user_doc;
40
41#[user_doc(
42 doc_section(label = "Other Functions"),
43 description = r#"Returns a field within a map or a struct with the given key.
44 Supports nested field access by providing multiple field names.
45 Note: most users invoke `get_field` indirectly via field access
46 syntax such as `my_struct_col['field_name']` which results in a call to
47 `get_field(my_struct_col, 'field_name')`.
48 Nested access like `my_struct['a']['b']` is optimized to a single call:
49 `get_field(my_struct, 'a', 'b')`."#,
50 syntax_example = "get_field(expression, field_name[, field_name2, ...])",
51 sql_example = r#"```sql
52> -- Access a field from a struct column
53> create table test( struct_col) as values
54 ({name: 'Alice', age: 30}),
55 ({name: 'Bob', age: 25});
56> select struct_col from test;
57+-----------------------------+
58| struct_col |
59+-----------------------------+
60| {name: Alice, age: 30} |
61| {name: Bob, age: 25} |
62+-----------------------------+
63> select struct_col['name'] as name from test;
64+-------+
65| name |
66+-------+
67| Alice |
68| Bob |
69+-------+
70
71> -- Nested field access with multiple arguments
72> create table test(struct_col) as values
73 ({outer: {inner_val: 42}});
74> select struct_col['outer']['inner_val'] as result from test;
75+--------+
76| result |
77+--------+
78| 42 |
79+--------+
80```"#,
81 argument(
82 name = "expression",
83 description = "The map or struct to retrieve a field from."
84 ),
85 argument(
86 name = "field_name",
87 description = "The field name(s) to access, in order for nested access. Must evaluate to strings."
88 )
89)]
90#[derive(Debug, PartialEq, Eq, Hash)]
91pub struct GetFieldFunc {
92 signature: Signature,
93}
94
95impl Default for GetFieldFunc {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101fn process_map_array(
106 array: &dyn Array,
107 key_array: Arc<dyn Array>,
108) -> Result<ColumnarValue> {
109 let map_array = as_map_array(array)?;
110 let keys = if key_array.data_type().is_nested() {
111 let comparator = make_comparator(
112 map_array.keys().as_ref(),
113 key_array.as_ref(),
114 SortOptions::default(),
115 )?;
116 let len = map_array.keys().len().min(key_array.len());
117 let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
118 let nulls = NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
119 BooleanArray::new(values, nulls)
120 } else {
121 let be_compared = Scalar::new(key_array);
122 arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
123 };
124
125 let original_data = map_array.entries().column(1).to_data();
126 let capacity = Capacities::Array(original_data.len());
127 let mut mutable =
128 MutableArrayData::with_capacities(vec![&original_data], true, capacity);
129
130 for entry in 0..map_array.len() {
131 let start = map_array.value_offsets()[entry] as usize;
132 let end = map_array.value_offsets()[entry + 1] as usize;
133
134 let maybe_matched = keys
135 .slice(start, end - start)
136 .iter()
137 .enumerate()
138 .find(|(_, t)| t.unwrap());
139
140 if maybe_matched.is_none() {
141 mutable.extend_nulls(1);
142 continue;
143 }
144 let (match_offset, _) = maybe_matched.unwrap();
145 mutable.extend(0, start + match_offset, start + match_offset + 1);
146 }
147
148 let data = mutable.freeze();
149 let data = make_array(data);
150 Ok(ColumnarValue::Array(data))
151}
152
153fn process_map_with_nested_key(
158 array: &dyn Array,
159 key_array: &dyn Array,
160) -> Result<ColumnarValue> {
161 let map_array = as_map_array(array)?;
162
163 let comparator =
164 make_comparator(map_array.keys().as_ref(), key_array, SortOptions::default())?;
165
166 let original_data = map_array.entries().column(1).to_data();
167 let capacity = Capacities::Array(original_data.len());
168 let mut mutable =
169 MutableArrayData::with_capacities(vec![&original_data], true, capacity);
170
171 for entry in 0..map_array.len() {
172 let start = map_array.value_offsets()[entry] as usize;
173 let end = map_array.value_offsets()[entry + 1] as usize;
174
175 let mut found_match = false;
176 for i in start..end {
177 if comparator(i, 0).is_eq() {
178 mutable.extend(0, i, i + 1);
179 found_match = true;
180 break;
181 }
182 }
183
184 if !found_match {
185 mutable.extend_nulls(1);
186 }
187 }
188
189 let data = mutable.freeze();
190 let data = make_array(data);
191 Ok(ColumnarValue::Array(data))
192}
193
194fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<ColumnarValue> {
196 let arrays = ColumnarValue::values_to_arrays(&[base])?;
197 let array = Arc::clone(&arrays[0]);
198
199 let string_value = name.try_as_str().flatten().map(|s| s.to_string());
200
201 match (array.data_type(), name, string_value) {
202 (DataType::Map(_, _), ScalarValue::List(arr), _) => {
203 let key_array: Arc<dyn Array> = arr;
204 process_map_array(&array, key_array)
205 }
206 (DataType::Map(_, _), ScalarValue::Struct(arr), _) => {
207 process_map_array(&array, arr as Arc<dyn Array>)
208 }
209 (DataType::Map(_, _), other, _) => {
210 let data_type = other.data_type();
211 if data_type.is_nested() {
212 process_map_with_nested_key(&array, &other.to_array()?)
213 } else {
214 process_map_array(&array, other.to_array()?)
215 }
216 }
217 (DataType::Struct(_), _, Some(k)) => {
218 let as_struct_array = as_struct_array(&array)?;
219 match as_struct_array.column_by_name(&k) {
220 None => exec_err!("Field {k} not found in struct"),
221 Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
222 }
223 }
224 (DataType::Struct(_), name, _) => exec_err!(
225 "get_field is only possible on struct with utf8 indexes. \
226 Received with {name:?} index"
227 ),
228 (DataType::Null, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
229 (dt, name, _) => exec_err!(
230 "get_field is only possible on maps or structs. Received {dt} with {name:?} index"
231 ),
232 }
233}
234
235impl GetFieldFunc {
236 pub fn new() -> Self {
237 Self {
238 signature: Signature::user_defined(Volatility::Immutable),
239 }
240 }
241}
242
243impl ScalarUDFImpl for GetFieldFunc {
245 fn as_any(&self) -> &dyn Any {
246 self
247 }
248
249 fn name(&self) -> &str {
250 "get_field"
251 }
252
253 fn display_name(&self, args: &[Expr]) -> Result<String> {
254 if args.len() < 2 {
255 return exec_err!(
256 "get_field requires at least 2 arguments, got {}",
257 args.len()
258 );
259 }
260
261 let base = &args[0];
262 let field_names: Vec<String> = args[1..]
263 .iter()
264 .map(|f| match f {
265 Expr::Literal(name, _) => name.to_string(),
266 other => other.schema_name().to_string(),
267 })
268 .collect();
269
270 Ok(format!("{}[{}]", base, field_names.join("][")))
271 }
272
273 fn schema_name(&self, args: &[Expr]) -> Result<String> {
274 if args.len() < 2 {
275 return exec_err!(
276 "get_field requires at least 2 arguments, got {}",
277 args.len()
278 );
279 }
280
281 let base = &args[0];
282 let field_names: Vec<String> = args[1..]
283 .iter()
284 .map(|f| match f {
285 Expr::Literal(name, _) => name.to_string(),
286 other => other.schema_name().to_string(),
287 })
288 .collect();
289
290 Ok(format!(
291 "{}[{}]",
292 base.schema_name(),
293 field_names.join("][")
294 ))
295 }
296
297 fn signature(&self) -> &Signature {
298 &self.signature
299 }
300
301 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
302 internal_err!("return_field_from_args should be called instead")
303 }
304
305 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
306 if args.scalar_arguments.len() < 2 {
308 return exec_err!(
309 "get_field requires at least 2 arguments, got {}",
310 args.scalar_arguments.len()
311 );
312 }
313
314 let mut current_field = Arc::clone(&args.arg_fields[0]);
315
316 for (i, sv) in args.scalar_arguments.iter().enumerate().skip(1) {
318 match current_field.data_type() {
319 DataType::Map(map_field, _) => {
320 match map_field.data_type() {
321 DataType::Struct(fields) if fields.len() == 2 => {
322 let value_field = fields
327 .get(1)
328 .expect("fields should have exactly two members");
329
330 current_field = Arc::new(
331 value_field.as_ref().clone().with_nullable(true),
332 );
333 }
334 _ => {
335 return exec_err!(
336 "Map fields must contain a Struct with exactly 2 fields"
337 );
338 }
339 }
340 }
341 DataType::Struct(fields) => {
342 let field_name = sv
343 .as_ref()
344 .and_then(|sv| {
345 sv.try_as_str().flatten().filter(|s| !s.is_empty())
346 })
347 .ok_or_else(|| {
348 datafusion_common::DataFusionError::Execution(
349 "Field name must be a non-empty string".to_string(),
350 )
351 })?;
352
353 let child_field = fields
354 .iter()
355 .find(|f| f.name() == field_name)
356 .ok_or_else(|| {
357 plan_datafusion_err!("Field {field_name} not found in struct")
358 })?;
359
360 let mut new_field = child_field.as_ref().clone();
361
362 if current_field.is_nullable() {
364 new_field = new_field.with_nullable(true);
365 }
366 current_field = Arc::new(new_field);
367 }
368 DataType::Null => {
369 return Ok(Field::new(self.name(), DataType::Null, true).into());
370 }
371 other => {
372 return exec_err!(
373 "Cannot access field at argument {}: type {} is not Struct, Map, or Null",
374 i,
375 other
376 );
377 }
378 }
379 }
380
381 Ok(current_field)
382 }
383
384 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
385 if args.args.len() < 2 {
386 return exec_err!(
387 "get_field requires at least 2 arguments, got {}",
388 args.args.len()
389 );
390 }
391
392 let mut current = args.args[0].clone();
393
394 if current.data_type().is_null() {
396 return Ok(ColumnarValue::Scalar(ScalarValue::Null));
397 }
398
399 for field_name in args.args.iter().skip(1) {
401 let field_name_scalar = match field_name {
402 ColumnarValue::Scalar(name) => name.clone(),
403 _ => {
404 return exec_err!(
405 "get_field function requires all field_name arguments to be scalars"
406 );
407 }
408 };
409
410 current = extract_single_field(current, field_name_scalar)?;
411
412 if current.data_type().is_null() {
414 return Ok(ColumnarValue::Scalar(ScalarValue::Null));
415 }
416 }
417
418 Ok(current)
419 }
420
421 fn simplify(
422 &self,
423 args: Vec<Expr>,
424 _info: &dyn datafusion_expr::simplify::SimplifyInfo,
425 ) -> Result<ExprSimplifyResult> {
426 if args.len() < 2 {
428 return Ok(ExprSimplifyResult::Original(args));
429 }
430
431 let mut path_args_stack = Vec::new();
436 let mut current_expr = &args[0];
437
438 path_args_stack.push(&args[1..]);
440
441 let base_expr = loop {
443 if let Expr::ScalarFunction(ScalarFunction {
444 func,
445 args: inner_args,
446 }) = current_expr
447 && func
448 .inner()
449 .as_any()
450 .downcast_ref::<GetFieldFunc>()
451 .is_some()
452 {
453 path_args_stack.push(&inner_args[1..]);
455
456 current_expr = &inner_args[0];
458 continue;
459 }
460 break current_expr;
462 };
463
464 if path_args_stack.len() == args.len() - 1 {
466 return Ok(ExprSimplifyResult::Original(args));
467 }
468
469 let mut merged_args = vec![base_expr.clone()];
472
473 for path_slice in path_args_stack.iter().rev() {
477 merged_args.extend_from_slice(path_slice);
478 }
479
480 Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
481 ScalarFunction::new_udf(
482 Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())),
483 merged_args,
484 ),
485 )))
486 }
487
488 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
489 if arg_types.len() < 2 {
490 return exec_err!(
491 "get_field requires at least 2 arguments, got {}",
492 arg_types.len()
493 );
494 }
495 Ok(arg_types.to_vec())
497 }
498
499 fn documentation(&self) -> Option<&Documentation> {
500 self.doc()
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use arrow::array::{ArrayRef, Int32Array, StructArray};
508 use arrow::datatypes::Fields;
509
510 #[test]
511 fn test_get_field_utf8view_key() -> Result<()> {
512 let a_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
514 let b_values = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
515
516 let fields: Fields = vec![
517 Field::new("a", DataType::Int32, true),
518 Field::new("b", DataType::Int32, true),
519 ]
520 .into();
521
522 let struct_array = StructArray::new(
523 fields,
524 vec![
525 Arc::new(a_values) as ArrayRef,
526 Arc::new(b_values) as ArrayRef,
527 ],
528 None,
529 );
530
531 let base = ColumnarValue::Array(Arc::new(struct_array));
532
533 let key = ScalarValue::Utf8View(Some("a".to_string()));
535
536 let result = extract_single_field(base, key)?;
537
538 let result_array = result.into_array(3)?;
539 let expected = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
540
541 assert_eq!(result_array.as_ref(), &expected as &dyn Array);
542
543 Ok(())
544 }
545}