Skip to main content

gobby_code/search/fts/
common.rs

1use std::collections::HashSet;
2
3use postgres::Client;
4use postgres::types::ToSql;
5
6use crate::config::{Context, ProjectIndexScope};
7use crate::db;
8use crate::models::Symbol;
9use crate::visibility;
10
11// Keep BM25 query sanitation centralized in gobby-core so gcode and gwiki
12// escape pg_search's DSL identically.
13pub use gobby_core::search::sanitize_pg_search_query;
14
15pub(super) type PgParam = Box<dyn ToSql + Sync>;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ResolvedGraphSymbol {
19    pub id: String,
20    pub display_name: String,
21}
22
23#[derive(Debug, Clone, Copy, Default)]
24pub(super) struct SymbolFilters<'a> {
25    pub(super) kind: Option<&'a str>,
26    pub(super) language: Option<&'a str>,
27    pub(super) paths: &'a [String],
28}
29
30#[derive(Debug, Clone)]
31pub(super) enum SymbolOrder {
32    Bm25Score,
33    Name,
34    ExactCaseFirst(String),
35}
36
37impl SymbolOrder {
38    fn sql(&self) -> String {
39        match self {
40            Self::Bm25Score => "pg_search.score(cs.id) DESC, cs.id ASC".to_string(),
41            Self::Name => "cs.name ASC, cs.file_path ASC, cs.line_start ASC".to_string(),
42            Self::ExactCaseFirst(query_param) => format!(
43                "CASE WHEN cs.name = {q} OR cs.qualified_name = {q} THEN 0 ELSE 1 END,
44                 cs.file_path ASC,
45                 cs.line_start ASC",
46                q = query_param
47            ),
48        }
49    }
50}
51
52pub const FILTERED_FETCH_CAP: usize = 10_000;
53
54pub(super) fn push_param<T>(params: &mut Vec<PgParam>, value: T) -> String
55where
56    T: ToSql + Sync + 'static,
57{
58    params.push(Box::new(value));
59    format!("${}", params.len())
60}
61
62pub(super) fn param_refs(params: &[PgParam]) -> Vec<&(dyn ToSql + Sync)> {
63    params
64        .iter()
65        .map(|param| param.as_ref() as &(dyn ToSql + Sync))
66        .collect()
67}
68
69pub(super) fn query_count(
70    conn: &mut Client,
71    sql: &str,
72    params: &[PgParam],
73) -> Result<usize, postgres::Error> {
74    let refs = param_refs(params);
75    let row = conn.query_one(sql, &refs)?;
76    Ok(row.try_get::<_, i64>("count")? as usize)
77}
78
79pub(super) fn push_visible_project_file_filter(
80    conditions: &mut Vec<String>,
81    params: &mut Vec<PgParam>,
82    row_alias: &str,
83    indexed_file_alias: &str,
84    ctx: &Context,
85) {
86    let tombstone = push_param(params, visibility::TOMBSTONE_LANGUAGE.to_string());
87    conditions.push(format!("{indexed_file_alias}.language != {tombstone}"));
88
89    match &ctx.index_scope {
90        ProjectIndexScope::Single => {
91            let project = push_param(params, ctx.project_id.clone());
92            conditions.push(format!("{row_alias}.project_id = {project}"));
93        }
94        ProjectIndexScope::Overlay {
95            overlay_project_id,
96            parent_project_id,
97            ..
98        } => {
99            let overlay = push_param(params, overlay_project_id.clone());
100            let parent = push_param(params, parent_project_id.clone());
101            conditions.push(format!(
102                "({row_alias}.project_id = {overlay}
103                  OR (
104                      {row_alias}.project_id = {parent}
105                      AND NOT EXISTS (
106                          SELECT 1 FROM code_indexed_files shadow
107                          WHERE shadow.project_id = {overlay}
108                            AND shadow.file_path = {row_alias}.file_path
109                      )
110                  ))"
111            ));
112        }
113    }
114}
115
116/// Escape LIKE wildcards (`%`, `_`) and the backslash escape char itself.
117pub(super) fn escape_like(s: &str) -> String {
118    let mut out = String::with_capacity(s.len());
119    for c in s.chars() {
120        if matches!(c, '\\' | '%' | '_') {
121            out.push('\\');
122        }
123        out.push(c);
124    }
125    out
126}
127
128/// Extract a SQL LIKE prefix from a glob pattern for index-assisted pre-filtering.
129pub(super) fn glob_to_like_prefix(pattern: &str) -> Option<String> {
130    let prefix: String = pattern
131        .chars()
132        .take_while(|c| !matches!(c, '*' | '?' | '['))
133        .collect();
134    if prefix.is_empty() {
135        None
136    } else {
137        Some(format!("{}%", escape_like(&prefix)))
138    }
139}
140
141pub(super) fn has_glob_meta(path: &str) -> bool {
142    path.chars().any(|c| matches!(c, '*' | '?' | '['))
143}
144
145pub fn expand_paths(paths: &[String]) -> Vec<String> {
146    let mut expanded = Vec::new();
147    let mut seen = HashSet::new();
148    for path in paths {
149        let trimmed = path.trim().trim_end_matches('/');
150        if trimmed.is_empty() {
151            continue;
152        }
153
154        let patterns = if has_glob_meta(trimmed) {
155            vec![trimmed.to_string()]
156        } else {
157            vec![trimmed.to_string(), format!("{trimmed}/**")]
158        };
159        for pattern in patterns {
160            if seen.insert(pattern.clone()) {
161                expanded.push(pattern);
162            }
163        }
164    }
165    expanded
166}
167
168pub fn compile_patterns(paths: &[String]) -> anyhow::Result<Vec<glob::Pattern>> {
169    paths
170        .iter()
171        .map(|path| {
172            glob::Pattern::new(path).map_err(|e| anyhow::anyhow!("invalid path glob `{path}`: {e}"))
173        })
174        .collect()
175}
176
177pub(super) fn path_like_prefixes(paths: &[String]) -> Option<Vec<String>> {
178    if paths.is_empty() {
179        return Some(Vec::new());
180    }
181
182    let mut prefixes = Vec::with_capacity(paths.len());
183    for path in paths {
184        prefixes.push(glob_to_like_prefix(path)?);
185    }
186    Some(prefixes)
187}
188
189pub fn path_filter_falls_back(paths: &[String]) -> bool {
190    !paths.is_empty() && path_like_prefixes(paths).is_none()
191}
192
193pub(super) fn push_path_filter(
194    conditions: &mut Vec<String>,
195    params: &mut Vec<PgParam>,
196    alias: &str,
197    paths: &[String],
198) -> bool {
199    let requires_post_filter = !paths.is_empty();
200    let Some(prefixes) = path_like_prefixes(paths) else {
201        for path in paths
202            .iter()
203            .filter(|path| glob_to_like_prefix(path).is_none())
204        {
205            log::warn!(
206                "omitting SQL path filter for alias `{alias}` because path filter `{path}` cannot be converted to a LIKE prefix; relying on post-query glob matching",
207            );
208        }
209        return requires_post_filter;
210    };
211    if prefixes.is_empty() {
212        return requires_post_filter;
213    }
214
215    let predicates = prefixes
216        .into_iter()
217        .map(|prefix| {
218            let placeholder = push_param(params, prefix);
219            format!("{alias}.file_path LIKE {placeholder} ESCAPE '\\'")
220        })
221        .collect::<Vec<_>>();
222    conditions.push(format!("({})", predicates.join(" OR ")));
223    requires_post_filter
224}
225
226pub(super) fn push_symbol_filters(
227    conditions: &mut Vec<String>,
228    params: &mut Vec<PgParam>,
229    alias: &str,
230    filters: SymbolFilters<'_>,
231) -> bool {
232    if let Some(kind) = filters.kind {
233        let placeholder = push_param(params, kind.to_string());
234        conditions.push(format!("{alias}.kind = {placeholder}"));
235    }
236    if let Some(language) = filters.language {
237        let placeholder = push_param(params, language.to_string());
238        conditions.push(format!("{alias}.language = {placeholder}"));
239    }
240    push_path_filter(conditions, params, alias, filters.paths)
241}
242
243pub(super) fn symbols_matching_paths(
244    symbols: impl IntoIterator<Item = Symbol>,
245    paths: &[String],
246) -> Vec<Symbol> {
247    let patterns = match compile_patterns(paths) {
248        Ok(patterns) => patterns,
249        Err(error) => {
250            log::warn!("invalid post-query symbol path filter: {error}");
251            return Vec::new();
252        }
253    };
254    symbols
255        .into_iter()
256        .filter(|symbol| {
257            patterns.is_empty()
258                || patterns
259                    .iter()
260                    .any(|pattern| pattern.matches(&symbol.file_path))
261        })
262        .collect()
263}
264
265pub(super) fn append_unique_symbols(
266    out: &mut Vec<Symbol>,
267    seen: &mut HashSet<String>,
268    symbols: Vec<Symbol>,
269    limit: usize,
270) {
271    if limit == 0 {
272        return;
273    }
274    for symbol in symbols {
275        if seen.insert(symbol.id.clone()) {
276            out.push(symbol);
277            if out.len() >= limit {
278                return;
279            }
280        }
281    }
282}
283
284pub(super) fn query_symbols_by_conditions(
285    conn: &mut Client,
286    mut conditions: Vec<String>,
287    mut params: Vec<PgParam>,
288    filters: SymbolFilters<'_>,
289    limit: usize,
290    order: SymbolOrder,
291) -> Vec<Symbol> {
292    let path_filter_fallback = push_symbol_filters(&mut conditions, &mut params, "cs", filters);
293    let query_limit = if path_filter_fallback {
294        limit.max(FILTERED_FETCH_CAP)
295    } else {
296        limit
297    };
298    let limit_placeholder = push_param(&mut params, query_limit as i64);
299    let where_clause = if conditions.is_empty() {
300        "TRUE".to_string()
301    } else {
302        conditions.join(" AND ")
303    };
304    let columns = db::symbol_select_columns("cs");
305    let sql = format!(
306        "SELECT {columns}
307         FROM code_symbols cs
308         JOIN code_indexed_files cf
309           ON cf.project_id = cs.project_id AND cf.file_path = cs.file_path
310         WHERE {where_clause}
311         ORDER BY {order_by}
312         LIMIT {limit_placeholder}",
313        order_by = order.sql()
314    );
315    let refs = param_refs(&params);
316    let mut symbols = match conn.query(&sql, &refs) {
317        Ok(rows) => rows
318            .iter()
319            .filter_map(|row| Symbol::from_row(row).ok())
320            .collect(),
321        Err(error) => {
322            log::error!("symbol query failed: {error}");
323            return Vec::new();
324        }
325    };
326    if path_filter_fallback {
327        symbols = symbols_matching_paths(symbols, filters.paths);
328        symbols.truncate(limit);
329    }
330    symbols
331}