Skip to main content

gitcortex_mcp/mcp/
search.rs

1//! Fuzzy search over the graph — multi-signal ranking with CamelCase/snake_case
2//! tokenisation, token overlap scoring, and edit-distance typo tolerance.
3//!
4//! Ranking signals (higher score = better match):
5//! - exact name match:                   +100
6//! - prefix name match:                  +60
7//! - all query tokens match name tokens: +50
8//! - substring in name:                  +30
9//! - partial token overlap:              +10..+25
10//! - edit distance ≤1 (typo):            +20
11//! - edit distance ≤2:                   +10
12//! - substring in qualified_name only:   +10
13//! - shorter names break ties
14//! - kind boost: Function/Method/Struct/Trait > others
15
16use std::collections::HashSet;
17
18use gitcortex_core::{error::Result, graph::Node, schema::NodeKind, store::GraphStore};
19use serde::Serialize;
20
21#[derive(Debug, Clone, Serialize)]
22pub struct SearchHit {
23    pub name: String,
24    pub qualified_name: String,
25    pub kind: String,
26    pub file: String,
27    pub start_line: u32,
28    pub score: i32,
29}
30
31const DEFAULT_LIMIT: usize = 10;
32const MAX_LIMIT: usize = 200;
33const MIN_TOKEN_LEN: usize = 3;
34
35/// Split a camelCase/snake_case/PascalCase identifier into lowercase tokens.
36///
37/// "AuthConfig"      → ["auth", "config"]
38/// "validate_token"  → ["validate", "token"]
39/// "parseJSONResponse" → ["parse", "j", "s", "o", "n", "response"]  (intentional — acronyms split per char)
40/// "HTTPClient"      → ["h", "t", "t", "p", "client"]
41fn tokenize(s: &str) -> Vec<String> {
42    let mut tokens = Vec::new();
43    let mut current = String::new();
44    let chars: Vec<char> = s.chars().collect();
45    for (i, &ch) in chars.iter().enumerate() {
46        if ch == '_' || ch == '-' || ch == '.' || ch == ':' || ch == '/' || ch == ' ' {
47            if !current.is_empty() {
48                tokens.push(current.to_ascii_lowercase());
49                current = String::new();
50            }
51        } else if ch.is_uppercase() {
52            // Start new token on uppercase — but keep run of capitals together
53            // as one token (e.g. "HTTP" stays "http" not split per char).
54            let next_is_lower = chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false);
55            let prev_is_upper = i > 0 && chars[i - 1].is_uppercase();
56            if !current.is_empty() && (!prev_is_upper || next_is_lower) {
57                tokens.push(current.to_ascii_lowercase());
58                current = String::new();
59            }
60            current.push(ch.to_ascii_lowercase());
61        } else {
62            current.push(ch);
63        }
64    }
65    if !current.is_empty() {
66        tokens.push(current.to_ascii_lowercase());
67    }
68    tokens
69}
70
71/// Levenshtein edit distance between two strings (capped early at `max`).
72fn edit_distance(a: &str, b: &str) -> usize {
73    let a: Vec<char> = a.chars().collect();
74    let b: Vec<char> = b.chars().collect();
75    let m = a.len();
76    let n = b.len();
77    // Quick bounds: length difference alone is a lower bound.
78    if m.abs_diff(n) > 3 {
79        return usize::MAX;
80    }
81    let mut prev: Vec<usize> = (0..=n).collect();
82    let mut curr = vec![0usize; n + 1];
83    for i in 1..=m {
84        curr[0] = i;
85        for j in 1..=n {
86            curr[j] = if a[i - 1] == b[j - 1] {
87                prev[j - 1]
88            } else {
89                1 + prev[j - 1].min(prev[j]).min(curr[j - 1])
90            };
91        }
92        std::mem::swap(&mut prev, &mut curr);
93    }
94    prev[n]
95}
96
97/// Score a node against a query. Returns `None` when the node is not a match.
98fn score(n: &Node, q_lower: &str, q_tokens: &[String]) -> Option<i32> {
99    let name_lower = n.name.to_ascii_lowercase();
100    let qname_lower = n.qualified_name.to_ascii_lowercase();
101    let name_tokens = tokenize(&n.name);
102
103    let base = if name_lower == q_lower {
104        // Exact name match — highest confidence.
105        100
106    } else if name_lower.starts_with(q_lower) {
107        60
108    } else if !q_tokens.is_empty() && q_tokens.iter().all(|t| name_tokens.contains(t)) {
109        // All query tokens present in name tokens.
110        // "auth config" fully matches "AuthConfig" or "auth_config".
111        50
112    } else if name_lower.contains(q_lower) {
113        30
114    } else {
115        // Partial token overlap.
116        let overlap = q_tokens
117            .iter()
118            .filter(|qt| qt.len() >= MIN_TOKEN_LEN && name_tokens.contains(*qt))
119            .count();
120        if overlap > 0 {
121            10 + (overlap as i32 * 5).min(15)
122        } else if qname_lower.contains(q_lower) {
123            // Match only in qualified path (e.g. module prefix).
124            10
125        } else if q_lower.len() >= 4 && q_lower.len() <= 15 && name_lower.len() <= 25 {
126            // Typo tolerance: edit distance on short-ish queries.
127            let dist = edit_distance(q_lower, &name_lower);
128            if dist <= 1 {
129                20
130            } else if dist <= 2 {
131                10
132            } else {
133                return None;
134            }
135        } else {
136            return None;
137        }
138    };
139
140    Some(base + kind_boost(&n.kind))
141}
142
143fn kind_boost(k: &NodeKind) -> i32 {
144    match k {
145        NodeKind::Function | NodeKind::Method => 5,
146        NodeKind::Struct | NodeKind::Trait | NodeKind::Interface => 4,
147        NodeKind::Enum | NodeKind::TypeAlias => 3,
148        NodeKind::Constant | NodeKind::Macro | NodeKind::Annotation => 2,
149        _ => 0,
150    }
151}
152
153fn to_hit(n: Node, score: i32) -> SearchHit {
154    SearchHit {
155        name: n.name,
156        qualified_name: n.qualified_name,
157        kind: n.kind.to_string(),
158        file: n.file.display().to_string(),
159        start_line: n.span.start_line,
160        score,
161    }
162}
163
164/// Run a fuzzy search across all nodes on `branch`.
165///
166/// Candidate set is built by querying the store for the whole query string AND
167/// for each individual token (for multi-word / camelCase queries). Candidates
168/// are deduplicated, scored with the multi-signal scorer, sorted by score
169/// descending, and truncated to `limit`.
170pub fn search<S: GraphStore + ?Sized>(
171    store: &S,
172    branch: &str,
173    query: &str,
174    limit: Option<usize>,
175) -> Result<Vec<SearchHit>> {
176    let limit = limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT);
177    let q = query.trim();
178    if q.is_empty() {
179        return Ok(Vec::new());
180    }
181
182    let q_lower = q.to_ascii_lowercase();
183    let q_tokens = tokenize(q);
184    let candidate_limit = (limit * 50).max(500);
185
186    // Fetch candidates: whole query first, then per token.
187    let mut seen: HashSet<String> = HashSet::new();
188    let mut nodes: Vec<Node> = Vec::new();
189
190    let push = |nodes: &mut Vec<Node>, seen: &mut HashSet<String>, batch: Vec<Node>| {
191        for n in batch {
192            let id = n.id.as_str();
193            if seen.insert(id) {
194                nodes.push(n);
195            }
196        }
197    };
198
199    push(
200        &mut nodes,
201        &mut seen,
202        store.search_nodes(branch, q, candidate_limit)?,
203    );
204
205    // Per-token expansion: lets "validate token" find "validate_token" even
206    // when the store's CONTAINS filter requires the full substring.
207    for token in &q_tokens {
208        if token.len() < MIN_TOKEN_LEN {
209            continue;
210        }
211        // Skip if token equals the whole query (already fetched above).
212        if token.as_str() == q_lower {
213            continue;
214        }
215        push(
216            &mut nodes,
217            &mut seen,
218            store.search_nodes(branch, token, candidate_limit)?,
219        );
220    }
221
222    // Typo-fallback: CONTAINS can't find misspelled queries ("Greetter" won't
223    // match "Greeter"). When no candidates found and query is short enough for
224    // edit-distance to be meaningful, scan all nodes so the scorer can apply
225    // typo tolerance.
226    if nodes.is_empty() && q_lower.len() >= 4 && q_lower.len() <= 20 {
227        push(&mut nodes, &mut seen, store.list_all_nodes(branch)?);
228    }
229
230    let mut hits: Vec<SearchHit> = nodes
231        .into_iter()
232        .filter_map(|n| score(&n, &q_lower, &q_tokens).map(|s| to_hit(n, s)))
233        .collect();
234
235    hits.sort_by(|a, b| {
236        b.score
237            .cmp(&a.score)
238            .then_with(|| a.name.len().cmp(&b.name.len()))
239            .then_with(|| a.qualified_name.cmp(&b.qualified_name))
240    });
241    hits.truncate(limit);
242    Ok(hits)
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn tokenize_camel_case() {
251        assert_eq!(tokenize("AuthConfig"), vec!["auth", "config"]);
252        assert_eq!(tokenize("validateToken"), vec!["validate", "token"]);
253        assert_eq!(tokenize("HTTPClient"), vec!["http", "client"]);
254    }
255
256    #[test]
257    fn tokenize_snake_case() {
258        assert_eq!(tokenize("validate_token"), vec!["validate", "token"]);
259        assert_eq!(tokenize("auth_middleware"), vec!["auth", "middleware"]);
260    }
261
262    #[test]
263    fn tokenize_pascal_case() {
264        assert_eq!(tokenize("KuzuGraphStore"), vec!["kuzu", "graph", "store"]);
265    }
266
267    #[test]
268    fn edit_distance_exact() {
269        assert_eq!(edit_distance("validate", "validate"), 0);
270    }
271
272    #[test]
273    fn edit_distance_typo() {
274        assert_eq!(edit_distance("vlidate", "validate"), 1);
275        assert_eq!(edit_distance("authnticate", "authenticate"), 1);
276    }
277
278    #[test]
279    fn edit_distance_length_short_circuit() {
280        // length difference > 3 → MAX
281        assert_eq!(edit_distance("a", "abcde"), usize::MAX);
282    }
283}