Skip to main content

graphify_serve/
lib.rs

1//! MCP server for graph queries.
2//!
3//! Provides graph traversal and scoring functions used by the query
4//! engine and MCP protocol server. Port of Python query tools.
5
6pub mod mcp;
7
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::Path;
10
11use graphify_core::graph::KnowledgeGraph;
12use serde_json::Value;
13use thiserror::Error;
14
15/// Errors from the server.
16#[derive(Debug, Error)]
17pub enum ServeError {
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20
21    #[error("graph load error: {0}")]
22    GraphLoad(String),
23
24    #[error("serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26}
27
28/// Score nodes by relevance to search terms.
29///
30/// Returns `(score, node_id)` pairs sorted by descending score.
31/// Scoring: +2.0 for exact label match, +1.0 for label contains,
32/// +0.5 for id contains, plus a small degree-based boost.
33pub fn score_nodes(graph: &KnowledgeGraph, terms: &[String]) -> Vec<(f64, String)> {
34    let lower_terms: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
35
36    let mut scored = Vec::new();
37    for node_id in graph.node_ids() {
38        if let Some(node) = graph.get_node(&node_id) {
39            let label_lower = node.label.to_lowercase();
40            let id_lower = node.id.to_lowercase();
41
42            let mut score: f64 = 0.0;
43
44            for term in &lower_terms {
45                // Exact match in label
46                if label_lower == *term {
47                    score += 2.0;
48                } else if label_lower.contains(term.as_str()) {
49                    score += 1.0;
50                }
51
52                // Match in node ID
53                if id_lower.contains(term.as_str()) {
54                    score += 0.5;
55                }
56            }
57
58            if score > 0.0 {
59                // Boost by degree (well-connected nodes are more relevant)
60                let degree_boost = (graph.degree(&node_id) as f64).ln_1p() * 0.1;
61                score += degree_boost;
62                scored.push((score, node_id.clone()));
63            }
64        }
65    }
66
67    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
68    scored
69}
70
71/// BFS traversal from start nodes up to a maximum depth.
72///
73/// Returns `(visited_nodes, edges_traversed)` where edges are `(source, target)` pairs.
74pub fn bfs(
75    graph: &KnowledgeGraph,
76    start: &[String],
77    depth: usize,
78) -> (Vec<String>, Vec<(String, String)>) {
79    let mut visited: HashSet<String> = HashSet::new();
80    let mut edges: Vec<(String, String)> = Vec::new();
81    let mut queue: VecDeque<(String, usize)> = VecDeque::new();
82
83    for s in start {
84        if graph.get_node(s).is_some() {
85            visited.insert(s.clone());
86            queue.push_back((s.clone(), 0));
87        }
88    }
89
90    while let Some((current, current_depth)) = queue.pop_front() {
91        if current_depth >= depth {
92            continue;
93        }
94
95        for neighbor_id in graph.neighbor_ids(&current) {
96            edges.push((current.clone(), neighbor_id.clone()));
97
98            if !visited.contains(&neighbor_id) {
99                visited.insert(neighbor_id.clone());
100                queue.push_back((neighbor_id, current_depth + 1));
101            }
102        }
103    }
104
105    let visited_vec: Vec<String> = visited.into_iter().collect();
106    (visited_vec, edges)
107}
108
109/// DFS traversal from start nodes up to a maximum depth.
110///
111/// Returns `(visited_nodes, edges_traversed)` where edges are `(source, target)` pairs.
112pub fn dfs(
113    graph: &KnowledgeGraph,
114    start: &[String],
115    depth: usize,
116) -> (Vec<String>, Vec<(String, String)>) {
117    let mut visited: HashSet<String> = HashSet::new();
118    let mut edges: Vec<(String, String)> = Vec::new();
119    let mut stack: Vec<(String, usize)> = Vec::new();
120
121    for s in start {
122        if graph.get_node(s).is_some() {
123            visited.insert(s.clone());
124            stack.push((s.clone(), 0));
125        }
126    }
127
128    while let Some((current, current_depth)) = stack.pop() {
129        if current_depth >= depth {
130            continue;
131        }
132
133        for neighbor_id in graph.neighbor_ids(&current) {
134            edges.push((current.clone(), neighbor_id.clone()));
135
136            if !visited.contains(&neighbor_id) {
137                visited.insert(neighbor_id.clone());
138                stack.push((neighbor_id, current_depth + 1));
139            }
140        }
141    }
142
143    let visited_vec: Vec<String> = visited.into_iter().collect();
144    (visited_vec, edges)
145}
146
147/// Convert a subgraph (set of nodes and edges) to a text representation
148/// suitable for LLM context windows.
149///
150/// Respects a `token_budget` (approximate: 1 token ≈ 4 chars).
151pub fn subgraph_to_text(
152    graph: &KnowledgeGraph,
153    nodes: &[String],
154    edges: &[(String, String)],
155    token_budget: usize,
156) -> String {
157    let char_budget = token_budget * 4;
158    let mut output = String::with_capacity(char_budget.min(64 * 1024));
159
160    // Header
161    output.push_str(&format!(
162        "=== Knowledge Graph Context ({} nodes, {} edges) ===\n\n",
163        nodes.len(),
164        edges.len()
165    ));
166
167    // Nodes section
168    output.push_str("## Nodes\n\n");
169    for node_id in nodes {
170        if output.len() >= char_budget {
171            output.push_str("\n... (truncated due to token budget)\n");
172            break;
173        }
174
175        if let Some(node) = graph.get_node(node_id) {
176            output.push_str(&format!(
177                "- **{}** [{}] (type: {:?}",
178                node.label, node.id, node.node_type
179            ));
180            if let Some(community) = node.community {
181                output.push_str(&format!(", community: {}", community));
182            }
183            output.push_str(&format!(", file: {})\n", node.source_file));
184        }
185    }
186
187    // Edges section
188    if output.len() < char_budget {
189        output.push_str("\n## Relationships\n\n");
190
191        // Deduplicate edges for display
192        let mut seen: HashSet<(&str, &str)> = HashSet::new();
193        for (src, tgt) in edges {
194            if output.len() >= char_budget {
195                output.push_str("\n... (truncated due to token budget)\n");
196                break;
197            }
198
199            if seen.insert((src.as_str(), tgt.as_str())) {
200                let src_label = graph.get_node(src).map(|n| n.label.as_str()).unwrap_or(src);
201                let tgt_label = graph.get_node(tgt).map(|n| n.label.as_str()).unwrap_or(tgt);
202                output.push_str(&format!("- {} -> {}\n", src_label, tgt_label));
203            }
204        }
205    }
206
207    output
208}
209
210/// Load a knowledge graph from a JSON file.
211pub fn load_graph(graph_path: &Path) -> Result<KnowledgeGraph, ServeError> {
212    let content = std::fs::read_to_string(graph_path)?;
213    let value: Value = serde_json::from_str(&content)?;
214    KnowledgeGraph::from_node_link_json(&value).map_err(|e| ServeError::GraphLoad(e.to_string()))
215}
216
217/// Get basic statistics about the graph.
218pub fn graph_stats(graph: &KnowledgeGraph) -> HashMap<String, Value> {
219    let mut stats = HashMap::new();
220    stats.insert("node_count".to_string(), Value::from(graph.node_count()));
221    stats.insert("edge_count".to_string(), Value::from(graph.edge_count()));
222    stats.insert(
223        "community_count".to_string(),
224        Value::from(graph.communities.len()),
225    );
226
227    // Degree statistics
228    let node_ids = graph.node_ids();
229    if !node_ids.is_empty() {
230        let degrees: Vec<usize> = node_ids.iter().map(|id| graph.degree(id)).collect();
231        let max_degree = degrees.iter().copied().max().unwrap_or(0);
232        let avg_degree = degrees.iter().sum::<usize>() as f64 / degrees.len() as f64;
233        stats.insert("max_degree".to_string(), Value::from(max_degree));
234        stats.insert(
235            "avg_degree".to_string(),
236            Value::from(format!("{:.2}", avg_degree)),
237        );
238    }
239
240    stats
241}
242
243/// Start the MCP server over stdio (JSON-RPC 2.0).
244///
245/// Reads requests from stdin, writes responses to stdout.
246/// This is the entry point called by the CLI `serve` command.
247pub async fn start_server(graph_path: &Path) -> Result<(), ServeError> {
248    // Run the synchronous stdio loop; use spawn_blocking so we don't
249    // block the tokio runtime (though for stdio this is fine).
250    let path = graph_path.to_path_buf();
251    tokio::task::spawn_blocking(move || mcp::run_mcp_server(&path))
252        .await
253        .map_err(|e| ServeError::Io(std::io::Error::other(e)))??;
254    Ok(())
255}
256
257// ---------------------------------------------------------------------------
258// Tests
259// ---------------------------------------------------------------------------
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use graphify_core::confidence::Confidence;
265    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
266
267    fn make_node(id: &str, label: &str) -> GraphNode {
268        GraphNode {
269            id: id.into(),
270            label: label.into(),
271            source_file: "test.rs".into(),
272            source_location: None,
273            node_type: NodeType::Class,
274            community: None,
275            extra: HashMap::new(),
276        }
277    }
278
279    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
280        GraphEdge {
281            source: src.into(),
282            target: tgt.into(),
283            relation: "calls".into(),
284            confidence: Confidence::Extracted,
285            confidence_score: 1.0,
286            source_file: "test.rs".into(),
287            source_location: None,
288            weight: 1.0,
289            extra: HashMap::new(),
290        }
291    }
292
293    fn make_test_graph() -> KnowledgeGraph {
294        let mut g = KnowledgeGraph::new();
295        g.add_node(make_node("auth", "AuthService")).unwrap();
296        g.add_node(make_node("user", "UserManager")).unwrap();
297        g.add_node(make_node("db", "Database")).unwrap();
298        g.add_node(make_node("cache", "CacheLayer")).unwrap();
299        g.add_edge(make_edge("auth", "user")).unwrap();
300        g.add_edge(make_edge("auth", "db")).unwrap();
301        g.add_edge(make_edge("user", "db")).unwrap();
302        g.add_edge(make_edge("user", "cache")).unwrap();
303        g
304    }
305
306    #[test]
307    fn test_score_nodes_basic() {
308        let g = make_test_graph();
309        let results = score_nodes(&g, &["auth".to_string()]);
310        assert!(!results.is_empty());
311        // "auth" node should score highest
312        let top_id = &results[0].1;
313        assert_eq!(top_id, "auth");
314    }
315
316    #[test]
317    fn test_score_nodes_no_match() {
318        let g = make_test_graph();
319        let results = score_nodes(&g, &["nonexistent".to_string()]);
320        assert!(results.is_empty());
321    }
322
323    #[test]
324    fn test_score_nodes_multiple_terms() {
325        let g = make_test_graph();
326        let results = score_nodes(&g, &["user".to_string(), "manager".to_string()]);
327        assert!(!results.is_empty());
328        assert!(results.iter().any(|(_, id)| id == "user"));
329    }
330
331    #[test]
332    fn test_bfs_depth_0() {
333        let g = make_test_graph();
334        let (nodes, edges) = bfs(&g, &["auth".to_string()], 0);
335        assert_eq!(nodes.len(), 1);
336        assert!(edges.is_empty());
337    }
338
339    #[test]
340    fn test_bfs_depth_1() {
341        let g = make_test_graph();
342        let (nodes, edges) = bfs(&g, &["auth".to_string()], 1);
343        // auth -> user, auth -> db
344        assert!(nodes.len() >= 3); // auth, user, db
345        assert!(!edges.is_empty());
346    }
347
348    #[test]
349    fn test_bfs_depth_2() {
350        let g = make_test_graph();
351        let (nodes, _edges) = bfs(&g, &["auth".to_string()], 2);
352        // Should reach all 4 nodes
353        assert_eq!(nodes.len(), 4);
354    }
355
356    #[test]
357    fn test_dfs_depth_1() {
358        let g = make_test_graph();
359        let (nodes, edges) = dfs(&g, &["auth".to_string()], 1);
360        assert!(nodes.len() >= 3);
361        assert!(!edges.is_empty());
362    }
363
364    #[test]
365    fn test_bfs_nonexistent_start() {
366        let g = make_test_graph();
367        let (nodes, edges) = bfs(&g, &["nonexistent".to_string()], 3);
368        assert!(nodes.is_empty());
369        assert!(edges.is_empty());
370    }
371
372    #[test]
373    fn test_subgraph_to_text() {
374        let g = make_test_graph();
375        let nodes = vec!["auth".to_string(), "user".to_string()];
376        let edges = vec![("auth".to_string(), "user".to_string())];
377        let text = subgraph_to_text(&g, &nodes, &edges, 1000);
378
379        assert!(text.contains("Knowledge Graph Context"));
380        assert!(text.contains("AuthService"));
381        assert!(text.contains("UserManager"));
382        assert!(text.contains("Relationships"));
383    }
384
385    #[test]
386    fn test_subgraph_to_text_budget() {
387        let g = make_test_graph();
388        let nodes: Vec<String> = g.node_ids();
389        let edges = vec![
390            ("auth".to_string(), "user".to_string()),
391            ("auth".to_string(), "db".to_string()),
392        ];
393        // Very small budget
394        let text = subgraph_to_text(&g, &nodes, &edges, 10);
395        assert!(text.contains("truncated") || text.len() < 200);
396    }
397
398    #[test]
399    fn test_graph_stats() {
400        let g = make_test_graph();
401        let stats = graph_stats(&g);
402        assert_eq!(stats["node_count"], 4);
403        assert_eq!(stats["edge_count"], 4);
404    }
405
406    #[test]
407    fn test_bfs_multiple_starts() {
408        let g = make_test_graph();
409        let (nodes, _) = bfs(&g, &["auth".to_string(), "cache".to_string()], 1);
410        assert!(nodes.len() >= 4);
411    }
412}