datafusion_functions/core/
getfield.rs1use arrow::array::{
19 make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData,
20 Scalar,
21};
22use arrow::compute::SortOptions;
23use arrow::datatypes::{DataType, Field, FieldRef};
24use arrow_buffer::NullBuffer;
25use datafusion_common::cast::{as_map_array, as_struct_array};
26use datafusion_common::{
27 exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result,
28 ScalarValue,
29};
30use datafusion_expr::{
31 ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
32};
33use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
34use datafusion_macros::user_doc;
35use std::any::Any;
36use std::sync::Arc;
37
38#[user_doc(
39 doc_section(label = "Other Functions"),
40 description = r#"Returns a field within a map or a struct with the given key.
41 Note: most users invoke `get_field` indirectly via field access
42 syntax such as `my_struct_col['field_name']` which results in a call to
43 `get_field(my_struct_col, 'field_name')`."#,
44 syntax_example = "get_field(expression1, expression2)",
45 sql_example = r#"```sql
46> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow');
47> select struct(idx, v) from t as c;
48+-------------------------+
49| struct(c.idx,c.v) |
50+-------------------------+
51| {c0: data, c1: fusion} |
52| {c0: apache, c1: arrow} |
53+-------------------------+
54> select get_field((select struct(idx, v) from t), 'c0');
55+-----------------------+
56| struct(t.idx,t.v)[c0] |
57+-----------------------+
58| data |
59| apache |
60+-----------------------+
61> select get_field((select struct(idx, v) from t), 'c1');
62+-----------------------+
63| struct(t.idx,t.v)[c1] |
64+-----------------------+
65| fusion |
66| arrow |
67+-----------------------+
68```"#,
69 argument(
70 name = "expression1",
71 description = "The map or struct to retrieve a field for."
72 ),
73 argument(
74 name = "expression2",
75 description = "The field name in the map or struct to retrieve data for. Must evaluate to a string."
76 )
77)]
78#[derive(Debug)]
79pub struct GetFieldFunc {
80 signature: Signature,
81}
82
83impl Default for GetFieldFunc {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl GetFieldFunc {
90 pub fn new() -> Self {
91 Self {
92 signature: Signature::any(2, Volatility::Immutable),
93 }
94 }
95}
96
97impl ScalarUDFImpl for GetFieldFunc {
99 fn as_any(&self) -> &dyn Any {
100 self
101 }
102
103 fn name(&self) -> &str {
104 "get_field"
105 }
106
107 fn display_name(&self, args: &[Expr]) -> Result<String> {
108 let [base, field_name] = take_function_args(self.name(), args)?;
109
110 let name = match field_name {
111 Expr::Literal(name, _) => name,
112 other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
113 };
114
115 Ok(format!("{base}[{name}]"))
116 }
117
118 fn schema_name(&self, args: &[Expr]) -> Result<String> {
119 let [base, field_name] = take_function_args(self.name(), args)?;
120 let name = match field_name {
121 Expr::Literal(name, _) => name,
122 other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
123 };
124
125 Ok(format!("{}[{}]", base.schema_name(), name))
126 }
127
128 fn signature(&self) -> &Signature {
129 &self.signature
130 }
131
132 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
133 internal_err!("return_field_from_args should be called instead")
134 }
135
136 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
137 debug_assert_eq!(args.scalar_arguments.len(), 2);
139
140 match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) {
141 (DataType::Map(fields, _), _) => {
142 match fields.data_type() {
143 DataType::Struct(fields) if fields.len() == 2 => {
144 let value_field = fields.get(1).expect("fields should have exactly two members");
149
150 Ok(value_field.as_ref().clone().with_nullable(true).into())
151 },
152 _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"),
153 }
154 }
155 (DataType::Struct(fields),sv) => {
156 sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
157 .map_or_else(
158 || exec_err!("Field name must be a non-empty string"),
159 |field_name| {
160 fields.iter().find(|f| f.name() == field_name)
161 .ok_or(plan_datafusion_err!("Field {field_name} not found in struct"))
162 .map(|f| {
163 let mut child_field = f.as_ref().clone();
164
165 if args.arg_fields[0].is_nullable() {
169 child_field = child_field.with_nullable(true);
170 }
171 Arc::new(child_field)
172 })
173 })
174 },
175 (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()),
176 (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"),
177 }
178 }
179
180 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
181 let [base, field_name] = take_function_args(self.name(), args.args)?;
182
183 if base.data_type().is_null() {
184 return Ok(ColumnarValue::Scalar(ScalarValue::Null));
185 }
186
187 let arrays =
188 ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?;
189 let array = Arc::clone(&arrays[0]);
190 let name = match field_name {
191 ColumnarValue::Scalar(name) => name,
192 _ => {
193 return exec_err!(
194 "get_field function requires the argument field_name to be a string"
195 );
196 }
197 };
198
199 fn process_map_array(
200 array: Arc<dyn Array>,
201 key_array: Arc<dyn Array>,
202 ) -> Result<ColumnarValue> {
203 let map_array = as_map_array(array.as_ref())?;
204 let keys = if key_array.data_type().is_nested() {
205 let comparator = make_comparator(
206 map_array.keys().as_ref(),
207 key_array.as_ref(),
208 SortOptions::default(),
209 )?;
210 let len = map_array.keys().len().min(key_array.len());
211 let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
212 let nulls =
213 NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
214 BooleanArray::new(values, nulls)
215 } else {
216 let be_compared = Scalar::new(key_array);
217 arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
218 };
219
220 let original_data = map_array.entries().column(1).to_data();
221 let capacity = Capacities::Array(original_data.len());
222 let mut mutable =
223 MutableArrayData::with_capacities(vec![&original_data], true, capacity);
224
225 for entry in 0..map_array.len() {
226 let start = map_array.value_offsets()[entry] as usize;
227 let end = map_array.value_offsets()[entry + 1] as usize;
228
229 let maybe_matched = keys
230 .slice(start, end - start)
231 .iter()
232 .enumerate()
233 .find(|(_, t)| t.unwrap());
234
235 if maybe_matched.is_none() {
236 mutable.extend_nulls(1);
237 continue;
238 }
239 let (match_offset, _) = maybe_matched.unwrap();
240 mutable.extend(0, start + match_offset, start + match_offset + 1);
241 }
242
243 let data = mutable.freeze();
244 let data = make_array(data);
245 Ok(ColumnarValue::Array(data))
246 }
247
248 match (array.data_type(), name) {
249 (DataType::Map(_, _), ScalarValue::List(arr)) => {
250 let key_array: Arc<dyn Array> = arr;
251 process_map_array(array, key_array)
252 }
253 (DataType::Map(_, _), ScalarValue::Struct(arr)) => {
254 process_map_array(array, arr as Arc<dyn Array>)
255 }
256 (DataType::Map(_, _), other) => {
257 let data_type = other.data_type();
258 if data_type.is_nested() {
259 exec_err!("unsupported type {:?} for map access", data_type)
260 } else {
261 process_map_array(array, other.to_array()?)
262 }
263 }
264 (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
265 let as_struct_array = as_struct_array(&array)?;
266 match as_struct_array.column_by_name(&k) {
267 None => exec_err!("get indexed field {k} not found in struct"),
268 Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
269 }
270 }
271 (DataType::Struct(_), name) => exec_err!(
272 "get_field is only possible on struct with utf8 indexes. \
273 Received with {name:?} index"
274 ),
275 (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
276 (dt, name) => exec_err!(
277 "get_field is only possible on maps with utf8 indexes or struct \
278 with utf8 indexes. Received {dt:?} with {name:?} index"
279 ),
280 }
281 }
282
283 fn documentation(&self) -> Option<&Documentation> {
284 self.doc()
285 }
286}