Skip to main content

gobby_code/search/fts/
common.rs

1use std::collections::HashSet;
2
3use postgres::Client;
4use postgres::types::ToSql;
5
6use crate::config::{Context, ProjectIndexScope};
7use crate::db;
8use crate::models::Symbol;
9use crate::visibility;
10
11// Keep BM25 query sanitation centralized in gobby-core so gcode and gwiki
12// escape pg_search's DSL identically.
13pub use gobby_core::search::sanitize_pg_search_query;
14pub(super) use gobby_core::search::{TrustedRowId, bm25_score_expr};
15
16pub(super) type PgParam = Box<dyn ToSql + Sync>;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct ResolvedGraphSymbol {
20    pub id: String,
21    pub display_name: String,
22}
23
24#[derive(Debug, Clone, Copy, Default)]
25pub(super) struct SymbolFilters<'a> {
26    pub(super) kind: Option<&'a str>,
27    pub(super) language: Option<&'a str>,
28    pub(super) paths: &'a [String],
29}
30
31#[derive(Debug, Clone)]
32pub(super) enum SymbolOrder {
33    Bm25Score,
34    Name,
35    ExactCaseFirst(String),
36}
37
38impl SymbolOrder {
39    fn sql(&self) -> String {
40        match self {
41            Self::Bm25Score => {
42                let row_id = trusted_row_id("cs.id");
43                format!("{} DESC, cs.id ASC", bm25_score_expr(&row_id))
44            }
45            Self::Name => "cs.name ASC, cs.file_path ASC, cs.line_start ASC".to_string(),
46            Self::ExactCaseFirst(query_param) => format!(
47                "CASE WHEN cs.name = {q} OR cs.qualified_name = {q} THEN 0 ELSE 1 END,
48                 cs.file_path ASC,
49                 cs.line_start ASC",
50                q = query_param
51            ),
52        }
53    }
54}
55
56pub(super) fn trusted_row_id(row_id: &str) -> TrustedRowId {
57    // SAFETY: FTS callers pass static SQL row identifiers for local table aliases.
58    unsafe { TrustedRowId::new_unchecked(row_id) }
59}
60
61pub const FILTERED_FETCH_CAP: usize = 10_000;
62
63pub(super) fn push_param<T>(params: &mut Vec<PgParam>, value: T) -> String
64where
65    T: ToSql + Sync + 'static,
66{
67    params.push(Box::new(value));
68    format!("${}", params.len())
69}
70
71pub(super) fn param_refs(params: &[PgParam]) -> Vec<&(dyn ToSql + Sync)> {
72    params
73        .iter()
74        .map(|param| param.as_ref() as &(dyn ToSql + Sync))
75        .collect()
76}
77
78pub(super) fn query_count(
79    conn: &mut Client,
80    sql: &str,
81    params: &[PgParam],
82) -> Result<usize, postgres::Error> {
83    let refs = param_refs(params);
84    let row = conn.query_one(sql, &refs)?;
85    Ok(row.try_get::<_, i64>("count")? as usize)
86}
87
88pub(super) fn push_visible_project_file_filter(
89    conditions: &mut Vec<String>,
90    params: &mut Vec<PgParam>,
91    row_alias: &str,
92    indexed_file_alias: &str,
93    ctx: &Context,
94) {
95    let tombstone = push_param(params, visibility::TOMBSTONE_LANGUAGE.to_string());
96    conditions.push(format!("{indexed_file_alias}.language != {tombstone}"));
97
98    match &ctx.index_scope {
99        ProjectIndexScope::Single => {
100            let project = push_param(params, ctx.project_id.clone());
101            conditions.push(format!("{row_alias}.project_id = {project}"));
102        }
103        ProjectIndexScope::Overlay {
104            overlay_project_id,
105            parent_project_id,
106            ..
107        } => {
108            let overlay = push_param(params, overlay_project_id.clone());
109            let parent = push_param(params, parent_project_id.clone());
110            conditions.push(format!(
111                "({row_alias}.project_id = {overlay}
112                  OR (
113                      {row_alias}.project_id = {parent}
114                      AND NOT EXISTS (
115                          SELECT 1 FROM code_indexed_files shadow
116                          WHERE shadow.project_id = {overlay}
117                            AND shadow.file_path = {row_alias}.file_path
118                      )
119                  ))"
120            ));
121        }
122    }
123}
124
125/// Escape LIKE wildcards (`%`, `_`) and the backslash escape char itself.
126pub(super) fn escape_like(s: &str) -> String {
127    let mut out = String::with_capacity(s.len());
128    for c in s.chars() {
129        if matches!(c, '\\' | '%' | '_') {
130            out.push('\\');
131        }
132        out.push(c);
133    }
134    out
135}
136
137/// Extract a SQL LIKE prefix from a glob pattern for index-assisted pre-filtering.
138pub(super) fn glob_to_like_prefix(pattern: &str) -> Option<String> {
139    let prefix: String = pattern
140        .chars()
141        .take_while(|c| !matches!(c, '*' | '?' | '['))
142        .collect();
143    if prefix.is_empty() {
144        None
145    } else {
146        Some(format!("{}%", escape_like(&prefix)))
147    }
148}
149
150pub(super) fn has_glob_meta(path: &str) -> bool {
151    path.chars().any(|c| matches!(c, '*' | '?' | '['))
152}
153
154pub fn expand_paths(paths: &[String]) -> Vec<String> {
155    let mut expanded = Vec::new();
156    let mut seen = HashSet::new();
157    for path in paths {
158        let trimmed = path.trim().trim_end_matches('/');
159        if trimmed.is_empty() {
160            continue;
161        }
162
163        let patterns = if has_glob_meta(trimmed) {
164            vec![trimmed.to_string()]
165        } else {
166            vec![trimmed.to_string(), format!("{trimmed}/**")]
167        };
168        for pattern in patterns {
169            if seen.insert(pattern.clone()) {
170                expanded.push(pattern);
171            }
172        }
173    }
174    expanded
175}
176
177pub fn compile_patterns(paths: &[String]) -> anyhow::Result<Vec<glob::Pattern>> {
178    paths
179        .iter()
180        .map(|path| {
181            glob::Pattern::new(path).map_err(|e| anyhow::anyhow!("invalid path glob `{path}`: {e}"))
182        })
183        .collect()
184}
185
186pub(super) fn path_like_prefixes(paths: &[String]) -> Option<Vec<String>> {
187    if paths.is_empty() {
188        return Some(Vec::new());
189    }
190
191    let mut prefixes = Vec::with_capacity(paths.len());
192    for path in paths {
193        prefixes.push(glob_to_like_prefix(path)?);
194    }
195    Some(prefixes)
196}
197
198pub fn path_filter_requires_post_filter(paths: &[String]) -> bool {
199    !paths.is_empty() && path_like_prefixes(paths).is_none()
200}
201
202pub(super) fn push_path_filter(
203    conditions: &mut Vec<String>,
204    params: &mut Vec<PgParam>,
205    alias: &str,
206    paths: &[String],
207) -> bool {
208    let requires_post_filter = !paths.is_empty();
209    let Some(prefixes) = path_like_prefixes(paths) else {
210        for path in paths
211            .iter()
212            .filter(|path| glob_to_like_prefix(path).is_none())
213        {
214            log::warn!(
215                "omitting SQL path filter for alias `{alias}` because path filter `{path}` cannot be converted to a LIKE prefix; relying on post-query glob matching",
216            );
217        }
218        return requires_post_filter;
219    };
220    if prefixes.is_empty() {
221        return requires_post_filter;
222    }
223
224    let predicates = prefixes
225        .into_iter()
226        .map(|prefix| {
227            let placeholder = push_param(params, prefix);
228            format!("{alias}.file_path LIKE {placeholder} ESCAPE '\\'")
229        })
230        .collect::<Vec<_>>();
231    conditions.push(format!("({})", predicates.join(" OR ")));
232    requires_post_filter
233}
234
235pub(super) fn push_symbol_filters(
236    conditions: &mut Vec<String>,
237    params: &mut Vec<PgParam>,
238    alias: &str,
239    filters: SymbolFilters<'_>,
240) -> bool {
241    if let Some(kind) = filters.kind {
242        let placeholder = push_param(params, kind.to_string());
243        conditions.push(format!("{alias}.kind = {placeholder}"));
244    }
245    if let Some(language) = filters.language {
246        let placeholder = push_param(params, language.to_string());
247        conditions.push(format!("{alias}.language = {placeholder}"));
248    }
249    push_path_filter(conditions, params, alias, filters.paths)
250}
251
252pub(super) fn symbols_matching_paths(
253    symbols: impl IntoIterator<Item = Symbol>,
254    paths: &[String],
255) -> Vec<Symbol> {
256    let patterns = match compile_patterns(paths) {
257        Ok(patterns) => patterns,
258        Err(error) => {
259            log::warn!("invalid post-query symbol path filter: {error}");
260            return Vec::new();
261        }
262    };
263    symbols
264        .into_iter()
265        .filter(|symbol| {
266            patterns.is_empty()
267                || patterns
268                    .iter()
269                    .any(|pattern| pattern.matches(&symbol.file_path))
270        })
271        .collect()
272}
273
274pub(super) fn append_unique_symbols(
275    out: &mut Vec<Symbol>,
276    seen: &mut HashSet<String>,
277    symbols: Vec<Symbol>,
278    limit: usize,
279) {
280    if limit == 0 {
281        return;
282    }
283    for symbol in symbols {
284        if seen.insert(symbol.id.clone()) {
285            out.push(symbol);
286            if out.len() >= limit {
287                return;
288            }
289        }
290    }
291}
292
293pub(super) fn query_symbols_by_conditions(
294    conn: &mut Client,
295    mut conditions: Vec<String>,
296    mut params: Vec<PgParam>,
297    filters: SymbolFilters<'_>,
298    limit: usize,
299    order: SymbolOrder,
300) -> Vec<Symbol> {
301    let path_filter_requires_post_filter =
302        push_symbol_filters(&mut conditions, &mut params, "cs", filters);
303    let query_limit = if path_filter_requires_post_filter {
304        limit.max(FILTERED_FETCH_CAP)
305    } else {
306        limit
307    };
308    let limit_placeholder = push_param(&mut params, query_limit as i64);
309    let where_clause = if conditions.is_empty() {
310        "TRUE".to_string()
311    } else {
312        conditions.join(" AND ")
313    };
314    let columns = db::symbol_select_columns("cs");
315    let sql = format!(
316        "SELECT {columns}
317         FROM code_symbols cs
318         JOIN code_indexed_files cf
319           ON cf.project_id = cs.project_id AND cf.file_path = cs.file_path
320         WHERE {where_clause}
321         ORDER BY {order_by}
322         LIMIT {limit_placeholder}",
323        order_by = order.sql()
324    );
325    let refs = param_refs(&params);
326    let mut symbols = match conn.query(&sql, &refs) {
327        Ok(rows) => rows
328            .iter()
329            .filter_map(|row| Symbol::from_row(row).ok())
330            .collect(),
331        Err(error) => {
332            log::error!("symbol query failed: {error}");
333            return Vec::new();
334        }
335    };
336    if path_filter_requires_post_filter {
337        symbols = symbols_matching_paths(symbols, filters.paths);
338        symbols.truncate(limit);
339    }
340    symbols
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn bm25_score_expression_uses_pdb_score() {
349        let row_id = trusted_row_id("cs.id");
350        let sql = bm25_score_expr(&row_id);
351
352        assert_eq!(sql, "pdb.score(cs.id)");
353        assert!(!sql.contains("pg_search.score"));
354    }
355
356    #[test]
357    fn symbol_bm25_order_uses_pdb_score() {
358        let sql = SymbolOrder::Bm25Score.sql();
359
360        assert_eq!(sql, "pdb.score(cs.id) DESC, cs.id ASC");
361        assert!(!sql.contains("pg_search.score"));
362    }
363}