Skip to main content

polyglot_sql/
function_catalog.rs

1use crate::dialects::DialectType;
2use std::collections::HashMap;
3
4/// Function-name casing behavior for lookup.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum FunctionNameCase {
7    /// Function names are compared case-insensitively.
8    #[default]
9    Insensitive,
10    /// Function names are compared with exact case.
11    Sensitive,
12}
13
14/// Function signature metadata used by semantic validation.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct FunctionSignature {
17    /// Minimum number of positional arguments.
18    pub min_arity: usize,
19    /// Maximum number of positional arguments.
20    /// `None` means unbounded/variadic.
21    pub max_arity: Option<usize>,
22}
23
24impl FunctionSignature {
25    /// Build an exact-arity signature.
26    pub const fn exact(arity: usize) -> Self {
27        Self {
28            min_arity: arity,
29            max_arity: Some(arity),
30        }
31    }
32
33    /// Build a bounded arity range signature.
34    pub const fn range(min_arity: usize, max_arity: usize) -> Self {
35        Self {
36            min_arity,
37            max_arity: Some(max_arity),
38        }
39    }
40
41    /// Build a variadic signature with a minimum arity.
42    pub const fn variadic(min_arity: usize) -> Self {
43        Self {
44            min_arity,
45            max_arity: None,
46        }
47    }
48
49    /// Whether an observed arity matches this signature.
50    pub fn matches_arity(&self, arity: usize) -> bool {
51        if arity < self.min_arity {
52            return false;
53        }
54        match self.max_arity {
55            Some(max) => arity <= max,
56            None => true,
57        }
58    }
59
60    /// Render a human-readable arity descriptor.
61    pub fn describe_arity(&self) -> String {
62        match self.max_arity {
63            Some(max) if max == self.min_arity => self.min_arity.to_string(),
64            Some(max) => format!("{}..{}", self.min_arity, max),
65            None => format!("{}+", self.min_arity),
66        }
67    }
68}
69
70/// Catalog abstraction for dialect-specific function metadata.
71///
72/// Implementations can be backed by generated files, external crates, or runtime-loaded assets.
73pub trait FunctionCatalog: Send + Sync {
74    /// Lookup overloads for a function name in a given dialect.
75    ///
76    /// `raw_function_name` should preserve user query casing.
77    /// `normalized_name` should be canonicalized/lowercased by the caller.
78    fn lookup(
79        &self,
80        dialect: DialectType,
81        raw_function_name: &str,
82        normalized_name: &str,
83    ) -> Option<&[FunctionSignature]>;
84}
85
86/// Minimal in-memory catalog implementation for runtime registration and tests.
87#[derive(Debug, Clone, Default)]
88pub struct HashMapFunctionCatalog {
89    entries_normalized: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
90    entries_exact: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
91    dialect_name_case: HashMap<DialectType, FunctionNameCase>,
92    function_name_case_overrides: HashMap<DialectType, HashMap<String, FunctionNameCase>>,
93}
94
95impl HashMapFunctionCatalog {
96    /// Set default function-name casing behavior for a dialect.
97    pub fn set_dialect_name_case(&mut self, dialect: DialectType, name_case: FunctionNameCase) {
98        self.dialect_name_case.insert(dialect, name_case);
99    }
100
101    /// Set optional per-function casing behavior override for a dialect.
102    ///
103    /// The override key is normalized to lowercase.
104    pub fn set_function_name_case(
105        &mut self,
106        dialect: DialectType,
107        function_name: impl Into<String>,
108        name_case: FunctionNameCase,
109    ) {
110        self.function_name_case_overrides
111            .entry(dialect)
112            .or_default()
113            .insert(function_name.into().to_lowercase(), name_case);
114    }
115
116    /// Register overloads for a function in a dialect.
117    pub fn register(
118        &mut self,
119        dialect: DialectType,
120        function_name: impl Into<String>,
121        signatures: Vec<FunctionSignature>,
122    ) {
123        let function_name = function_name.into();
124        let normalized_name = function_name.to_lowercase();
125
126        let normalized_entry = self
127            .entries_normalized
128            .entry(dialect)
129            .or_default()
130            .entry(normalized_name)
131            .or_default();
132        let exact_entry = self
133            .entries_exact
134            .entry(dialect)
135            .or_default()
136            .entry(function_name)
137            .or_default();
138
139        for sig in signatures {
140            if !normalized_entry.contains(&sig) {
141                normalized_entry.push(sig.clone());
142            }
143            if !exact_entry.contains(&sig) {
144                exact_entry.push(sig);
145            }
146        }
147    }
148
149    fn effective_name_case(&self, dialect: DialectType, normalized_name: &str) -> FunctionNameCase {
150        if let Some(overrides) = self.function_name_case_overrides.get(&dialect) {
151            if let Some(name_case) = overrides.get(normalized_name) {
152                return *name_case;
153            }
154        }
155        self.dialect_name_case
156            .get(&dialect)
157            .copied()
158            .unwrap_or_default()
159    }
160}
161
162impl FunctionCatalog for HashMapFunctionCatalog {
163    fn lookup(
164        &self,
165        dialect: DialectType,
166        raw_function_name: &str,
167        normalized_name: &str,
168    ) -> Option<&[FunctionSignature]> {
169        match self.effective_name_case(dialect, normalized_name) {
170            FunctionNameCase::Insensitive => self
171                .entries_normalized
172                .get(&dialect)
173                .and_then(|entries| entries.get(normalized_name))
174                .map(|v| v.as_slice()),
175            FunctionNameCase::Sensitive => self
176                .entries_exact
177                .get(&dialect)
178                .and_then(|entries| entries.get(raw_function_name))
179                .map(|v| v.as_slice()),
180        }
181    }
182}