flowscope_core/completion/
functions.rs1use std::sync::LazyLock;
14
15use crate::generated::{
16 all_function_signatures, get_function_signature, FunctionCategory, FunctionSignature,
17 ReturnTypeRule,
18};
19use crate::types::{CompletionClause, CompletionItem, CompletionItemCategory, CompletionItemKind};
20
21static FUNCTION_COMPLETION_CACHE: LazyLock<Vec<CachedFunctionItem>> = LazyLock::new(|| {
25 all_function_signatures()
26 .map(|sig| CachedFunctionItem {
27 item: function_to_completion_item(&sig),
28 name_lower: sig.name.to_string(),
29 category: sig.category,
30 })
31 .collect()
32});
33
34struct CachedFunctionItem {
36 item: CompletionItem,
38 name_lower: String,
40 category: FunctionCategory,
42}
43
44const SCORE_AGGREGATE_IN_GROUP_BY_CONTEXT: i32 = 200;
46const SCORE_AGGREGATE_NO_GROUP_BY: i32 = -100;
48const SCORE_WINDOW_IN_WINDOW_CONTEXT: i32 = 150;
50const SCORE_AGGREGATE_IN_WHERE_PENALTY: i32 = -300;
52const KEYWORD_STYLE_FUNCTIONS: &[&str] = &[
54 "current_catalog",
55 "current_date",
56 "current_datetime",
57 "current_database",
58 "current_path",
59 "current_role",
60 "current_schema",
61 "current_session",
62 "current_time",
63 "current_timestamp",
64 "current_timestamp_ltz",
65 "current_timestamp_ntz",
66 "current_timestamp_tz",
67 "current_user",
68 "localtime",
69 "localtimestamp",
70 "session_user",
71 "system_user",
72 "user",
73];
74
75fn uses_keyword_call_style(sig: &FunctionSignature) -> bool {
76 KEYWORD_STYLE_FUNCTIONS
77 .iter()
78 .any(|name| sig.name.eq_ignore_ascii_case(name))
79}
80
81pub fn function_to_completion_item(sig: &FunctionSignature) -> CompletionItem {
89 let category = match sig.category {
90 FunctionCategory::Aggregate => CompletionItemCategory::Aggregate,
91 FunctionCategory::Window | FunctionCategory::Scalar => CompletionItemCategory::Function,
92 };
93
94 let detail = Some(sig.format_signature());
96
97 CompletionItem {
98 label: sig.display_name.to_string(),
99 insert_text: if uses_keyword_call_style(sig) {
100 sig.display_name.to_string()
101 } else {
102 format!("{}(", sig.display_name)
103 },
104 kind: CompletionItemKind::Function,
105 category,
106 score: 0, clause_specific: false,
108 detail,
109 }
110}
111
112pub fn function_return_type_display(name: &str) -> Option<&'static str> {
114 get_function_signature(name).and_then(|sig| {
115 sig.return_type.map(|rt| match rt {
116 ReturnTypeRule::Integer => "INTEGER",
117 ReturnTypeRule::Numeric => "NUMERIC",
118 ReturnTypeRule::Text => "TEXT",
119 ReturnTypeRule::Timestamp => "TIMESTAMP",
120 ReturnTypeRule::Boolean => "BOOLEAN",
121 ReturnTypeRule::Date => "DATE",
122 ReturnTypeRule::MatchFirstArg => "T",
123 })
124 })
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct FunctionCompletionContext {
130 pub clause: CompletionClause,
132 pub has_group_by: bool,
134 pub in_window_context: bool,
136 pub prefix: Option<String>,
138}
139
140pub fn get_function_completions(ctx: &FunctionCompletionContext) -> Vec<CompletionItem> {
157 let prefix_lower = ctx.prefix.as_ref().map(|p| p.to_ascii_lowercase());
158
159 FUNCTION_COMPLETION_CACHE
160 .iter()
161 .filter(|cached| {
162 match &prefix_lower {
164 Some(prefix) => cached.name_lower.starts_with(prefix.as_str()),
165 None => true,
166 }
167 })
168 .map(|cached| {
169 let mut item = cached.item.clone();
170
171 let score_adjustment =
173 compute_function_score_adjustment_by_category(cached.category, ctx);
174 item.score = score_adjustment;
175
176 if score_adjustment > 0 {
178 item.clause_specific = true;
179 }
180
181 item
182 })
183 .collect()
184}
185
186fn compute_function_score_adjustment_by_category(
188 category: FunctionCategory,
189 ctx: &FunctionCompletionContext,
190) -> i32 {
191 let mut adjustment = 0;
192
193 match category {
194 FunctionCategory::Aggregate => {
195 if ctx.has_group_by {
197 adjustment += SCORE_AGGREGATE_IN_GROUP_BY_CONTEXT;
198 } else {
199 adjustment += SCORE_AGGREGATE_NO_GROUP_BY;
200 }
201
202 if ctx.clause == CompletionClause::Where {
204 adjustment += SCORE_AGGREGATE_IN_WHERE_PENALTY;
205 }
206
207 if ctx.clause == CompletionClause::Having {
209 adjustment += SCORE_AGGREGATE_IN_GROUP_BY_CONTEXT;
210 }
211 }
212 FunctionCategory::Window => {
213 if ctx.in_window_context || ctx.clause == CompletionClause::Window {
215 adjustment += SCORE_WINDOW_IN_WINDOW_CONTEXT;
216 }
217 }
218 FunctionCategory::Scalar => {
219 }
222 }
223
224 adjustment
225}
226
227pub fn is_aggregate(name: &str) -> bool {
229 get_function_signature(name)
230 .map(|sig| sig.category == FunctionCategory::Aggregate)
231 .unwrap_or(false)
232}
233
234pub fn is_window(name: &str) -> bool {
236 get_function_signature(name)
237 .map(|sig| sig.category == FunctionCategory::Window)
238 .unwrap_or(false)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_function_to_completion_item() {
247 let sig = get_function_signature("count").expect("COUNT should exist");
248 let item = function_to_completion_item(&sig);
249
250 assert_eq!(item.label, "COUNT");
251 assert_eq!(item.insert_text, "COUNT(");
252 assert_eq!(item.kind, CompletionItemKind::Function);
253 assert_eq!(item.category, CompletionItemCategory::Aggregate);
254 assert!(item.detail.is_some());
255 }
256
257 #[test]
258 fn test_keyword_function_inserts_plain_identifier() {
259 let sig = get_function_signature("current_date").expect("CURRENT_DATE should exist");
260 let item = function_to_completion_item(&sig);
261 assert_eq!(item.insert_text, "CURRENT_DATE");
262 }
263
264 #[test]
265 fn test_zero_arg_regular_function_still_opens_parenthesis() {
266 let sig = get_function_signature("pi").expect("PI should exist");
267 let item = function_to_completion_item(&sig);
268 assert_eq!(item.insert_text, "PI(");
269 }
270
271 #[test]
272 fn test_function_completion_with_return_type() {
273 let sig = get_function_signature("count").expect("COUNT should exist");
274 let formatted = sig.format_signature();
275
276 assert!(
278 formatted.contains("INTEGER"),
279 "Expected INTEGER in signature: {}",
280 formatted
281 );
282 }
283
284 #[test]
285 fn test_aggregate_boosted_with_group_by() {
286 let ctx = FunctionCompletionContext {
287 clause: CompletionClause::Select,
288 has_group_by: true,
289 in_window_context: false,
290 prefix: Some("sum".to_string()),
291 };
292
293 let items = get_function_completions(&ctx);
294 let sum_item = items.iter().find(|i| i.label == "SUM");
295
296 assert!(sum_item.is_some(), "SUM should be in completions");
297 let sum = sum_item.unwrap();
298 assert!(
299 sum.score > 0,
300 "SUM should have positive score with GROUP BY"
301 );
302 }
303
304 #[test]
305 fn test_aggregate_penalized_in_where() {
306 let ctx = FunctionCompletionContext {
307 clause: CompletionClause::Where,
308 has_group_by: false,
309 in_window_context: false,
310 prefix: Some("sum".to_string()),
311 };
312
313 let items = get_function_completions(&ctx);
314 let sum_item = items.iter().find(|i| i.label == "SUM");
315
316 assert!(sum_item.is_some(), "SUM should still appear in completions");
317 let sum = sum_item.unwrap();
318 assert!(
319 sum.score < 0,
320 "SUM should have negative score in WHERE clause"
321 );
322 }
323
324 #[test]
325 fn test_prefix_filtering() {
326 let ctx = FunctionCompletionContext {
327 clause: CompletionClause::Select,
328 has_group_by: false,
329 in_window_context: false,
330 prefix: Some("row_".to_string()),
331 };
332
333 let items = get_function_completions(&ctx);
334
335 assert!(items.iter().all(|i| i.label.starts_with("ROW_")));
337 assert!(items.iter().any(|i| i.label == "ROW_NUMBER"));
338 }
339
340 #[test]
341 fn test_window_function_in_window_context() {
342 let ctx = FunctionCompletionContext {
343 clause: CompletionClause::Window,
344 has_group_by: false,
345 in_window_context: true,
346 prefix: Some("row_".to_string()),
347 };
348
349 let items = get_function_completions(&ctx);
350 let row_number = items.iter().find(|i| i.label == "ROW_NUMBER");
351
352 assert!(row_number.is_some());
353 assert!(
354 row_number.unwrap().score > 0,
355 "ROW_NUMBER should have positive score in window context"
356 );
357 }
358
359 #[test]
360 fn test_function_signature_parameter_order_preserved() {
361 let sig = get_function_signature("substring").expect("SUBSTRING should exist");
362 let names: Vec<_> = sig.params.iter().map(|p| p.name).collect();
363
364 assert_eq!(names, vec!["this", "start", "length"]);
365 assert!(sig.params[0].required);
366 assert!(!sig.params[1].required);
367 assert!(!sig.params[2].required);
368 }
369}