datafusion_functions/core/
arrow_metadata.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{MapBuilder, StringBuilder};
19use arrow::datatypes::{DataType, Field, Fields};
20use datafusion_common::types::logical_string;
21use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
22use datafusion_expr::{
23    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24    TypeSignature, TypeSignatureClass, Volatility,
25};
26use datafusion_macros::user_doc;
27use std::any::Any;
28use std::sync::Arc;
29
30#[user_doc(
31    doc_section(label = "Other Functions"),
32    description = "Returns the metadata of the input expression. If a key is provided, returns the value for that key. If no key is provided, returns a Map of all metadata.",
33    syntax_example = "arrow_metadata(expression[, key])",
34    sql_example = r#"```sql
35> select arrow_metadata(col) from table;
36+----------------------------+
37| arrow_metadata(table.col)  |
38+----------------------------+
39| {k: v}                     |
40+----------------------------+
41> select arrow_metadata(col, 'k') from table;
42+-------------------------------+
43| arrow_metadata(table.col, 'k')|
44+-------------------------------+
45| v                             |
46+-------------------------------+
47```"#,
48    argument(
49        name = "expression",
50        description = "The expression to retrieve metadata from. Can be a column or other expression."
51    ),
52    argument(
53        name = "key",
54        description = "Optional. The specific metadata key to retrieve."
55    )
56)]
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub struct ArrowMetadataFunc {
59    signature: Signature,
60}
61
62impl ArrowMetadataFunc {
63    pub fn new() -> Self {
64        Self {
65            signature: Signature::one_of(
66                vec![
67                    TypeSignature::Coercible(vec![Coercion::new_exact(
68                        TypeSignatureClass::Any,
69                    )]),
70                    TypeSignature::Coercible(vec![
71                        Coercion::new_exact(TypeSignatureClass::Any),
72                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                    ]),
74                ],
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl Default for ArrowMetadataFunc {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl ScalarUDFImpl for ArrowMetadataFunc {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91
92    fn name(&self) -> &str {
93        "arrow_metadata"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    fn documentation(&self) -> Option<&Documentation> {
101        self.doc()
102    }
103
104    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
105        if arg_types.len() == 2 {
106            Ok(DataType::Utf8)
107        } else if arg_types.len() == 1 {
108            Ok(DataType::Map(
109                Arc::new(Field::new(
110                    "entries",
111                    DataType::Struct(Fields::from(vec![
112                        Field::new("keys", DataType::Utf8, false),
113                        Field::new("values", DataType::Utf8, true),
114                    ])),
115                    false,
116                )),
117                false,
118            ))
119        } else {
120            internal_err!("arrow_metadata requires 1 or 2 arguments")
121        }
122    }
123
124    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125        let metadata = args.arg_fields[0].metadata();
126
127        if args.args.len() == 2 {
128            let key = match &args.args[1] {
129                ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key,
130                _ => {
131                    return exec_err!(
132                        "Second argument to arrow_metadata must be a string literal key"
133                    );
134                }
135            };
136            let value = metadata.get(key).cloned();
137            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(value)))
138        } else if args.args.len() == 1 {
139            let mut map_builder =
140                MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
141
142            let mut entries: Vec<_> = metadata.iter().collect();
143            entries.sort_by_key(|(k, _)| *k);
144
145            for (k, v) in entries {
146                map_builder.keys().append_value(k);
147                map_builder.values().append_value(v);
148            }
149            map_builder.append(true)?;
150
151            let map_array = map_builder.finish();
152
153            Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
154                &map_array, 0,
155            )?))
156        } else {
157            internal_err!("arrow_metadata requires 1 or 2 arguments")
158        }
159    }
160}