use crate::Expr;
use datafusion_common::{plan_err, Result};
use std::collections::HashMap;
pub fn resolve_function_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
if args.len() != arg_names.len() {
return plan_err!(
"Internal error: args length ({}) != arg_names length ({})",
args.len(),
arg_names.len()
);
}
if arg_names.iter().all(|name| name.is_none()) {
return Ok(args);
}
validate_argument_order(&arg_names)?;
reorder_named_arguments(param_names, args, arg_names)
}
fn validate_argument_order(arg_names: &[Option<String>]) -> Result<()> {
let mut seen_named = false;
for (i, arg_name) in arg_names.iter().enumerate() {
match arg_name {
Some(_) => seen_named = true,
None if seen_named => {
return plan_err!(
"Positional argument at position {} follows named argument. \
All positional arguments must come before named arguments.",
i
);
}
None => {}
}
}
Ok(())
}
fn reorder_named_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
let param_index_map: HashMap<&str, usize> = param_names
.iter()
.enumerate()
.map(|(idx, name)| (name.as_str(), idx))
.collect();
let positional_count = arg_names.iter().filter(|n| n.is_none()).count();
let args_len = args.len();
let expected_arg_count = param_names.len();
if positional_count > expected_arg_count {
return plan_err!(
"Too many positional arguments: expected at most {}, got {}",
expected_arg_count,
positional_count
);
}
let mut result: Vec<Option<Expr>> = vec![None; expected_arg_count];
for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() {
if let Some(name) = arg_name {
let param_index =
param_index_map.get(name.as_str()).copied().ok_or_else(|| {
datafusion_common::plan_datafusion_err!(
"Unknown parameter name '{}'. Valid parameters are: [{}]",
name,
param_names.join(", ")
)
})?;
if result[param_index].is_some() {
return plan_err!("Parameter '{}' specified multiple times", name);
}
result[param_index] = Some(arg);
} else {
result[i] = Some(arg);
}
}
let required_count = args_len;
for i in 0..required_count {
if result[i].is_none() {
return plan_err!("Missing required parameter '{}'", param_names[i]);
}
}
Ok(result.into_iter().take(required_count).flatten().collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lit;
#[test]
fn test_all_positional() {
let param_names = vec!["a".to_string(), "b".to_string()];
let args = vec![lit(1), lit("hello")];
let arg_names = vec![None, None];
let result =
resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_all_named() {
let param_names = vec!["a".to_string(), "b".to_string()];
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), Some("b".to_string())];
let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_named_reordering() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let args = vec![lit(3.0), lit(1), lit("hello")];
let arg_names = vec![
Some("c".to_string()),
Some("a".to_string()),
Some("b".to_string()),
];
let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}
#[test]
fn test_mixed_positional_and_named() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let args = vec![lit(1), lit(3.0), lit("hello")];
let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())];
let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}
#[test]
fn test_positional_after_named_error() {
let param_names = vec!["a".to_string(), "b".to_string()];
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), None];
let result = resolve_function_arguments(¶m_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Positional argument"));
}
#[test]
fn test_unknown_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("x".to_string()), Some("b".to_string())];
let result = resolve_function_arguments(¶m_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unknown parameter"));
}
#[test]
fn test_duplicate_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];
let args = vec![lit(1), lit(2)];
let arg_names = vec![Some("a".to_string()), Some("a".to_string())];
let result = resolve_function_arguments(¶m_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("specified multiple times"));
}
#[test]
fn test_missing_required_parameter() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let args = vec![lit(1), lit(3.0)];
let arg_names = vec![Some("a".to_string()), Some("c".to_string())];
let result = resolve_function_arguments(¶m_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Missing required parameter"));
}
}