Skip to main content

datafusion_functions/core/
arrow_metadata.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{MapBuilder, StringBuilder};
19use arrow::datatypes::{DataType, Field, Fields};
20use datafusion_common::types::logical_string;
21use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
22use datafusion_expr::{
23    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24    TypeSignature, TypeSignatureClass, Volatility,
25};
26use datafusion_macros::user_doc;
27use std::sync::Arc;
28
29#[user_doc(
30    doc_section(label = "Other Functions"),
31    description = "Returns the metadata of the input expression. If a key is provided, returns the value for that key. If no key is provided, returns a Map of all metadata.",
32    syntax_example = "arrow_metadata(expression[, key])",
33    sql_example = r#"```sql
34> select arrow_metadata(col) from table;
35+----------------------------+
36| arrow_metadata(table.col)  |
37+----------------------------+
38| {k: v}                     |
39+----------------------------+
40> select arrow_metadata(col, 'k') from table;
41+-------------------------------+
42| arrow_metadata(table.col, 'k')|
43+-------------------------------+
44| v                             |
45+-------------------------------+
46```"#,
47    argument(
48        name = "expression",
49        description = "The expression to retrieve metadata from. Can be a column or other expression."
50    ),
51    argument(
52        name = "key",
53        description = "Optional. The specific metadata key to retrieve."
54    )
55)]
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct ArrowMetadataFunc {
58    signature: Signature,
59}
60
61impl ArrowMetadataFunc {
62    pub fn new() -> Self {
63        Self {
64            signature: Signature::one_of(
65                vec![
66                    TypeSignature::Coercible(vec![Coercion::new_exact(
67                        TypeSignatureClass::Any,
68                    )]),
69                    TypeSignature::Coercible(vec![
70                        Coercion::new_exact(TypeSignatureClass::Any),
71                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    ]),
73                ],
74                Volatility::Immutable,
75            ),
76        }
77    }
78}
79
80impl Default for ArrowMetadataFunc {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl ScalarUDFImpl for ArrowMetadataFunc {
87    fn name(&self) -> &str {
88        "arrow_metadata"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn documentation(&self) -> Option<&Documentation> {
96        self.doc()
97    }
98
99    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100        if arg_types.len() == 2 {
101            Ok(DataType::Utf8)
102        } else if arg_types.len() == 1 {
103            Ok(DataType::Map(
104                Arc::new(Field::new(
105                    "entries",
106                    DataType::Struct(Fields::from(vec![
107                        Field::new("keys", DataType::Utf8, false),
108                        Field::new("values", DataType::Utf8, true),
109                    ])),
110                    false,
111                )),
112                false,
113            ))
114        } else {
115            internal_err!("arrow_metadata requires 1 or 2 arguments")
116        }
117    }
118
119    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
120        let metadata = args.arg_fields[0].metadata();
121
122        if args.args.len() == 2 {
123            let key = match &args.args[1] {
124                ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key,
125                _ => {
126                    return exec_err!(
127                        "Second argument to arrow_metadata must be a string literal key"
128                    );
129                }
130            };
131            let value = metadata.get(key).cloned();
132            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(value)))
133        } else if args.args.len() == 1 {
134            let mut map_builder =
135                MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
136
137            let mut entries: Vec<_> = metadata.iter().collect();
138            entries.sort_by_key(|(k, _)| *k);
139
140            for (k, v) in entries {
141                map_builder.keys().append_value(k);
142                map_builder.values().append_value(v);
143            }
144            map_builder.append(true)?;
145
146            let map_array = map_builder.finish();
147
148            Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
149                &map_array, 0,
150            )?))
151        } else {
152            internal_err!("arrow_metadata requires 1 or 2 arguments")
153        }
154    }
155}