Skip to main content

datafusion_functions/core/
with_metadata.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, FieldRef};
19use datafusion_common::{Result, exec_err, internal_err};
20use datafusion_expr::{
21    ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
22    Signature, Volatility,
23};
24use datafusion_macros::user_doc;
25
26#[user_doc(
27    doc_section(label = "Other Functions"),
28    description = "Attaches Arrow field metadata (key/value pairs) to the input expression. Keys must be non-empty constant strings and values must be constant strings (empty values are allowed). Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`.",
29    syntax_example = "with_metadata(expression, key1, value1[, key2, value2, ...])",
30    sql_example = r#"```sql
31> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1));
32+---------------------------------------------------------------+
33| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) |
34+---------------------------------------------------------------+
35| ms                                                            |
36+---------------------------------------------------------------+
37> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1));
38+--------------------------+
39| {source: sensor, unit: ms} |
40+--------------------------+
41```"#,
42    argument(
43        name = "expression",
44        description = "The expression whose output Arrow field should be annotated. Values flow through unchanged."
45    ),
46    argument(
47        name = "key",
48        description = "Metadata key. Must be a non-empty constant string literal."
49    ),
50    argument(
51        name = "value",
52        description = "Metadata value. Must be a constant string literal (may be empty)."
53    )
54)]
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct WithMetadataFunc {
57    signature: Signature,
58}
59
60impl Default for WithMetadataFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl WithMetadataFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::variadic_any(Volatility::Immutable),
70        }
71    }
72}
73
74impl ScalarUDFImpl for WithMetadataFunc {
75    fn name(&self) -> &str {
76        "with_metadata"
77    }
78
79    fn signature(&self) -> &Signature {
80        &self.signature
81    }
82
83    fn documentation(&self) -> Option<&Documentation> {
84        self.doc()
85    }
86
87    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
88        internal_err!(
89            "with_metadata: return_type called instead of return_field_from_args"
90        )
91    }
92
93    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
94        // Require at least the value expression plus one (key, value) pair,
95        // and an odd total (1 + 2*N).
96        if args.arg_fields.len() < 3 {
97            return exec_err!(
98                "with_metadata requires the input expression plus at least one (key, value) pair (minimum 3 arguments), got {}",
99                args.arg_fields.len()
100            );
101        }
102        if args.arg_fields.len().is_multiple_of(2) {
103            return exec_err!(
104                "with_metadata requires an odd number of arguments (expression followed by key/value pairs), got {}",
105                args.arg_fields.len()
106            );
107        }
108
109        let input_field = &args.arg_fields[0];
110        let mut metadata = input_field.metadata().clone();
111
112        // Keys are at indices 1, 3, 5, ...; values at 2, 4, 6, ...
113        for pair_idx in 0..((args.scalar_arguments.len() - 1) / 2) {
114            let key_idx = 1 + pair_idx * 2;
115            let value_idx = key_idx + 1;
116
117            let key = args.scalar_arguments[key_idx]
118                .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
119                .ok_or_else(|| {
120                    datafusion_common::DataFusionError::Execution(format!(
121                        "with_metadata requires argument {key_idx} (key) to be a non-empty constant string"
122                    ))
123                })?;
124
125            let value = args.scalar_arguments[value_idx]
126                .and_then(|sv| sv.try_as_str().flatten())
127                .ok_or_else(|| {
128                    datafusion_common::DataFusionError::Execution(format!(
129                        "with_metadata requires argument {value_idx} (value) to be a constant string"
130                    ))
131                })?;
132
133            metadata.insert(key.to_string(), value.to_string());
134        }
135
136        // Preserve the input field's name, data type, and nullability; only the
137        // metadata changes. This makes `with_metadata(col, ...)` a true
138        // pass-through annotation from a schema perspective.
139        let field = Field::new(
140            input_field.name(),
141            input_field.data_type().clone(),
142            input_field.is_nullable(),
143        )
144        .with_metadata(metadata);
145
146        Ok(field.into())
147    }
148
149    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
150        // Pure value pass-through. The metadata was attached to the return
151        // field during planning and flows through record batch schemas; the
152        // physical operator does not need to rebuild arrays.
153        Ok(args.args[0].clone())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use arrow::datatypes::Field;
161    use datafusion_common::ScalarValue;
162    use std::sync::Arc;
163
164    fn field(name: &str, dt: DataType, nullable: bool) -> FieldRef {
165        Arc::new(Field::new(name, dt, nullable))
166    }
167
168    fn str_lit(s: &str) -> ScalarValue {
169        ScalarValue::Utf8(Some(s.to_string()))
170    }
171
172    #[test]
173    fn attaches_single_key() {
174        let udf = WithMetadataFunc::new();
175        let input = field("my_col", DataType::Int32, true);
176        let k = str_lit("unit");
177        let v = str_lit("ms");
178        let fields = [
179            Arc::clone(&input),
180            field("", DataType::Utf8, false),
181            field("", DataType::Utf8, false),
182        ];
183        let scalars = [None, Some(&k), Some(&v)];
184        let ret = udf
185            .return_field_from_args(ReturnFieldArgs {
186                arg_fields: &fields,
187                scalar_arguments: &scalars,
188            })
189            .unwrap();
190        assert_eq!(ret.name(), "my_col");
191        assert_eq!(ret.data_type(), &DataType::Int32);
192        assert!(ret.is_nullable());
193        assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("ms"));
194    }
195
196    #[test]
197    fn merges_existing_metadata_and_overwrites_on_collision() {
198        let udf = WithMetadataFunc::new();
199        let mut existing = Field::new("x", DataType::Float64, false);
200        existing.set_metadata(
201            [
202                ("keep".to_string(), "yes".to_string()),
203                ("unit".to_string(), "old".to_string()),
204            ]
205            .into_iter()
206            .collect(),
207        );
208        let input: FieldRef = Arc::new(existing);
209        let k = str_lit("unit");
210        let v = str_lit("new");
211        let fields = [
212            Arc::clone(&input),
213            field("", DataType::Utf8, false),
214            field("", DataType::Utf8, false),
215        ];
216        let scalars = [None, Some(&k), Some(&v)];
217        let ret = udf
218            .return_field_from_args(ReturnFieldArgs {
219                arg_fields: &fields,
220                scalar_arguments: &scalars,
221            })
222            .unwrap();
223        assert_eq!(ret.name(), "x");
224        assert!(!ret.is_nullable());
225        assert_eq!(ret.metadata().get("keep").map(String::as_str), Some("yes"));
226        assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("new"));
227    }
228
229    #[test]
230    fn multiple_pairs() {
231        let udf = WithMetadataFunc::new();
232        let input = field("c", DataType::Utf8, true);
233        let k1 = str_lit("a");
234        let v1 = str_lit("1");
235        let k2 = str_lit("b");
236        let v2 = str_lit("2");
237        let fields = [
238            Arc::clone(&input),
239            field("", DataType::Utf8, false),
240            field("", DataType::Utf8, false),
241            field("", DataType::Utf8, false),
242            field("", DataType::Utf8, false),
243        ];
244        let scalars = [None, Some(&k1), Some(&v1), Some(&k2), Some(&v2)];
245        let ret = udf
246            .return_field_from_args(ReturnFieldArgs {
247                arg_fields: &fields,
248                scalar_arguments: &scalars,
249            })
250            .unwrap();
251        assert_eq!(ret.metadata().get("a").map(String::as_str), Some("1"));
252        assert_eq!(ret.metadata().get("b").map(String::as_str), Some("2"));
253    }
254
255    #[test]
256    fn rejects_even_arity() {
257        let udf = WithMetadataFunc::new();
258        let input = field("c", DataType::Int32, true);
259        let a = str_lit("a");
260        let b = str_lit("b");
261        let c = str_lit("c");
262        // 4 args total: input + 3 literals (odd key count)
263        let fields = [
264            Arc::clone(&input),
265            field("", DataType::Utf8, false),
266            field("", DataType::Utf8, false),
267            field("", DataType::Utf8, false),
268        ];
269        let scalars = [None, Some(&a), Some(&b), Some(&c)];
270        let err = udf
271            .return_field_from_args(ReturnFieldArgs {
272                arg_fields: &fields,
273                scalar_arguments: &scalars,
274            })
275            .unwrap_err();
276        assert!(err.to_string().contains("odd number"));
277    }
278
279    #[test]
280    fn rejects_too_few_args() {
281        let udf = WithMetadataFunc::new();
282        let input = field("c", DataType::Int32, true);
283        let k = str_lit("a");
284        let fields = [Arc::clone(&input), field("", DataType::Utf8, false)];
285        let scalars = [None, Some(&k)];
286        let err = udf
287            .return_field_from_args(ReturnFieldArgs {
288                arg_fields: &fields,
289                scalar_arguments: &scalars,
290            })
291            .unwrap_err();
292        assert!(err.to_string().contains("at least one"));
293    }
294
295    #[test]
296    fn allows_empty_value() {
297        let udf = WithMetadataFunc::new();
298        let input = field("c", DataType::Int32, true);
299        let k = str_lit("unit");
300        let v = str_lit("");
301        let fields = [
302            Arc::clone(&input),
303            field("", DataType::Utf8, false),
304            field("", DataType::Utf8, false),
305        ];
306        let scalars = [None, Some(&k), Some(&v)];
307        let ret = udf
308            .return_field_from_args(ReturnFieldArgs {
309                arg_fields: &fields,
310                scalar_arguments: &scalars,
311            })
312            .unwrap();
313        assert_eq!(ret.metadata().get("unit").map(String::as_str), Some(""));
314    }
315
316    #[test]
317    fn rejects_non_literal_key() {
318        let udf = WithMetadataFunc::new();
319        let input = field("c", DataType::Int32, true);
320        let v = str_lit("v");
321        let fields = [
322            Arc::clone(&input),
323            field("", DataType::Utf8, true),
324            field("", DataType::Utf8, false),
325        ];
326        let scalars = [None, None, Some(&v)];
327        let err = udf
328            .return_field_from_args(ReturnFieldArgs {
329                arg_fields: &fields,
330                scalar_arguments: &scalars,
331            })
332            .unwrap_err();
333        assert!(err.to_string().contains("non-empty constant string"));
334    }
335}