use std::any::Any;
use std::sync::Arc;
use arrow::array::{
Array, BooleanArray, Capacities, MutableArrayData, Scalar, make_array,
make_comparator,
};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, FieldRef};
use arrow_buffer::NullBuffer;
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, plan_datafusion_err,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
#[user_doc(
doc_section(label = "Other Functions"),
description = r#"Returns a field within a map or a struct with the given key.
Supports nested field access by providing multiple field names.
Note: most users invoke `get_field` indirectly via field access
syntax such as `my_struct_col['field_name']` which results in a call to
`get_field(my_struct_col, 'field_name')`.
Nested access like `my_struct['a']['b']` is optimized to a single call:
`get_field(my_struct, 'a', 'b')`."#,
syntax_example = "get_field(expression, field_name[, field_name2, ...])",
sql_example = r#"```sql
> -- Access a field from a struct column
> create table test( struct_col) as values
({name: 'Alice', age: 30}),
({name: 'Bob', age: 25});
> select struct_col from test;
+-----------------------------+
| struct_col |
+-----------------------------+
| {name: Alice, age: 30} |
| {name: Bob, age: 25} |
+-----------------------------+
> select struct_col['name'] as name from test;
+-------+
| name |
+-------+
| Alice |
| Bob |
+-------+
> -- Nested field access with multiple arguments
> create table test(struct_col) as values
({outer: {inner_val: 42}});
> select struct_col['outer']['inner_val'] as result from test;
+--------+
| result |
+--------+
| 42 |
+--------+
```"#,
argument(
name = "expression",
description = "The map or struct to retrieve a field from."
),
argument(
name = "field_name",
description = "The field name(s) to access, in order for nested access. Must evaluate to strings."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct GetFieldFunc {
signature: Signature,
}
impl Default for GetFieldFunc {
fn default() -> Self {
Self::new()
}
}
fn process_map_array(
array: &dyn Array,
key_array: Arc<dyn Array>,
) -> Result<ColumnarValue> {
let map_array = as_map_array(array)?;
let keys = if key_array.data_type().is_nested() {
let comparator = make_comparator(
map_array.keys().as_ref(),
key_array.as_ref(),
SortOptions::default(),
)?;
let len = map_array.keys().len().min(key_array.len());
let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
let nulls = NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
BooleanArray::new(values, nulls)
} else {
let be_compared = Scalar::new(key_array);
arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
};
let original_data = map_array.entries().column(1).to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);
for entry in 0..map_array.len() {
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;
let maybe_matched = keys
.slice(start, end - start)
.iter()
.enumerate()
.find(|(_, t)| t.unwrap());
if maybe_matched.is_none() {
mutable.extend_nulls(1);
continue;
}
let (match_offset, _) = maybe_matched.unwrap();
mutable.extend(0, start + match_offset, start + match_offset + 1);
}
let data = mutable.freeze();
let data = make_array(data);
Ok(ColumnarValue::Array(data))
}
fn process_map_with_nested_key(
array: &dyn Array,
key_array: &dyn Array,
) -> Result<ColumnarValue> {
let map_array = as_map_array(array)?;
let comparator =
make_comparator(map_array.keys().as_ref(), key_array, SortOptions::default())?;
let original_data = map_array.entries().column(1).to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);
for entry in 0..map_array.len() {
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;
let mut found_match = false;
for i in start..end {
if comparator(i, 0).is_eq() {
mutable.extend(0, i, i + 1);
found_match = true;
break;
}
}
if !found_match {
mutable.extend_nulls(1);
}
}
let data = mutable.freeze();
let data = make_array(data);
Ok(ColumnarValue::Array(data))
}
fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<ColumnarValue> {
let arrays = ColumnarValue::values_to_arrays(&[base])?;
let array = Arc::clone(&arrays[0]);
let string_value = name.try_as_str().flatten().map(|s| s.to_string());
match (array.data_type(), name, string_value) {
(DataType::Map(_, _), ScalarValue::List(arr), _) => {
let key_array: Arc<dyn Array> = arr;
process_map_array(&array, key_array)
}
(DataType::Map(_, _), ScalarValue::Struct(arr), _) => {
process_map_array(&array, arr as Arc<dyn Array>)
}
(DataType::Map(_, _), other, _) => {
let data_type = other.data_type();
if data_type.is_nested() {
process_map_with_nested_key(&array, &other.to_array()?)
} else {
process_map_array(&array, other.to_array()?)
}
}
(DataType::Struct(_), _, Some(k)) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(&k) {
None => exec_err!("Field {k} not found in struct"),
Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
}
}
(DataType::Struct(_), name, _) => exec_err!(
"get_field is only possible on struct with utf8 indexes. \
Received with {name:?} index"
),
(DataType::Null, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
(dt, name, _) => exec_err!(
"get_field is only possible on maps or structs. Received {dt} with {name:?} index"
),
}
}
impl GetFieldFunc {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for GetFieldFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"get_field"
}
fn display_name(&self, args: &[Expr]) -> Result<String> {
if args.len() < 2 {
return exec_err!(
"get_field requires at least 2 arguments, got {}",
args.len()
);
}
let base = &args[0];
let field_names: Vec<String> = args[1..]
.iter()
.map(|f| match f {
Expr::Literal(name, _) => name.to_string(),
other => other.schema_name().to_string(),
})
.collect();
Ok(format!("{}[{}]", base, field_names.join("][")))
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
if args.len() < 2 {
return exec_err!(
"get_field requires at least 2 arguments, got {}",
args.len()
);
}
let base = &args[0];
let field_names: Vec<String> = args[1..]
.iter()
.map(|f| match f {
Expr::Literal(name, _) => name.to_string(),
other => other.schema_name().to_string(),
})
.collect();
Ok(format!(
"{}[{}]",
base.schema_name(),
field_names.join("][")
))
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
if args.scalar_arguments.len() < 2 {
return exec_err!(
"get_field requires at least 2 arguments, got {}",
args.scalar_arguments.len()
);
}
let mut current_field = Arc::clone(&args.arg_fields[0]);
for (i, sv) in args.scalar_arguments.iter().enumerate().skip(1) {
match current_field.data_type() {
DataType::Map(map_field, _) => {
match map_field.data_type() {
DataType::Struct(fields) if fields.len() == 2 => {
let value_field = fields
.get(1)
.expect("fields should have exactly two members");
current_field = Arc::new(
value_field.as_ref().clone().with_nullable(true),
);
}
_ => {
return exec_err!(
"Map fields must contain a Struct with exactly 2 fields"
);
}
}
}
DataType::Struct(fields) => {
let field_name = sv
.as_ref()
.and_then(|sv| {
sv.try_as_str().flatten().filter(|s| !s.is_empty())
})
.ok_or_else(|| {
datafusion_common::DataFusionError::Execution(
"Field name must be a non-empty string".to_string(),
)
})?;
let child_field = fields
.iter()
.find(|f| f.name() == field_name)
.ok_or_else(|| {
plan_datafusion_err!("Field {field_name} not found in struct")
})?;
let mut new_field = child_field.as_ref().clone();
if current_field.is_nullable() {
new_field = new_field.with_nullable(true);
}
current_field = Arc::new(new_field);
}
DataType::Null => {
return Ok(Field::new(self.name(), DataType::Null, true).into());
}
other => {
return exec_err!(
"Cannot access field at argument {}: type {} is not Struct, Map, or Null",
i,
other
);
}
}
}
Ok(current_field)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.len() < 2 {
return exec_err!(
"get_field requires at least 2 arguments, got {}",
args.args.len()
);
}
let mut current = args.args[0].clone();
if current.data_type().is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
}
for field_name in args.args.iter().skip(1) {
let field_name_scalar = match field_name {
ColumnarValue::Scalar(name) => name.clone(),
_ => {
return exec_err!(
"get_field function requires all field_name arguments to be scalars"
);
}
};
current = extract_single_field(current, field_name_scalar)?;
if current.data_type().is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
}
}
Ok(current)
}
fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn datafusion_expr::simplify::SimplifyInfo,
) -> Result<ExprSimplifyResult> {
if args.len() < 2 {
return Ok(ExprSimplifyResult::Original(args));
}
let mut path_args_stack = Vec::new();
let mut current_expr = &args[0];
path_args_stack.push(&args[1..]);
let base_expr = loop {
if let Expr::ScalarFunction(ScalarFunction {
func,
args: inner_args,
}) = current_expr
&& func
.inner()
.as_any()
.downcast_ref::<GetFieldFunc>()
.is_some()
{
path_args_stack.push(&inner_args[1..]);
current_expr = &inner_args[0];
continue;
}
break current_expr;
};
if path_args_stack.len() == args.len() - 1 {
return Ok(ExprSimplifyResult::Original(args));
}
let mut merged_args = vec![base_expr.clone()];
for path_slice in path_args_stack.iter().rev() {
merged_args.extend_from_slice(path_slice);
}
Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
ScalarFunction::new_udf(
Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())),
merged_args,
),
)))
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() < 2 {
return exec_err!(
"get_field requires at least 2 arguments, got {}",
arg_types.len()
);
}
Ok(arg_types.to_vec())
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{ArrayRef, Int32Array, StructArray};
use arrow::datatypes::Fields;
#[test]
fn test_get_field_utf8view_key() -> Result<()> {
let a_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
let b_values = Int32Array::from(vec![Some(10), Some(20), Some(30)]);
let fields: Fields = vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]
.into();
let struct_array = StructArray::new(
fields,
vec![
Arc::new(a_values) as ArrayRef,
Arc::new(b_values) as ArrayRef,
],
None,
);
let base = ColumnarValue::Array(Arc::new(struct_array));
let key = ScalarValue::Utf8View(Some("a".to_string()));
let result = extract_single_field(base, key)?;
let result_array = result.into_array(3)?;
let expected = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
assert_eq!(result_array.as_ref(), &expected as &dyn Array);
Ok(())
}
}