datafusion_expr/
arguments.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Argument resolution logic for named function parameters
19
20use crate::Expr;
21use datafusion_common::{plan_err, Result};
22use std::collections::HashMap;
23
24/// Resolves function arguments, handling named and positional notation.
25///
26/// This function validates and reorders arguments to match the function's parameter names
27/// when named arguments are used.
28///
29/// # Rules
30/// - All positional arguments must come before named arguments
31/// - Named arguments can be in any order after positional arguments
32/// - Parameter names follow SQL identifier rules: unquoted names are case-insensitive
33///   (normalized to lowercase), quoted names are case-sensitive
34/// - No duplicate parameter names allowed
35///
36/// # Arguments
37/// * `param_names` - The function's parameter names in order
38/// * `args` - The argument expressions
39/// * `arg_names` - Optional parameter name for each argument
40///
41/// # Returns
42/// A vector of expressions in the correct order matching the parameter names
43///
44/// # Examples
45/// ```text
46/// Given parameters ["a", "b", "c"]
47/// And call: func(10, c => 30, b => 20)
48/// Returns: [Expr(10), Expr(20), Expr(30)]
49/// ```
50pub fn resolve_function_arguments(
51    param_names: &[String],
52    args: Vec<Expr>,
53    arg_names: Vec<Option<String>>,
54) -> Result<Vec<Expr>> {
55    if args.len() != arg_names.len() {
56        return plan_err!(
57            "Internal error: args length ({}) != arg_names length ({})",
58            args.len(),
59            arg_names.len()
60        );
61    }
62
63    // Check if all arguments are positional (fast path)
64    if arg_names.iter().all(|name| name.is_none()) {
65        return Ok(args);
66    }
67
68    validate_argument_order(&arg_names)?;
69
70    reorder_named_arguments(param_names, args, arg_names)
71}
72
73/// Validates that positional arguments come before named arguments
74fn validate_argument_order(arg_names: &[Option<String>]) -> Result<()> {
75    let mut seen_named = false;
76    for (i, arg_name) in arg_names.iter().enumerate() {
77        match arg_name {
78            Some(_) => seen_named = true,
79            None if seen_named => {
80                return plan_err!(
81                    "Positional argument at position {} follows named argument. \
82                     All positional arguments must come before named arguments.",
83                    i
84                );
85            }
86            None => {}
87        }
88    }
89    Ok(())
90}
91
92/// Reorders arguments based on named parameters to match signature order
93fn reorder_named_arguments(
94    param_names: &[String],
95    args: Vec<Expr>,
96    arg_names: Vec<Option<String>>,
97) -> Result<Vec<Expr>> {
98    // Build HashMap for O(1) parameter name lookups
99    let param_index_map: HashMap<&str, usize> = param_names
100        .iter()
101        .enumerate()
102        .map(|(idx, name)| (name.as_str(), idx))
103        .collect();
104
105    let positional_count = arg_names.iter().filter(|n| n.is_none()).count();
106
107    // Capture args length before consuming the vector
108    let args_len = args.len();
109
110    let expected_arg_count = param_names.len();
111
112    if positional_count > expected_arg_count {
113        return plan_err!(
114            "Too many positional arguments: expected at most {}, got {}",
115            expected_arg_count,
116            positional_count
117        );
118    }
119
120    let mut result: Vec<Option<Expr>> = vec![None; expected_arg_count];
121
122    for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() {
123        if let Some(name) = arg_name {
124            // Named argument - O(1) lookup in HashMap
125            let param_index =
126                param_index_map.get(name.as_str()).copied().ok_or_else(|| {
127                    datafusion_common::plan_datafusion_err!(
128                        "Unknown parameter name '{}'. Valid parameters are: [{}]",
129                        name,
130                        param_names.join(", ")
131                    )
132                })?;
133
134            if result[param_index].is_some() {
135                return plan_err!("Parameter '{}' specified multiple times", name);
136            }
137
138            result[param_index] = Some(arg);
139        } else {
140            result[i] = Some(arg);
141        }
142    }
143
144    // Only require parameters up to the number of arguments provided (supports optional parameters)
145    let required_count = args_len;
146    for i in 0..required_count {
147        if result[i].is_none() {
148            return plan_err!("Missing required parameter '{}'", param_names[i]);
149        }
150    }
151
152    // Return only the assigned parameters (handles optional trailing parameters)
153    Ok(result.into_iter().take(required_count).flatten().collect())
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::lit;
160
161    #[test]
162    fn test_all_positional() {
163        let param_names = vec!["a".to_string(), "b".to_string()];
164
165        let args = vec![lit(1), lit("hello")];
166        let arg_names = vec![None, None];
167
168        let result =
169            resolve_function_arguments(&param_names, args.clone(), arg_names).unwrap();
170        assert_eq!(result.len(), 2);
171    }
172
173    #[test]
174    fn test_all_named() {
175        let param_names = vec!["a".to_string(), "b".to_string()];
176
177        let args = vec![lit(1), lit("hello")];
178        let arg_names = vec![Some("a".to_string()), Some("b".to_string())];
179
180        let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();
181        assert_eq!(result.len(), 2);
182    }
183
184    #[test]
185    fn test_named_reordering() {
186        let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
187
188        // Call with: func(c => 3.0, a => 1, b => "hello")
189        let args = vec![lit(3.0), lit(1), lit("hello")];
190        let arg_names = vec![
191            Some("c".to_string()),
192            Some("a".to_string()),
193            Some("b".to_string()),
194        ];
195
196        let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();
197
198        // Should be reordered to [a, b, c] = [1, "hello", 3.0]
199        assert_eq!(result.len(), 3);
200        assert_eq!(result[0], lit(1));
201        assert_eq!(result[1], lit("hello"));
202        assert_eq!(result[2], lit(3.0));
203    }
204
205    #[test]
206    fn test_mixed_positional_and_named() {
207        let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
208
209        // Call with: func(1, c => 3.0, b => "hello")
210        let args = vec![lit(1), lit(3.0), lit("hello")];
211        let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())];
212
213        let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();
214
215        // Should be reordered to [a, b, c] = [1, "hello", 3.0]
216        assert_eq!(result.len(), 3);
217        assert_eq!(result[0], lit(1));
218        assert_eq!(result[1], lit("hello"));
219        assert_eq!(result[2], lit(3.0));
220    }
221
222    #[test]
223    fn test_positional_after_named_error() {
224        let param_names = vec!["a".to_string(), "b".to_string()];
225
226        // Call with: func(a => 1, "hello") - ERROR
227        let args = vec![lit(1), lit("hello")];
228        let arg_names = vec![Some("a".to_string()), None];
229
230        let result = resolve_function_arguments(&param_names, args, arg_names);
231        assert!(result.is_err());
232        assert!(result
233            .unwrap_err()
234            .to_string()
235            .contains("Positional argument"));
236    }
237
238    #[test]
239    fn test_unknown_parameter_name() {
240        let param_names = vec!["a".to_string(), "b".to_string()];
241
242        // Call with: func(x => 1, b => "hello") - ERROR
243        let args = vec![lit(1), lit("hello")];
244        let arg_names = vec![Some("x".to_string()), Some("b".to_string())];
245
246        let result = resolve_function_arguments(&param_names, args, arg_names);
247        assert!(result.is_err());
248        assert!(result
249            .unwrap_err()
250            .to_string()
251            .contains("Unknown parameter"));
252    }
253
254    #[test]
255    fn test_duplicate_parameter_name() {
256        let param_names = vec!["a".to_string(), "b".to_string()];
257
258        // Call with: func(a => 1, a => 2) - ERROR
259        let args = vec![lit(1), lit(2)];
260        let arg_names = vec![Some("a".to_string()), Some("a".to_string())];
261
262        let result = resolve_function_arguments(&param_names, args, arg_names);
263        assert!(result.is_err());
264        assert!(result
265            .unwrap_err()
266            .to_string()
267            .contains("specified multiple times"));
268    }
269
270    #[test]
271    fn test_missing_required_parameter() {
272        let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
273
274        // Call with: func(a => 1, c => 3.0) - missing 'b'
275        let args = vec![lit(1), lit(3.0)];
276        let arg_names = vec![Some("a".to_string()), Some("c".to_string())];
277
278        let result = resolve_function_arguments(&param_names, args, arg_names);
279        assert!(result.is_err());
280        assert!(result
281            .unwrap_err()
282            .to_string()
283            .contains("Missing required parameter"));
284    }
285}