datafusion_expr/
arguments.rs1use crate::Expr;
21use datafusion_common::{plan_err, Result};
22use std::collections::HashMap;
23
24pub fn resolve_function_arguments(
51 param_names: &[String],
52 args: Vec<Expr>,
53 arg_names: Vec<Option<String>>,
54) -> Result<Vec<Expr>> {
55 if args.len() != arg_names.len() {
56 return plan_err!(
57 "Internal error: args length ({}) != arg_names length ({})",
58 args.len(),
59 arg_names.len()
60 );
61 }
62
63 if arg_names.iter().all(|name| name.is_none()) {
65 return Ok(args);
66 }
67
68 validate_argument_order(&arg_names)?;
69
70 reorder_named_arguments(param_names, args, arg_names)
71}
72
73fn validate_argument_order(arg_names: &[Option<String>]) -> Result<()> {
75 let mut seen_named = false;
76 for (i, arg_name) in arg_names.iter().enumerate() {
77 match arg_name {
78 Some(_) => seen_named = true,
79 None if seen_named => {
80 return plan_err!(
81 "Positional argument at position {} follows named argument. \
82 All positional arguments must come before named arguments.",
83 i
84 );
85 }
86 None => {}
87 }
88 }
89 Ok(())
90}
91
92fn reorder_named_arguments(
94 param_names: &[String],
95 args: Vec<Expr>,
96 arg_names: Vec<Option<String>>,
97) -> Result<Vec<Expr>> {
98 let param_index_map: HashMap<&str, usize> = param_names
100 .iter()
101 .enumerate()
102 .map(|(idx, name)| (name.as_str(), idx))
103 .collect();
104
105 let positional_count = arg_names.iter().filter(|n| n.is_none()).count();
106
107 let args_len = args.len();
109
110 let expected_arg_count = param_names.len();
111
112 if positional_count > expected_arg_count {
113 return plan_err!(
114 "Too many positional arguments: expected at most {}, got {}",
115 expected_arg_count,
116 positional_count
117 );
118 }
119
120 let mut result: Vec<Option<Expr>> = vec![None; expected_arg_count];
121
122 for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() {
123 if let Some(name) = arg_name {
124 let param_index =
126 param_index_map.get(name.as_str()).copied().ok_or_else(|| {
127 datafusion_common::plan_datafusion_err!(
128 "Unknown parameter name '{}'. Valid parameters are: [{}]",
129 name,
130 param_names.join(", ")
131 )
132 })?;
133
134 if result[param_index].is_some() {
135 return plan_err!("Parameter '{}' specified multiple times", name);
136 }
137
138 result[param_index] = Some(arg);
139 } else {
140 result[i] = Some(arg);
141 }
142 }
143
144 let required_count = args_len;
146 for i in 0..required_count {
147 if result[i].is_none() {
148 return plan_err!("Missing required parameter '{}'", param_names[i]);
149 }
150 }
151
152 Ok(result.into_iter().take(required_count).flatten().collect())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::lit;
160
161 #[test]
162 fn test_all_positional() {
163 let param_names = vec!["a".to_string(), "b".to_string()];
164
165 let args = vec![lit(1), lit("hello")];
166 let arg_names = vec![None, None];
167
168 let result =
169 resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap();
170 assert_eq!(result.len(), 2);
171 }
172
173 #[test]
174 fn test_all_named() {
175 let param_names = vec!["a".to_string(), "b".to_string()];
176
177 let args = vec![lit(1), lit("hello")];
178 let arg_names = vec![Some("a".to_string()), Some("b".to_string())];
179
180 let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
181 assert_eq!(result.len(), 2);
182 }
183
184 #[test]
185 fn test_named_reordering() {
186 let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
187
188 let args = vec![lit(3.0), lit(1), lit("hello")];
190 let arg_names = vec![
191 Some("c".to_string()),
192 Some("a".to_string()),
193 Some("b".to_string()),
194 ];
195
196 let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
197
198 assert_eq!(result.len(), 3);
200 assert_eq!(result[0], lit(1));
201 assert_eq!(result[1], lit("hello"));
202 assert_eq!(result[2], lit(3.0));
203 }
204
205 #[test]
206 fn test_mixed_positional_and_named() {
207 let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
208
209 let args = vec![lit(1), lit(3.0), lit("hello")];
211 let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())];
212
213 let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
214
215 assert_eq!(result.len(), 3);
217 assert_eq!(result[0], lit(1));
218 assert_eq!(result[1], lit("hello"));
219 assert_eq!(result[2], lit(3.0));
220 }
221
222 #[test]
223 fn test_positional_after_named_error() {
224 let param_names = vec!["a".to_string(), "b".to_string()];
225
226 let args = vec![lit(1), lit("hello")];
228 let arg_names = vec![Some("a".to_string()), None];
229
230 let result = resolve_function_arguments(¶m_names, args, arg_names);
231 assert!(result.is_err());
232 assert!(result
233 .unwrap_err()
234 .to_string()
235 .contains("Positional argument"));
236 }
237
238 #[test]
239 fn test_unknown_parameter_name() {
240 let param_names = vec!["a".to_string(), "b".to_string()];
241
242 let args = vec![lit(1), lit("hello")];
244 let arg_names = vec![Some("x".to_string()), Some("b".to_string())];
245
246 let result = resolve_function_arguments(¶m_names, args, arg_names);
247 assert!(result.is_err());
248 assert!(result
249 .unwrap_err()
250 .to_string()
251 .contains("Unknown parameter"));
252 }
253
254 #[test]
255 fn test_duplicate_parameter_name() {
256 let param_names = vec!["a".to_string(), "b".to_string()];
257
258 let args = vec![lit(1), lit(2)];
260 let arg_names = vec![Some("a".to_string()), Some("a".to_string())];
261
262 let result = resolve_function_arguments(¶m_names, args, arg_names);
263 assert!(result.is_err());
264 assert!(result
265 .unwrap_err()
266 .to_string()
267 .contains("specified multiple times"));
268 }
269
270 #[test]
271 fn test_missing_required_parameter() {
272 let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
273
274 let args = vec![lit(1), lit(3.0)];
276 let arg_names = vec![Some("a".to_string()), Some("c".to_string())];
277
278 let result = resolve_function_arguments(¶m_names, args, arg_names);
279 assert!(result.is_err());
280 assert!(result
281 .unwrap_err()
282 .to_string()
283 .contains("Missing required parameter"));
284 }
285}