1use arrow::datatypes::{DataType, Field, FieldRef};
19use datafusion_common::{Result, exec_err, internal_err};
20use datafusion_expr::{
21 ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
22 Signature, Volatility,
23};
24use datafusion_macros::user_doc;
25
26#[user_doc(
27 doc_section(label = "Other Functions"),
28 description = "Attaches Arrow field metadata (key/value pairs) to the input expression. Keys must be non-empty constant strings and values must be constant strings (empty values are allowed). Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`.",
29 syntax_example = "with_metadata(expression, key1, value1[, key2, value2, ...])",
30 sql_example = r#"```sql
31> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1));
32+---------------------------------------------------------------+
33| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) |
34+---------------------------------------------------------------+
35| ms |
36+---------------------------------------------------------------+
37> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1));
38+--------------------------+
39| {source: sensor, unit: ms} |
40+--------------------------+
41```"#,
42 argument(
43 name = "expression",
44 description = "The expression whose output Arrow field should be annotated. Values flow through unchanged."
45 ),
46 argument(
47 name = "key",
48 description = "Metadata key. Must be a non-empty constant string literal."
49 ),
50 argument(
51 name = "value",
52 description = "Metadata value. Must be a constant string literal (may be empty)."
53 )
54)]
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
56pub struct WithMetadataFunc {
57 signature: Signature,
58}
59
60impl Default for WithMetadataFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl WithMetadataFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::variadic_any(Volatility::Immutable),
70 }
71 }
72}
73
74impl ScalarUDFImpl for WithMetadataFunc {
75 fn name(&self) -> &str {
76 "with_metadata"
77 }
78
79 fn signature(&self) -> &Signature {
80 &self.signature
81 }
82
83 fn documentation(&self) -> Option<&Documentation> {
84 self.doc()
85 }
86
87 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
88 internal_err!(
89 "with_metadata: return_type called instead of return_field_from_args"
90 )
91 }
92
93 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
94 if args.arg_fields.len() < 3 {
97 return exec_err!(
98 "with_metadata requires the input expression plus at least one (key, value) pair (minimum 3 arguments), got {}",
99 args.arg_fields.len()
100 );
101 }
102 if args.arg_fields.len().is_multiple_of(2) {
103 return exec_err!(
104 "with_metadata requires an odd number of arguments (expression followed by key/value pairs), got {}",
105 args.arg_fields.len()
106 );
107 }
108
109 let input_field = &args.arg_fields[0];
110 let mut metadata = input_field.metadata().clone();
111
112 for pair_idx in 0..((args.scalar_arguments.len() - 1) / 2) {
114 let key_idx = 1 + pair_idx * 2;
115 let value_idx = key_idx + 1;
116
117 let key = args.scalar_arguments[key_idx]
118 .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
119 .ok_or_else(|| {
120 datafusion_common::DataFusionError::Execution(format!(
121 "with_metadata requires argument {key_idx} (key) to be a non-empty constant string"
122 ))
123 })?;
124
125 let value = args.scalar_arguments[value_idx]
126 .and_then(|sv| sv.try_as_str().flatten())
127 .ok_or_else(|| {
128 datafusion_common::DataFusionError::Execution(format!(
129 "with_metadata requires argument {value_idx} (value) to be a constant string"
130 ))
131 })?;
132
133 metadata.insert(key.to_string(), value.to_string());
134 }
135
136 let field = Field::new(
140 input_field.name(),
141 input_field.data_type().clone(),
142 input_field.is_nullable(),
143 )
144 .with_metadata(metadata);
145
146 Ok(field.into())
147 }
148
149 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
150 Ok(args.args[0].clone())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use arrow::datatypes::Field;
161 use datafusion_common::ScalarValue;
162 use std::sync::Arc;
163
164 fn field(name: &str, dt: DataType, nullable: bool) -> FieldRef {
165 Arc::new(Field::new(name, dt, nullable))
166 }
167
168 fn str_lit(s: &str) -> ScalarValue {
169 ScalarValue::Utf8(Some(s.to_string()))
170 }
171
172 #[test]
173 fn attaches_single_key() {
174 let udf = WithMetadataFunc::new();
175 let input = field("my_col", DataType::Int32, true);
176 let k = str_lit("unit");
177 let v = str_lit("ms");
178 let fields = [
179 Arc::clone(&input),
180 field("", DataType::Utf8, false),
181 field("", DataType::Utf8, false),
182 ];
183 let scalars = [None, Some(&k), Some(&v)];
184 let ret = udf
185 .return_field_from_args(ReturnFieldArgs {
186 arg_fields: &fields,
187 scalar_arguments: &scalars,
188 })
189 .unwrap();
190 assert_eq!(ret.name(), "my_col");
191 assert_eq!(ret.data_type(), &DataType::Int32);
192 assert!(ret.is_nullable());
193 assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("ms"));
194 }
195
196 #[test]
197 fn merges_existing_metadata_and_overwrites_on_collision() {
198 let udf = WithMetadataFunc::new();
199 let mut existing = Field::new("x", DataType::Float64, false);
200 existing.set_metadata(
201 [
202 ("keep".to_string(), "yes".to_string()),
203 ("unit".to_string(), "old".to_string()),
204 ]
205 .into_iter()
206 .collect(),
207 );
208 let input: FieldRef = Arc::new(existing);
209 let k = str_lit("unit");
210 let v = str_lit("new");
211 let fields = [
212 Arc::clone(&input),
213 field("", DataType::Utf8, false),
214 field("", DataType::Utf8, false),
215 ];
216 let scalars = [None, Some(&k), Some(&v)];
217 let ret = udf
218 .return_field_from_args(ReturnFieldArgs {
219 arg_fields: &fields,
220 scalar_arguments: &scalars,
221 })
222 .unwrap();
223 assert_eq!(ret.name(), "x");
224 assert!(!ret.is_nullable());
225 assert_eq!(ret.metadata().get("keep").map(String::as_str), Some("yes"));
226 assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("new"));
227 }
228
229 #[test]
230 fn multiple_pairs() {
231 let udf = WithMetadataFunc::new();
232 let input = field("c", DataType::Utf8, true);
233 let k1 = str_lit("a");
234 let v1 = str_lit("1");
235 let k2 = str_lit("b");
236 let v2 = str_lit("2");
237 let fields = [
238 Arc::clone(&input),
239 field("", DataType::Utf8, false),
240 field("", DataType::Utf8, false),
241 field("", DataType::Utf8, false),
242 field("", DataType::Utf8, false),
243 ];
244 let scalars = [None, Some(&k1), Some(&v1), Some(&k2), Some(&v2)];
245 let ret = udf
246 .return_field_from_args(ReturnFieldArgs {
247 arg_fields: &fields,
248 scalar_arguments: &scalars,
249 })
250 .unwrap();
251 assert_eq!(ret.metadata().get("a").map(String::as_str), Some("1"));
252 assert_eq!(ret.metadata().get("b").map(String::as_str), Some("2"));
253 }
254
255 #[test]
256 fn rejects_even_arity() {
257 let udf = WithMetadataFunc::new();
258 let input = field("c", DataType::Int32, true);
259 let a = str_lit("a");
260 let b = str_lit("b");
261 let c = str_lit("c");
262 let fields = [
264 Arc::clone(&input),
265 field("", DataType::Utf8, false),
266 field("", DataType::Utf8, false),
267 field("", DataType::Utf8, false),
268 ];
269 let scalars = [None, Some(&a), Some(&b), Some(&c)];
270 let err = udf
271 .return_field_from_args(ReturnFieldArgs {
272 arg_fields: &fields,
273 scalar_arguments: &scalars,
274 })
275 .unwrap_err();
276 assert!(err.to_string().contains("odd number"));
277 }
278
279 #[test]
280 fn rejects_too_few_args() {
281 let udf = WithMetadataFunc::new();
282 let input = field("c", DataType::Int32, true);
283 let k = str_lit("a");
284 let fields = [Arc::clone(&input), field("", DataType::Utf8, false)];
285 let scalars = [None, Some(&k)];
286 let err = udf
287 .return_field_from_args(ReturnFieldArgs {
288 arg_fields: &fields,
289 scalar_arguments: &scalars,
290 })
291 .unwrap_err();
292 assert!(err.to_string().contains("at least one"));
293 }
294
295 #[test]
296 fn allows_empty_value() {
297 let udf = WithMetadataFunc::new();
298 let input = field("c", DataType::Int32, true);
299 let k = str_lit("unit");
300 let v = str_lit("");
301 let fields = [
302 Arc::clone(&input),
303 field("", DataType::Utf8, false),
304 field("", DataType::Utf8, false),
305 ];
306 let scalars = [None, Some(&k), Some(&v)];
307 let ret = udf
308 .return_field_from_args(ReturnFieldArgs {
309 arg_fields: &fields,
310 scalar_arguments: &scalars,
311 })
312 .unwrap();
313 assert_eq!(ret.metadata().get("unit").map(String::as_str), Some(""));
314 }
315
316 #[test]
317 fn rejects_non_literal_key() {
318 let udf = WithMetadataFunc::new();
319 let input = field("c", DataType::Int32, true);
320 let v = str_lit("v");
321 let fields = [
322 Arc::clone(&input),
323 field("", DataType::Utf8, true),
324 field("", DataType::Utf8, false),
325 ];
326 let scalars = [None, None, Some(&v)];
327 let err = udf
328 .return_field_from_args(ReturnFieldArgs {
329 arg_fields: &fields,
330 scalar_arguments: &scalars,
331 })
332 .unwrap_err();
333 assert!(err.to_string().contains("non-empty constant string"));
334 }
335}