polyglot_sql/
function_catalog.rs1use crate::dialects::DialectType;
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum FunctionNameCase {
7 #[default]
9 Insensitive,
10 Sensitive,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct FunctionSignature {
17 pub min_arity: usize,
19 pub max_arity: Option<usize>,
22}
23
24impl FunctionSignature {
25 pub const fn exact(arity: usize) -> Self {
27 Self {
28 min_arity: arity,
29 max_arity: Some(arity),
30 }
31 }
32
33 pub const fn range(min_arity: usize, max_arity: usize) -> Self {
35 Self {
36 min_arity,
37 max_arity: Some(max_arity),
38 }
39 }
40
41 pub const fn variadic(min_arity: usize) -> Self {
43 Self {
44 min_arity,
45 max_arity: None,
46 }
47 }
48
49 pub fn matches_arity(&self, arity: usize) -> bool {
51 if arity < self.min_arity {
52 return false;
53 }
54 match self.max_arity {
55 Some(max) => arity <= max,
56 None => true,
57 }
58 }
59
60 pub fn describe_arity(&self) -> String {
62 match self.max_arity {
63 Some(max) if max == self.min_arity => self.min_arity.to_string(),
64 Some(max) => format!("{}..{}", self.min_arity, max),
65 None => format!("{}+", self.min_arity),
66 }
67 }
68}
69
70pub trait FunctionCatalog: Send + Sync {
74 fn lookup(
79 &self,
80 dialect: DialectType,
81 raw_function_name: &str,
82 normalized_name: &str,
83 ) -> Option<&[FunctionSignature]>;
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct HashMapFunctionCatalog {
89 entries_normalized: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
90 entries_exact: HashMap<DialectType, HashMap<String, Vec<FunctionSignature>>>,
91 dialect_name_case: HashMap<DialectType, FunctionNameCase>,
92 function_name_case_overrides: HashMap<DialectType, HashMap<String, FunctionNameCase>>,
93}
94
95impl HashMapFunctionCatalog {
96 pub fn set_dialect_name_case(&mut self, dialect: DialectType, name_case: FunctionNameCase) {
98 self.dialect_name_case.insert(dialect, name_case);
99 }
100
101 pub fn set_function_name_case(
105 &mut self,
106 dialect: DialectType,
107 function_name: impl Into<String>,
108 name_case: FunctionNameCase,
109 ) {
110 self.function_name_case_overrides
111 .entry(dialect)
112 .or_default()
113 .insert(function_name.into().to_lowercase(), name_case);
114 }
115
116 pub fn register(
118 &mut self,
119 dialect: DialectType,
120 function_name: impl Into<String>,
121 signatures: Vec<FunctionSignature>,
122 ) {
123 let function_name = function_name.into();
124 let normalized_name = function_name.to_lowercase();
125
126 let normalized_entry = self
127 .entries_normalized
128 .entry(dialect)
129 .or_default()
130 .entry(normalized_name)
131 .or_default();
132 let exact_entry = self
133 .entries_exact
134 .entry(dialect)
135 .or_default()
136 .entry(function_name)
137 .or_default();
138
139 for sig in signatures {
140 if !normalized_entry.contains(&sig) {
141 normalized_entry.push(sig.clone());
142 }
143 if !exact_entry.contains(&sig) {
144 exact_entry.push(sig);
145 }
146 }
147 }
148
149 fn effective_name_case(&self, dialect: DialectType, normalized_name: &str) -> FunctionNameCase {
150 if let Some(overrides) = self.function_name_case_overrides.get(&dialect) {
151 if let Some(name_case) = overrides.get(normalized_name) {
152 return *name_case;
153 }
154 }
155 self.dialect_name_case
156 .get(&dialect)
157 .copied()
158 .unwrap_or_default()
159 }
160}
161
162impl FunctionCatalog for HashMapFunctionCatalog {
163 fn lookup(
164 &self,
165 dialect: DialectType,
166 raw_function_name: &str,
167 normalized_name: &str,
168 ) -> Option<&[FunctionSignature]> {
169 match self.effective_name_case(dialect, normalized_name) {
170 FunctionNameCase::Insensitive => self
171 .entries_normalized
172 .get(&dialect)
173 .and_then(|entries| entries.get(normalized_name))
174 .map(|v| v.as_slice()),
175 FunctionNameCase::Sensitive => self
176 .entries_exact
177 .get(&dialect)
178 .and_then(|entries| entries.get(raw_function_name))
179 .map(|v| v.as_slice()),
180 }
181 }
182}