Skip to main content

graphify_serve/
search.rs

1//! In-memory inverted index for fast node lookup.
2//!
3//! Provides [`SearchIndex`] which tokenizes node labels, ids, and source files
4//! into an inverted index for sub-linear search across the knowledge graph.
5
6use std::collections::HashMap;
7
8use graphify_core::graph::KnowledgeGraph;
9
10// ---------------------------------------------------------------------------
11// Tokenizer
12// ---------------------------------------------------------------------------
13
14/// Split `input` on camelCase boundaries, `_`, `.`, `::`, `/`, `\`, `-`, and
15/// whitespace. Returns all-lowercase, non-empty tokens.
16///
17/// # Examples (implicit, tested below)
18/// * `"camelCase"` -> `["camel", "case"]`
19/// * `"foo_bar.baz"` -> `["foo", "bar", "baz"]`
20/// * `"std::collections::HashMap"` -> `["std", "collections", "hash", "map"]`
21/// * `"src/main/mod.rs"` -> `["src", "main", "mod", "rs"]`
22pub fn tokenize(input: &str) -> Vec<String> {
23    // Phase 1: split on explicit delimiters.
24    let raw: Vec<&str> = input
25        .split(&['_', '.', ':', '/', '\\', '-'][..])
26        .flat_map(|s| s.split_whitespace())
27        .collect();
28
29    // Phase 2: split each piece on camelCase boundaries.
30    let mut tokens: Vec<String> = Vec::new();
31    for piece in &raw {
32        if piece.is_empty() {
33            continue;
34        }
35        // Walk characters; start a new segment on lowercase→uppercase transition.
36        let mut segment = String::new();
37        for ch in piece.chars() {
38            if ch.is_uppercase()
39                && !segment.is_empty()
40                && !segment.chars().last().unwrap().is_uppercase()
41            {
42                tokens.push(segment.to_lowercase());
43                segment.clear();
44            }
45            segment.push(ch);
46        }
47        if !segment.is_empty() {
48            tokens.push(segment.to_lowercase());
49        }
50    }
51    tokens
52}
53
54// ---------------------------------------------------------------------------
55// SearchIndex
56// ---------------------------------------------------------------------------
57
58/// In-memory inverted index mapping tokens to weighted `(node_id, weight)` pairs.
59///
60/// Built from a [`KnowledgeGraph`] via [`SearchIndex::build`]. Each node contributes
61/// tokens from its **label**, **id**, and **source_file**, each with a different base
62/// weight plus a degree-based boost.
63pub struct SearchIndex {
64    /// token -> [(node_id, weight)]
65    index: HashMap<String, Vec<(String, f64)>>,
66}
67
68impl SearchIndex {
69    /// Build the inverted index from a knowledge graph.
70    ///
71    /// Token weights:
72    /// - Label token: `2.0 + ln_1p(degree) * 0.1`
73    /// - Id token:    `1.0 + ln_1p(degree) * 0.1`
74    /// - Source file token: `0.5 + ln_1p(degree) * 0.1`
75    ///
76    /// Note: the same token from multiple fields (label + id) stacks additively,
77    /// rewarding nodes whose token appears in multiple fields.
78    pub fn build(graph: &KnowledgeGraph) -> Self {
79        let mut index: HashMap<String, Vec<(String, f64)>> = HashMap::new();
80
81        for node_id in graph.node_ids() {
82            let Some(node) = graph.get_node(&node_id) else {
83                continue;
84            };
85            let degree = graph.degree(&node_id) as f64;
86            let degree_boost = degree.ln_1p() * 0.1;
87
88            // Helper: insert tokens with a given base weight.
89            let mut insert = |text: &str, base: f64| {
90                for tok in tokenize(text) {
91                    let weight = base + degree_boost;
92                    index
93                        .entry(tok)
94                        .or_default()
95                        .push((node_id.clone(), weight));
96                }
97            };
98
99            insert(&node.label, 2.0);
100            insert(&node.id, 1.0);
101            insert(&node.source_file, 0.5);
102        }
103
104        SearchIndex { index }
105    }
106
107    /// Search for nodes matching any of the given terms.
108    ///
109    /// Each term is tokenized and matched against the index using **exact** and
110    /// **prefix** matching. Scores are aggregated per node. Results are returned
111    /// sorted by descending score.
112    ///
113    /// Prefix matches receive half the weight of an exact match.
114    pub fn search(&self, terms: &[String]) -> Vec<(f64, String)> {
115        let mut scores: HashMap<String, f64> = HashMap::new();
116
117        let term_tokens: Vec<String> = terms.iter().flat_map(|t| tokenize(t)).collect();
118
119        for term_tok in &term_tokens {
120            // Exact match.
121            if let Some(entries) = self.index.get(term_tok) {
122                for (node_id, weight) in entries {
123                    *scores.entry(node_id.clone()).or_default() += weight;
124                }
125            }
126
127            // PERF: prefix match is O(vocabulary size). Acceptable for graphs up to ~10k nodes.
128            // A future optimization would use a sorted token list or BTreeMap for range scan.
129            for (token, entries) in &self.index {
130                if token != term_tok && token.starts_with(term_tok) {
131                    for (node_id, weight) in entries {
132                        *scores.entry(node_id.clone()).or_default() += weight * 0.5;
133                    }
134                }
135            }
136        }
137
138        let mut results: Vec<(f64, String)> = scores
139            .into_iter()
140            .map(|(node_id, score)| (score, node_id))
141            .collect();
142        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
143        results
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Tests
149// ---------------------------------------------------------------------------
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use graphify_core::model::{GraphNode, NodeType};
155    use std::collections::HashMap;
156
157    // -- helpers --
158
159    fn make_node(id: &str, label: &str, source_file: &str) -> GraphNode {
160        GraphNode {
161            id: id.into(),
162            label: label.into(),
163            source_file: source_file.into(),
164            source_location: None,
165            node_type: NodeType::Class,
166            community: None,
167            extra: HashMap::new(),
168        }
169    }
170
171    fn make_graph() -> KnowledgeGraph {
172        let mut g = KnowledgeGraph::new();
173        g.add_node(make_node(
174            "auth_service",
175            "AuthService",
176            "src/auth/service.rs",
177        ))
178        .unwrap();
179        g.add_node(make_node(
180            "user_manager",
181            "UserManager",
182            "src/user/manager.rs",
183        ))
184        .unwrap();
185        g.add_node(make_node("database_pool", "DatabasePool", "src/db/pool.rs"))
186            .unwrap();
187        g.add_node(make_node("cache_layer", "CacheLayer", "src/cache/layer.rs"))
188            .unwrap();
189        g
190    }
191
192    fn make_graph_with_edges() -> KnowledgeGraph {
193        use graphify_core::confidence::Confidence;
194        use graphify_core::model::GraphEdge;
195
196        let mut g = KnowledgeGraph::new();
197        g.add_node(make_node("auth", "AuthService", "src/auth.rs"))
198            .unwrap();
199        g.add_node(make_node("user", "UserManager", "src/user.rs"))
200            .unwrap();
201        g.add_node(make_node("db", "Database", "src/db.rs"))
202            .unwrap();
203        g.add_node(make_node("cache", "CacheLayer", "src/cache.rs"))
204            .unwrap();
205
206        let edge = GraphEdge {
207            source: "auth".into(),
208            target: "user".into(),
209            relation: "calls".into(),
210            confidence: Confidence::Extracted,
211            confidence_score: 1.0,
212            source_file: "test.rs".into(),
213            source_location: None,
214            weight: 1.0,
215            provenance: None,
216            extra: HashMap::new(),
217        };
218        g.add_edge(edge).unwrap();
219        g
220    }
221
222    // -- tokenize tests --
223
224    #[test]
225    fn tokenize_camel_case() {
226        let tokens = tokenize("camelCase");
227        assert_eq!(tokens, vec!["camel", "case"]);
228    }
229
230    #[test]
231    fn tokenize_underscore() {
232        let tokens = tokenize("foo_bar_baz");
233        assert_eq!(tokens, vec!["foo", "bar", "baz"]);
234    }
235
236    #[test]
237    fn tokenize_dot_separator() {
238        let tokens = tokenize("mod.rs");
239        assert_eq!(tokens, vec!["mod", "rs"]);
240    }
241
242    #[test]
243    fn tokenize_double_colon() {
244        let tokens = tokenize("std::collections::HashMap");
245        assert_eq!(tokens, vec!["std", "collections", "hash", "map"]);
246    }
247
248    #[test]
249    fn tokenize_slash() {
250        let tokens = tokenize("src/main/mod.rs");
251        assert_eq!(tokens, vec!["src", "main", "mod", "rs"]);
252    }
253
254    #[test]
255    fn tokenize_backslash() {
256        let tokens = tokenize(r"src\main\mod.rs");
257        assert_eq!(tokens, vec!["src", "main", "mod", "rs"]);
258    }
259
260    #[test]
261    fn tokenize_hyphen() {
262        let tokens = tokenize("my-component-name");
263        assert_eq!(tokens, vec!["my", "component", "name"]);
264    }
265
266    #[test]
267    fn tokenize_whitespace() {
268        let tokens = tokenize("foo   bar\tbaz");
269        assert_eq!(tokens, vec!["foo", "bar", "baz"]);
270    }
271
272    #[test]
273    fn tokenize_mixed() {
274        let tokens = tokenize("MyComponent_test.rs");
275        assert_eq!(tokens, vec!["my", "component", "test", "rs"]);
276    }
277
278    #[test]
279    fn tokenize_empty() {
280        let tokens = tokenize("");
281        assert!(tokens.is_empty());
282    }
283
284    #[test]
285    fn tokenize_all_lowercase() {
286        let tokens = tokenize("AuthService");
287        assert!(tokens.iter().all(|t| t == &t.to_lowercase()));
288    }
289
290    #[test]
291    fn tokenize_consecutive_uppercase() {
292        // "HTTPServer" -> ["httpserver"] or ["https", "erver"] depending on impl
293        // Our impl keeps consecutive uppercase together: "HTTPServer" -> "h", "t", "t", "p", "s", "erver"?
294        // Actually let's check: H-T-T-P are uppercase but segment starts empty,
295        // then we see 'S' uppercase, segment="HTTP" has last char 'P' which is uppercase,
296        // so no split. Then 'e' is lowercase, no split. 'r','v','e','r' lowercase, no split.
297        // Result: ["httpserver"]
298        let tokens = tokenize("HTTPServer");
299        assert_eq!(tokens, vec!["httpserver"]);
300    }
301
302    // -- SearchIndex::build tests --
303
304    #[test]
305    fn build_creates_index_entries() {
306        let g = make_graph();
307        let idx = SearchIndex::build(&g);
308
309        // "auth" should appear from both label and id of auth_service
310        assert!(idx.index.contains_key("auth"));
311        let entries = &idx.index["auth"];
312        // auth_service label "AuthService" -> "auth", "service"
313        // auth_service id "auth_service" -> "auth", "service"
314        assert!(entries.iter().any(|(id, _)| id == "auth_service"));
315    }
316
317    #[test]
318    fn build_label_weight_higher_than_id() {
319        let g = make_graph();
320        let idx = SearchIndex::build(&g);
321
322        // Find weight for "service" from auth_service (label token vs id token)
323        let _label_weight = idx.index["service"]
324            .iter()
325            .filter(|(id, _)| id == "auth_service")
326            .map(|(_, w)| *w)
327            .fold(f64::NEG_INFINITY, f64::max);
328        // There should be a label-derived entry (weight base 2.0) and id-derived (base 1.0)
329        let weights: Vec<f64> = idx.index["service"]
330            .iter()
331            .filter(|(id, _)| id == "auth_service")
332            .map(|(_, w)| *w)
333            .collect();
334        assert!(
335            weights.len() >= 2,
336            "should have label and id entries for 'service'"
337        );
338        assert!(
339            weights.iter().any(|w| *w >= 2.0),
340            "at least one weight >= 2.0 (label), got {:?}",
341            weights
342        );
343    }
344
345    #[test]
346    fn build_source_file_tokens() {
347        let g = make_graph();
348        let idx = SearchIndex::build(&g);
349
350        // "pool" from "src/db/pool.rs" source_file
351        assert!(idx.index.contains_key("pool"));
352    }
353
354    // -- SearchIndex::search tests --
355
356    #[test]
357    fn search_exact_label_match() {
358        let g = make_graph();
359        let idx = SearchIndex::build(&g);
360        let results = idx.search(&["auth".to_string()]);
361
362        assert!(!results.is_empty());
363        // auth_service should appear (label "AuthService" -> "auth")
364        assert!(results.iter().any(|(_, id)| id == "auth_service"));
365    }
366
367    #[test]
368    fn search_exact_id_match() {
369        let g = make_graph();
370        let idx = SearchIndex::build(&g);
371        let results = idx.search(&["database".to_string()]);
372
373        assert!(!results.is_empty());
374        // "database" from id "database_pool"
375        assert!(results.iter().any(|(_, id)| id == "database_pool"));
376    }
377
378    #[test]
379    fn search_source_file_match() {
380        let g = make_graph();
381        let idx = SearchIndex::build(&g);
382        let results = idx.search(&["cache".to_string()]);
383
384        assert!(!results.is_empty());
385        // "cache" from source_file "src/cache/layer.rs" and id "cache_layer"
386        assert!(results.iter().any(|(_, id)| id == "cache_layer"));
387    }
388
389    #[test]
390    fn search_no_match() {
391        let g = make_graph();
392        let idx = SearchIndex::build(&g);
393        let results = idx.search(&["nonexistent_xyz".to_string()]);
394        assert!(results.is_empty());
395    }
396
397    #[test]
398    fn search_prefix_match() {
399        let g = make_graph();
400        let idx = SearchIndex::build(&g);
401
402        // "auth" should prefix-match "auth" exactly and also match any token
403        // starting with "auth" (none in this graph besides "auth" itself).
404        // Let's test with "use" which should prefix-match "user".
405        let results = idx.search(&["use".to_string()]);
406        // "user_manager" label "UserManager" -> "user", "manager"
407        // "use" is a prefix of "user"
408        assert!(
409            results.iter().any(|(_, id)| id == "user_manager"),
410            "'use' should prefix-match 'user' from UserManager, got: {:?}",
411            results
412        );
413    }
414
415    #[test]
416    fn search_prefix_lower_weight() {
417        let g = make_graph();
418        let idx = SearchIndex::build(&g);
419
420        // Compare exact match score vs prefix match score.
421        let exact = idx.search(&["user".to_string()]);
422        let prefix = idx.search(&["use".to_string()]);
423
424        let exact_score = exact
425            .iter()
426            .find(|(_, id)| id == "user_manager")
427            .map(|(s, _)| *s)
428            .unwrap_or(0.0);
429        let prefix_score = prefix
430            .iter()
431            .find(|(_, id)| id == "user_manager")
432            .map(|(s, _)| *s)
433            .unwrap_or(0.0);
434
435        assert!(
436            exact_score > prefix_score,
437            "exact match ({}) should score higher than prefix match ({})",
438            exact_score,
439            prefix_score
440        );
441    }
442
443    #[test]
444    fn search_multiple_terms_aggregate() {
445        let g = make_graph();
446        let idx = SearchIndex::build(&g);
447        let results = idx.search(&["auth".to_string(), "service".to_string()]);
448
449        assert!(!results.is_empty());
450        // auth_service should get the highest score (both "auth" and "service" match)
451        let top = &results[0];
452        assert_eq!(top.1, "auth_service");
453        // Score should be greater than searching for just "auth"
454        let single = idx.search(&["auth".to_string()]);
455        let single_score = single
456            .iter()
457            .find(|(_, id)| id == "auth_service")
458            .map(|(s, _)| *s)
459            .unwrap_or(0.0);
460        assert!(
461            top.0 > single_score,
462            "two-term match ({}) should score higher than single-term ({})",
463            top.0,
464            single_score
465        );
466    }
467
468    #[test]
469    fn search_sorted_descending() {
470        let g = make_graph();
471        let idx = SearchIndex::build(&g);
472        let results = idx.search(&["service".to_string()]);
473
474        for w in results.windows(2) {
475            assert!(
476                w[0].0 >= w[1].0,
477                "results should be sorted descending by score"
478            );
479        }
480    }
481
482    #[test]
483    fn search_degree_boost() {
484        let g_no_edges = make_graph(); // no edges
485        let g_with_edges = make_graph_with_edges(); // auth has edges
486
487        let idx_no = SearchIndex::build(&g_no_edges);
488        let idx_with = SearchIndex::build(&g_with_edges);
489
490        let results_no = idx_no.search(&["auth".to_string()]);
491        let results_with = idx_with.search(&["auth".to_string()]);
492
493        let score_no = results_no
494            .iter()
495            .find(|(_, id)| id == "auth")
496            .map(|(s, _)| *s)
497            .unwrap_or(0.0);
498        let score_with = results_with
499            .iter()
500            .find(|(_, id)| id == "auth")
501            .map(|(s, _)| *s)
502            .unwrap_or(0.0);
503
504        assert!(
505            score_with > score_no,
506            "node with edges ({}) should score higher than without ({})",
507            score_with,
508            score_no
509        );
510    }
511
512    #[test]
513    fn search_empty_terms() {
514        let g = make_graph();
515        let idx = SearchIndex::build(&g);
516        let results = idx.search(&[]);
517        assert!(results.is_empty());
518    }
519
520    #[test]
521    fn search_case_insensitive_terms() {
522        let g = make_graph();
523        let idx = SearchIndex::build(&g);
524        let lower = idx.search(&["auth".to_string()]);
525        let upper = idx.search(&["AUTH".to_string()]);
526
527        // Tokenize lowercases input, so results should be identical
528        assert_eq!(lower.len(), upper.len());
529    }
530
531    #[test]
532    fn search_id_token_lower_than_label() {
533        let g = make_graph();
534        let idx = SearchIndex::build(&g);
535
536        // "pool" appears in id "database_pool" (weight 1.0+boost) and
537        // source_file "src/db/pool.rs" (weight 0.5+boost).
538        // Let's use a token that's ONLY in the id.
539        // "database_pool" -> id tokens: "database", "pool"
540        // source_file: "src/db/pool.rs" -> "src", "db", "pool", "rs"
541        // label: "DatabasePool" -> "database", "pool"
542        // So "pool" appears in all three. Let's find a pure id token.
543        // Not easy in this graph, so let's check the weighted sum directly.
544        let results = idx.search(&["rs".to_string()]);
545        // "rs" only appears in source files (weight 0.5+boost)
546        assert!(!results.is_empty(), "should find 'rs' in source files");
547        let rs_score = results[0].0;
548        // For comparison, search a label-only term
549        let results_label = idx.search(&["manager".to_string()]);
550        let label_score = results_label[0].0;
551        assert!(
552            label_score > rs_score,
553            "label token ({}) should score higher than source_file-only token ({})",
554            label_score,
555            rs_score
556        );
557    }
558}