gitcortex_mcp/mcp/
search.rs1use std::collections::HashSet;
17
18use gitcortex_core::{error::Result, graph::Node, schema::NodeKind, store::GraphStore};
19use serde::Serialize;
20
21#[derive(Debug, Clone, Serialize)]
22pub struct SearchHit {
23 pub name: String,
24 pub qualified_name: String,
25 pub kind: String,
26 pub file: String,
27 pub start_line: u32,
28 pub score: i32,
29}
30
31const DEFAULT_LIMIT: usize = 10;
32const MAX_LIMIT: usize = 200;
33const MIN_TOKEN_LEN: usize = 3;
34
35fn tokenize(s: &str) -> Vec<String> {
42 let mut tokens = Vec::new();
43 let mut current = String::new();
44 let chars: Vec<char> = s.chars().collect();
45 for (i, &ch) in chars.iter().enumerate() {
46 if ch == '_' || ch == '-' || ch == '.' || ch == ':' || ch == '/' || ch == ' ' {
47 if !current.is_empty() {
48 tokens.push(current.to_ascii_lowercase());
49 current = String::new();
50 }
51 } else if ch.is_uppercase() {
52 let next_is_lower = chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false);
55 let prev_is_upper = i > 0 && chars[i - 1].is_uppercase();
56 if !current.is_empty() && (!prev_is_upper || next_is_lower) {
57 tokens.push(current.to_ascii_lowercase());
58 current = String::new();
59 }
60 current.push(ch.to_ascii_lowercase());
61 } else {
62 current.push(ch);
63 }
64 }
65 if !current.is_empty() {
66 tokens.push(current.to_ascii_lowercase());
67 }
68 tokens
69}
70
71fn edit_distance(a: &str, b: &str) -> usize {
73 let a: Vec<char> = a.chars().collect();
74 let b: Vec<char> = b.chars().collect();
75 let m = a.len();
76 let n = b.len();
77 if m.abs_diff(n) > 3 {
79 return usize::MAX;
80 }
81 let mut prev: Vec<usize> = (0..=n).collect();
82 let mut curr = vec![0usize; n + 1];
83 for i in 1..=m {
84 curr[0] = i;
85 for j in 1..=n {
86 curr[j] = if a[i - 1] == b[j - 1] {
87 prev[j - 1]
88 } else {
89 1 + prev[j - 1].min(prev[j]).min(curr[j - 1])
90 };
91 }
92 std::mem::swap(&mut prev, &mut curr);
93 }
94 prev[n]
95}
96
97fn score(n: &Node, q_lower: &str, q_tokens: &[String]) -> Option<i32> {
99 let name_lower = n.name.to_ascii_lowercase();
100 let qname_lower = n.qualified_name.to_ascii_lowercase();
101 let name_tokens = tokenize(&n.name);
102
103 let base = if name_lower == q_lower {
104 100
106 } else if name_lower.starts_with(q_lower) {
107 60
108 } else if !q_tokens.is_empty() && q_tokens.iter().all(|t| name_tokens.contains(t)) {
109 50
112 } else if name_lower.contains(q_lower) {
113 30
114 } else {
115 let overlap = q_tokens
117 .iter()
118 .filter(|qt| qt.len() >= MIN_TOKEN_LEN && name_tokens.contains(*qt))
119 .count();
120 if overlap > 0 {
121 10 + (overlap as i32 * 5).min(15)
122 } else if qname_lower.contains(q_lower) {
123 10
125 } else if q_lower.len() >= 4 && q_lower.len() <= 15 && name_lower.len() <= 25 {
126 let dist = edit_distance(q_lower, &name_lower);
128 if dist <= 1 {
129 20
130 } else if dist <= 2 {
131 10
132 } else {
133 return None;
134 }
135 } else {
136 return None;
137 }
138 };
139
140 Some(base + kind_boost(&n.kind))
141}
142
143fn kind_boost(k: &NodeKind) -> i32 {
144 match k {
145 NodeKind::Function | NodeKind::Method => 5,
146 NodeKind::Struct | NodeKind::Trait | NodeKind::Interface => 4,
147 NodeKind::Enum | NodeKind::TypeAlias => 3,
148 NodeKind::Constant | NodeKind::Macro | NodeKind::Annotation => 2,
149 _ => 0,
150 }
151}
152
153fn to_hit(n: Node, score: i32) -> SearchHit {
154 SearchHit {
155 name: n.name,
156 qualified_name: n.qualified_name,
157 kind: n.kind.to_string(),
158 file: n.file.display().to_string(),
159 start_line: n.span.start_line,
160 score,
161 }
162}
163
164pub fn search<S: GraphStore + ?Sized>(
171 store: &S,
172 branch: &str,
173 query: &str,
174 limit: Option<usize>,
175) -> Result<Vec<SearchHit>> {
176 let limit = limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT);
177 let q = query.trim();
178 if q.is_empty() {
179 return Ok(Vec::new());
180 }
181
182 let q_lower = q.to_ascii_lowercase();
183 let q_tokens = tokenize(q);
184 let candidate_limit = (limit * 50).max(500);
185
186 let mut seen: HashSet<String> = HashSet::new();
188 let mut nodes: Vec<Node> = Vec::new();
189
190 let push = |nodes: &mut Vec<Node>, seen: &mut HashSet<String>, batch: Vec<Node>| {
191 for n in batch {
192 let id = n.id.as_str();
193 if seen.insert(id) {
194 nodes.push(n);
195 }
196 }
197 };
198
199 push(
200 &mut nodes,
201 &mut seen,
202 store.search_nodes(branch, q, candidate_limit)?,
203 );
204
205 for token in &q_tokens {
208 if token.len() < MIN_TOKEN_LEN {
209 continue;
210 }
211 if token.as_str() == q_lower {
213 continue;
214 }
215 push(
216 &mut nodes,
217 &mut seen,
218 store.search_nodes(branch, token, candidate_limit)?,
219 );
220 }
221
222 if nodes.is_empty() && q_lower.len() >= 4 && q_lower.len() <= 20 {
227 push(&mut nodes, &mut seen, store.list_all_nodes(branch)?);
228 }
229
230 let mut hits: Vec<SearchHit> = nodes
231 .into_iter()
232 .filter_map(|n| score(&n, &q_lower, &q_tokens).map(|s| to_hit(n, s)))
233 .collect();
234
235 hits.sort_by(|a, b| {
236 b.score
237 .cmp(&a.score)
238 .then_with(|| a.name.len().cmp(&b.name.len()))
239 .then_with(|| a.qualified_name.cmp(&b.qualified_name))
240 });
241 hits.truncate(limit);
242 Ok(hits)
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn tokenize_camel_case() {
251 assert_eq!(tokenize("AuthConfig"), vec!["auth", "config"]);
252 assert_eq!(tokenize("validateToken"), vec!["validate", "token"]);
253 assert_eq!(tokenize("HTTPClient"), vec!["http", "client"]);
254 }
255
256 #[test]
257 fn tokenize_snake_case() {
258 assert_eq!(tokenize("validate_token"), vec!["validate", "token"]);
259 assert_eq!(tokenize("auth_middleware"), vec!["auth", "middleware"]);
260 }
261
262 #[test]
263 fn tokenize_pascal_case() {
264 assert_eq!(tokenize("KuzuGraphStore"), vec!["kuzu", "graph", "store"]);
265 }
266
267 #[test]
268 fn edit_distance_exact() {
269 assert_eq!(edit_distance("validate", "validate"), 0);
270 }
271
272 #[test]
273 fn edit_distance_typo() {
274 assert_eq!(edit_distance("vlidate", "validate"), 1);
275 assert_eq!(edit_distance("authnticate", "authenticate"), 1);
276 }
277
278 #[test]
279 fn edit_distance_length_short_circuit() {
280 assert_eq!(edit_distance("a", "abcde"), usize::MAX);
282 }
283}