use crate::dialects::DialectType;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FunctionNameCase {
#[default]
Insensitive,
Sensitive,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionSignature {
pub min_arity: usize,
pub max_arity: Option<usize>,
}
impl FunctionSignature {
pub const fn exact(arity: usize) -> Self {
Self {
min_arity: arity,
max_arity: Some(arity),
}
}
pub const fn range(min_arity: usize, max_arity: usize) -> Self {
Self {
min_arity,
max_arity: Some(max_arity),
}
}
pub const fn variadic(min_arity: usize) -> Self {
Self {
min_arity,
max_arity: None,
}
}
pub fn matches_arity(&self, arity: usize) -> bool {
if arity < self.min_arity {
return false;
}
match self.max_arity {
Some(max) => arity <= max,
None => true,
}
}
pub fn describe_arity(&self) -> String {
match self.max_arity {
Some(max) if max == self.min_arity => self.min_arity.to_string(),
Some(max) => format!("{}..{}", self.min_arity, max),
None => format!("{}+", self.min_arity),
}
}
}
pub trait FunctionCatalog: Send + Sync {
fn lookup(
&self,
dialect: DialectType,
raw_function_name: &str,
normalized_name: &str,
) -> Option<&[FunctionSignature]>;
}
#[derive(Debug, Clone, Default)]
pub struct HashMapFunctionCatalog {
entries_normalized: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
entries_exact: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
dialect_name_case: HashMap<DialectType, FunctionNameCase>,
function_name_case_overrides: HashMap<DialectType, HashMap<String, FunctionNameCase>>,
}
impl HashMapFunctionCatalog {
pub fn set_dialect_name_case(&mut self, dialect: DialectType, name_case: FunctionNameCase) {
self.dialect_name_case.insert(dialect, name_case);
}
pub fn set_function_name_case(
&mut self,
dialect: DialectType,
function_name: impl Into<String>,
name_case: FunctionNameCase,
) {
self.function_name_case_overrides
.entry(dialect)
.or_default()
.insert(function_name.into().to_lowercase(), name_case);
}
pub fn register(
&mut self,
dialect: DialectType,
function_name: impl Into<String>,
signatures: Vec<FunctionSignature>,
) {
let function_name = function_name.into();
let normalized_name = function_name.to_lowercase();
let normalized_entry = self
.entries_normalized
.entry(dialect)
.or_default()
.entry(normalized_name)
.or_default();
let exact_entry = self
.entries_exact
.entry(dialect)
.or_default()
.entry(function_name)
.or_default();
for sig in signatures {
if !normalized_entry.contains(&sig) {
normalized_entry.push(sig.clone());
}
if !exact_entry.contains(&sig) {
exact_entry.push(sig);
}
}
}
fn effective_name_case(&self, dialect: DialectType, normalized_name: &str) -> FunctionNameCase {
if let Some(overrides) = self.function_name_case_overrides.get(&dialect) {
if let Some(name_case) = overrides.get(normalized_name) {
return *name_case;
}
}
self.dialect_name_case
.get(&dialect)
.copied()
.unwrap_or_default()
}
}
impl FunctionCatalog for HashMapFunctionCatalog {
fn lookup(
&self,
dialect: DialectType,
raw_function_name: &str,
normalized_name: &str,
) -> Option<&[FunctionSignature]> {
match self.effective_name_case(dialect, normalized_name) {
FunctionNameCase::Insensitive => self
.entries_normalized
.get(&dialect)
.and_then(|entries| entries.get(normalized_name))
.map(|v| v.as_slice()),
FunctionNameCase::Sensitive => self
.entries_exact
.get(&dialect)
.and_then(|entries| entries.get(raw_function_name))
.map(|v| v.as_slice()),
}
}
}