Skip to main content

graphify_serve/
lib.rs

1//! MCP server for graph queries.
2//!
3//! Provides graph traversal and scoring functions used by the query
4//! engine and MCP protocol server. Port of Python query tools.
5
6pub mod mcp;
7
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::Path;
10
11use graphify_core::graph::KnowledgeGraph;
12use serde_json::Value;
13use thiserror::Error;
14
15/// Errors from the server.
16#[derive(Debug, Error)]
17pub enum ServeError {
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20
21    #[error("graph load error: {0}")]
22    GraphLoad(String),
23
24    #[error("serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26}
27
28/// Score nodes by relevance to search terms.
29///
30/// Returns `(score, node_id)` pairs sorted by descending score.
31/// Scoring: +2.0 for exact label match, +1.0 for label contains,
32/// +0.5 for id contains, plus a small degree-based boost.
33pub fn score_nodes(graph: &KnowledgeGraph, terms: &[String]) -> Vec<(f64, String)> {
34    let lower_terms: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
35
36    let mut scored = Vec::new();
37    for node_id in graph.node_ids() {
38        if let Some(node) = graph.get_node(&node_id) {
39            let label_lower = node.label.to_lowercase();
40            let id_lower = node.id.to_lowercase();
41
42            let mut score: f64 = 0.0;
43
44            for term in &lower_terms {
45                // Exact match in label
46                if label_lower == *term {
47                    score += 2.0;
48                } else if label_lower.contains(term.as_str()) {
49                    score += 1.0;
50                }
51
52                // Match in node ID
53                if id_lower.contains(term.as_str()) {
54                    score += 0.5;
55                }
56            }
57
58            if score > 0.0 {
59                // Boost by degree (well-connected nodes are more relevant)
60                let degree_boost = (graph.degree(&node_id) as f64).ln_1p() * 0.1;
61                score += degree_boost;
62                scored.push((score, node_id.clone()));
63            }
64        }
65    }
66
67    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
68    scored
69}
70
71/// BFS traversal from start nodes up to a maximum depth.
72///
73/// Returns `(visited_nodes, edges_traversed)` where edges are `(source, target)` pairs.
74pub fn bfs(
75    graph: &KnowledgeGraph,
76    start: &[String],
77    depth: usize,
78) -> (Vec<String>, Vec<(String, String)>) {
79    let mut visited: HashSet<String> = HashSet::new();
80    let mut edges: Vec<(String, String)> = Vec::new();
81    let mut queue: VecDeque<(String, usize)> = VecDeque::new();
82
83    for s in start {
84        if graph.get_node(s).is_some() {
85            visited.insert(s.clone());
86            queue.push_back((s.clone(), 0));
87        }
88    }
89
90    while let Some((current, current_depth)) = queue.pop_front() {
91        if current_depth >= depth {
92            continue;
93        }
94
95        for neighbor_id in graph.neighbor_ids(&current) {
96            edges.push((current.clone(), neighbor_id.clone()));
97
98            if !visited.contains(&neighbor_id) {
99                visited.insert(neighbor_id.clone());
100                queue.push_back((neighbor_id, current_depth + 1));
101            }
102        }
103    }
104
105    let visited_vec: Vec<String> = visited.into_iter().collect();
106    (visited_vec, edges)
107}
108
109/// DFS traversal from start nodes up to a maximum depth.
110///
111/// Returns `(visited_nodes, edges_traversed)` where edges are `(source, target)` pairs.
112pub fn dfs(
113    graph: &KnowledgeGraph,
114    start: &[String],
115    depth: usize,
116) -> (Vec<String>, Vec<(String, String)>) {
117    let mut visited: HashSet<String> = HashSet::new();
118    let mut edges: Vec<(String, String)> = Vec::new();
119    let mut stack: Vec<(String, usize)> = Vec::new();
120
121    for s in start {
122        if graph.get_node(s).is_some() {
123            visited.insert(s.clone());
124            stack.push((s.clone(), 0));
125        }
126    }
127
128    while let Some((current, current_depth)) = stack.pop() {
129        if current_depth >= depth {
130            continue;
131        }
132
133        for neighbor_id in graph.neighbor_ids(&current) {
134            edges.push((current.clone(), neighbor_id.clone()));
135
136            if !visited.contains(&neighbor_id) {
137                visited.insert(neighbor_id.clone());
138                stack.push((neighbor_id, current_depth + 1));
139            }
140        }
141    }
142
143    let visited_vec: Vec<String> = visited.into_iter().collect();
144    (visited_vec, edges)
145}
146
147/// Convert a subgraph (set of nodes and edges) to a text representation
148/// suitable for LLM context windows.
149///
150/// Respects a `token_budget` (approximate: 1 token ≈ 4 chars).
151pub fn subgraph_to_text(
152    graph: &KnowledgeGraph,
153    nodes: &[String],
154    edges: &[(String, String)],
155    token_budget: usize,
156) -> String {
157    let char_budget = token_budget * 4;
158    let mut output = String::with_capacity(char_budget.min(64 * 1024));
159
160    // Header
161    output.push_str(&format!(
162        "=== Knowledge Graph Context ({} nodes, {} edges) ===\n\n",
163        nodes.len(),
164        edges.len()
165    ));
166
167    // Nodes section
168    output.push_str("## Nodes\n\n");
169    for node_id in nodes {
170        if output.len() >= char_budget {
171            output.push_str("\n... (truncated due to token budget)\n");
172            break;
173        }
174
175        if let Some(node) = graph.get_node(node_id) {
176            output.push_str(&format!(
177                "- **{}** [{}] (type: {:?}",
178                node.label, node.id, node.node_type
179            ));
180            if let Some(community) = node.community {
181                output.push_str(&format!(", community: {}", community));
182            }
183            output.push_str(&format!(", file: {})\n", node.source_file));
184        }
185    }
186
187    // Edges section
188    if output.len() < char_budget {
189        output.push_str("\n## Relationships\n\n");
190
191        // Deduplicate edges for display
192        let mut seen: HashSet<(&str, &str)> = HashSet::new();
193        for (src, tgt) in edges {
194            if output.len() >= char_budget {
195                output.push_str("\n... (truncated due to token budget)\n");
196                break;
197            }
198
199            if seen.insert((src.as_str(), tgt.as_str())) {
200                let src_label = graph.get_node(src).map(|n| n.label.as_str()).unwrap_or(src);
201                let tgt_label = graph.get_node(tgt).map(|n| n.label.as_str()).unwrap_or(tgt);
202                output.push_str(&format!("- {} -> {}\n", src_label, tgt_label));
203            }
204        }
205    }
206
207    output
208}
209
210/// Load a knowledge graph from a JSON file.
211pub fn load_graph(graph_path: &Path) -> Result<KnowledgeGraph, ServeError> {
212    let content = std::fs::read_to_string(graph_path)?;
213    let value: Value = serde_json::from_str(&content)?;
214    KnowledgeGraph::from_node_link_json(&value).map_err(|e| ServeError::GraphLoad(e.to_string()))
215}
216
217/// Get basic statistics about the graph.
218pub fn graph_stats(graph: &KnowledgeGraph) -> HashMap<String, Value> {
219    let mut stats = HashMap::new();
220    stats.insert("node_count".to_string(), Value::from(graph.node_count()));
221    stats.insert("edge_count".to_string(), Value::from(graph.edge_count()));
222    stats.insert(
223        "community_count".to_string(),
224        Value::from(graph.communities.len()),
225    );
226
227    // Degree statistics
228    let node_ids = graph.node_ids();
229    if !node_ids.is_empty() {
230        let degrees: Vec<usize> = node_ids.iter().map(|id| graph.degree(id)).collect();
231        let max_degree = degrees.iter().copied().max().unwrap_or(0);
232        let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
233        stats.insert("max_degree".to_string(), Value::from(max_degree));
234        stats.insert(
235            "avg_degree".to_string(),
236            Value::from(format!("{:.2}", avg_degree)),
237        );
238    }
239
240    stats
241}
242
243/// Start the MCP server over stdio (JSON-RPC 2.0).
244///
245/// Reads requests from stdin, writes responses to stdout.
246/// This is the entry point called by the CLI `serve` command.
247pub async fn start_server(graph_path: &Path) -> Result<(), ServeError> {
248    // Run the synchronous stdio loop; use spawn_blocking so we don't
249    // block the tokio runtime (though for stdio this is fine).
250    let path = graph_path.to_path_buf();
251    tokio::task::spawn_blocking(move || mcp::run_mcp_server(&path))
252        .await
253        .map_err(|e| ServeError::Io(std::io::Error::other(e)))??;
254    Ok(())
255}
256
257// ---------------------------------------------------------------------------
258// Advanced graph algorithms
259// ---------------------------------------------------------------------------
260
261/// Find all simple paths between `source` and `target` up to `max_length` edges.
262///
263/// Returns a vec of paths, each path being a vec of node IDs.
264/// Limits to at most 50 paths to prevent combinatorial explosion.
265pub fn all_simple_paths(
266    graph: &KnowledgeGraph,
267    source: &str,
268    target: &str,
269    max_length: usize,
270) -> Vec<Vec<String>> {
271    const MAX_PATHS: usize = 50;
272    let mut result: Vec<Vec<String>> = Vec::new();
273    let mut stack: Vec<(String, Vec<String>)> = Vec::new();
274
275    if graph.get_node(source).is_none() || graph.get_node(target).is_none() {
276        return result;
277    }
278
279    stack.push((source.to_string(), vec![source.to_string()]));
280
281    while let Some((current, path)) = stack.pop() {
282        if result.len() >= MAX_PATHS {
283            break;
284        }
285        if current == target && path.len() > 1 {
286            result.push(path);
287            continue;
288        }
289        if path.len() > max_length + 1 {
290            continue;
291        }
292
293        for neighbor_id in graph.neighbor_ids(&current) {
294            if !path.contains(&neighbor_id) {
295                let mut new_path = path.clone();
296                new_path.push(neighbor_id.clone());
297                stack.push((neighbor_id, new_path));
298            }
299        }
300    }
301
302    result.sort_by_key(|p| p.len());
303    result
304}
305
306/// Edge detail in a weighted path: (from_id, to_id, cost, relation).
307pub type EdgeDetail = (String, String, f64, String);
308
309/// Dijkstra shortest path using edge weights.
310///
311/// Cost = 1.0 / edge.weight (higher weight = shorter distance).
312/// Optionally filters edges below `min_confidence` score.
313/// Returns `(path, total_cost, edge_details)` or None if no path exists.
314pub fn dijkstra_path(
315    graph: &KnowledgeGraph,
316    source: &str,
317    target: &str,
318    min_confidence: f64,
319) -> Option<(Vec<String>, f64, Vec<EdgeDetail>)> {
320    use std::cmp::Ordering;
321    use std::collections::BinaryHeap;
322
323    if graph.get_node(source).is_none() || graph.get_node(target).is_none() {
324        return None;
325    }
326    if source == target {
327        return Some((vec![source.to_string()], 0.0, Vec::new()));
328    }
329
330    // Build adjacency with weights from edges
331    let mut adj: HashMap<String, Vec<(String, f64, String)>> = HashMap::new();
332    for (src, tgt, edge) in graph.edges_with_endpoints() {
333        if edge.confidence_score < min_confidence {
334            continue;
335        }
336        let cost = if edge.weight > 0.0 {
337            1.0 / edge.weight
338        } else {
339            f64::MAX
340        };
341        adj.entry(src.to_string()).or_default().push((
342            tgt.to_string(),
343            cost,
344            edge.relation.clone(),
345        ));
346        adj.entry(tgt.to_string()).or_default().push((
347            src.to_string(),
348            cost,
349            edge.relation.clone(),
350        ));
351    }
352
353    #[derive(PartialEq)]
354    struct State {
355        cost: f64,
356        node: String,
357    }
358    impl Eq for State {}
359    impl PartialOrd for State {
360        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
361            Some(self.cmp(other))
362        }
363    }
364    impl Ord for State {
365        fn cmp(&self, other: &Self) -> Ordering {
366            other
367                .cost
368                .partial_cmp(&self.cost)
369                .unwrap_or(Ordering::Equal)
370        }
371    }
372
373    let mut dist: HashMap<String, f64> = HashMap::new();
374    let mut prev: HashMap<String, (String, f64, String)> = HashMap::new();
375    let mut heap = BinaryHeap::new();
376
377    dist.insert(source.to_string(), 0.0);
378    heap.push(State {
379        cost: 0.0,
380        node: source.to_string(),
381    });
382
383    while let Some(State { cost, node }) = heap.pop() {
384        if node == target {
385            break;
386        }
387        if cost > *dist.get(&node).unwrap_or(&f64::MAX) {
388            continue;
389        }
390        if let Some(neighbors) = adj.get(&node) {
391            for (next, edge_cost, relation) in neighbors {
392                let new_cost = cost + edge_cost;
393                if new_cost < *dist.get(next).unwrap_or(&f64::MAX) {
394                    dist.insert(next.clone(), new_cost);
395                    prev.insert(next.clone(), (node.clone(), *edge_cost, relation.clone()));
396                    heap.push(State {
397                        cost: new_cost,
398                        node: next.clone(),
399                    });
400                }
401            }
402        }
403    }
404
405    // Reconstruct path
406    if !prev.contains_key(target) {
407        return None;
408    }
409
410    let mut path = vec![target.to_string()];
411    let mut edge_details: Vec<(String, String, f64, String)> = Vec::new();
412    let mut current = target.to_string();
413    while let Some((from, cost, relation)) = prev.get(&current) {
414        edge_details.push((from.clone(), current.clone(), *cost, relation.clone()));
415        path.push(from.clone());
416        current = from.clone();
417    }
418    path.reverse();
419    edge_details.reverse();
420
421    let total_cost = *dist.get(target).unwrap_or(&f64::MAX);
422    Some((path, total_cost, edge_details))
423}
424
425// ---------------------------------------------------------------------------
426// Tests
427// ---------------------------------------------------------------------------
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use graphify_core::confidence::Confidence;
433    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
434
435    fn make_node(id: &str, label: &str) -> GraphNode {
436        GraphNode {
437            id: id.into(),
438            label: label.into(),
439            source_file: "test.rs".into(),
440            source_location: None,
441            node_type: NodeType::Class,
442            community: None,
443            extra: HashMap::new(),
444        }
445    }
446
447    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
448        GraphEdge {
449            source: src.into(),
450            target: tgt.into(),
451            relation: "calls".into(),
452            confidence: Confidence::Extracted,
453            confidence_score: 1.0,
454            source_file: "test.rs".into(),
455            source_location: None,
456            weight: 1.0,
457            extra: HashMap::new(),
458        }
459    }
460
461    fn make_test_graph() -> KnowledgeGraph {
462        let mut g = KnowledgeGraph::new();
463        g.add_node(make_node("auth", "AuthService")).unwrap();
464        g.add_node(make_node("user", "UserManager")).unwrap();
465        g.add_node(make_node("db", "Database")).unwrap();
466        g.add_node(make_node("cache", "CacheLayer")).unwrap();
467        g.add_edge(make_edge("auth", "user")).unwrap();
468        g.add_edge(make_edge("auth", "db")).unwrap();
469        g.add_edge(make_edge("user", "db")).unwrap();
470        g.add_edge(make_edge("user", "cache")).unwrap();
471        g
472    }
473
474    #[test]
475    fn test_score_nodes_basic() {
476        let g = make_test_graph();
477        let results = score_nodes(&g, &["auth".to_string()]);
478        assert!(!results.is_empty());
479        // "auth" node should score highest
480        let top_id = &results[0].1;
481        assert_eq!(top_id, "auth");
482    }
483
484    #[test]
485    fn test_score_nodes_no_match() {
486        let g = make_test_graph();
487        let results = score_nodes(&g, &["nonexistent".to_string()]);
488        assert!(results.is_empty());
489    }
490
491    #[test]
492    fn test_score_nodes_multiple_terms() {
493        let g = make_test_graph();
494        let results = score_nodes(&g, &["user".to_string(), "manager".to_string()]);
495        assert!(!results.is_empty());
496        assert!(results.iter().any(|(_, id)| id == "user"));
497    }
498
499    #[test]
500    fn test_bfs_depth_0() {
501        let g = make_test_graph();
502        let (nodes, edges) = bfs(&g, &["auth".to_string()], 0);
503        assert_eq!(nodes.len(), 1);
504        assert!(edges.is_empty());
505    }
506
507    #[test]
508    fn test_bfs_depth_1() {
509        let g = make_test_graph();
510        let (nodes, edges) = bfs(&g, &["auth".to_string()], 1);
511        // auth -> user, auth -> db
512        assert!(nodes.len() >= 3); // auth, user, db
513        assert!(!edges.is_empty());
514    }
515
516    #[test]
517    fn test_bfs_depth_2() {
518        let g = make_test_graph();
519        let (nodes, _edges) = bfs(&g, &["auth".to_string()], 2);
520        // Should reach all 4 nodes
521        assert_eq!(nodes.len(), 4);
522    }
523
524    #[test]
525    fn test_dfs_depth_1() {
526        let g = make_test_graph();
527        let (nodes, edges) = dfs(&g, &["auth".to_string()], 1);
528        assert!(nodes.len() >= 3);
529        assert!(!edges.is_empty());
530    }
531
532    #[test]
533    fn test_bfs_nonexistent_start() {
534        let g = make_test_graph();
535        let (nodes, edges) = bfs(&g, &["nonexistent".to_string()], 3);
536        assert!(nodes.is_empty());
537        assert!(edges.is_empty());
538    }
539
540    #[test]
541    fn test_subgraph_to_text() {
542        let g = make_test_graph();
543        let nodes = vec!["auth".to_string(), "user".to_string()];
544        let edges = vec![("auth".to_string(), "user".to_string())];
545        let text = subgraph_to_text(&g, &nodes, &edges, 1000);
546
547        assert!(text.contains("Knowledge Graph Context"));
548        assert!(text.contains("AuthService"));
549        assert!(text.contains("UserManager"));
550        assert!(text.contains("Relationships"));
551    }
552
553    #[test]
554    fn test_subgraph_to_text_budget() {
555        let g = make_test_graph();
556        let nodes: Vec<String> = g.node_ids();
557        let edges = vec![
558            ("auth".to_string(), "user".to_string()),
559            ("auth".to_string(), "db".to_string()),
560        ];
561        // Very small budget
562        let text = subgraph_to_text(&g, &nodes, &edges, 10);
563        assert!(text.contains("truncated") || text.len() < 200);
564    }
565
566    #[test]
567    fn test_graph_stats() {
568        let g = make_test_graph();
569        let stats = graph_stats(&g);
570        assert_eq!(stats["node_count"], 4);
571        assert_eq!(stats["edge_count"], 4);
572    }
573
574    #[test]
575    fn test_bfs_multiple_starts() {
576        let g = make_test_graph();
577        let (nodes, _) = bfs(&g, &["auth".to_string(), "cache".to_string()], 1);
578        assert!(nodes.len() >= 4);
579    }
580
581    // -- all_simple_paths tests --
582
583    #[test]
584    fn test_all_simple_paths_direct() {
585        let g = make_test_graph();
586        let paths = all_simple_paths(&g, "auth", "user", 4);
587        assert!(!paths.is_empty());
588        // Direct edge exists: auth → user
589        assert!(paths.iter().any(|p| p.len() == 2));
590    }
591
592    #[test]
593    fn test_all_simple_paths_indirect() {
594        let g = make_test_graph();
595        // auth → db has direct path (len 2) and indirect auth → user → db (len 3)
596        let paths = all_simple_paths(&g, "auth", "db", 4);
597        assert!(
598            paths.len() >= 2,
599            "should find multiple paths, got {}",
600            paths.len()
601        );
602    }
603
604    #[test]
605    fn test_all_simple_paths_no_path() {
606        let mut g = KnowledgeGraph::new();
607        g.add_node(make_node("a", "A")).unwrap();
608        g.add_node(make_node("b", "B")).unwrap();
609        // No edge between a and b
610        let paths = all_simple_paths(&g, "a", "b", 4);
611        assert!(paths.is_empty());
612    }
613
614    #[test]
615    fn test_all_simple_paths_nonexistent_node() {
616        let g = make_test_graph();
617        let paths = all_simple_paths(&g, "auth", "nonexistent", 4);
618        assert!(paths.is_empty());
619    }
620
621    #[test]
622    fn test_all_simple_paths_sorted_by_length() {
623        let g = make_test_graph();
624        let paths = all_simple_paths(&g, "auth", "cache", 5);
625        for w in paths.windows(2) {
626            assert!(w[0].len() <= w[1].len(), "paths should be sorted by length");
627        }
628    }
629
630    // -- dijkstra_path tests --
631
632    #[test]
633    fn test_dijkstra_direct_path() {
634        let g = make_test_graph();
635        let result = dijkstra_path(&g, "auth", "user", 0.0);
636        assert!(result.is_some());
637        let (path, cost, edges) = result.unwrap();
638        assert_eq!(path.first().unwrap(), "auth");
639        assert_eq!(path.last().unwrap(), "user");
640        assert!(cost > 0.0);
641        assert!(!edges.is_empty());
642    }
643
644    #[test]
645    fn test_dijkstra_same_node() {
646        let g = make_test_graph();
647        let result = dijkstra_path(&g, "auth", "auth", 0.0);
648        assert!(result.is_some());
649        let (path, cost, _) = result.unwrap();
650        assert_eq!(path.len(), 1);
651        assert!((cost - 0.0).abs() < f64::EPSILON);
652    }
653
654    #[test]
655    fn test_dijkstra_no_path() {
656        let mut g = KnowledgeGraph::new();
657        g.add_node(make_node("a", "A")).unwrap();
658        g.add_node(make_node("b", "B")).unwrap();
659        let result = dijkstra_path(&g, "a", "b", 0.0);
660        assert!(result.is_none());
661    }
662
663    #[test]
664    fn test_dijkstra_nonexistent_node() {
665        let g = make_test_graph();
666        assert!(dijkstra_path(&g, "auth", "nonexistent", 0.0).is_none());
667    }
668
669    #[test]
670    fn test_dijkstra_min_confidence_filter() {
671        // Create graph with mixed confidence edges
672        let mut g = KnowledgeGraph::new();
673        g.add_node(make_node("a", "A")).unwrap();
674        g.add_node(make_node("b", "B")).unwrap();
675        g.add_node(make_node("c", "C")).unwrap();
676
677        // a→b: low confidence (0.3), a→c→b: high confidence (1.0)
678        let mut low_edge = make_edge("a", "b");
679        low_edge.confidence_score = 0.3;
680        g.add_edge(low_edge).unwrap();
681
682        let mut high1 = make_edge("a", "c");
683        high1.confidence_score = 1.0;
684        g.add_edge(high1).unwrap();
685
686        let mut high2 = make_edge("c", "b");
687        high2.confidence_score = 1.0;
688        g.add_edge(high2).unwrap();
689
690        // With min_confidence 0.5, should skip a→b and go a→c→b
691        let result = dijkstra_path(&g, "a", "b", 0.5);
692        assert!(result.is_some());
693        let (path, _, _) = result.unwrap();
694        assert_eq!(path.len(), 3, "should go through c, got path: {path:?}");
695    }
696}