Skip to main content

graphmind_sdk/
algo.rs

1//! AlgorithmClient — extension trait for graph algorithms (EmbeddedClient only)
2//!
3//! Provides PageRank, community detection, pathfinding, and other graph algorithms
4//! via the `AlgorithmClient` trait. Only `EmbeddedClient` implements this trait
5//! since algorithms require direct in-process access to the graph store.
6
7use async_trait::async_trait;
8use std::collections::HashMap;
9
10use graphmind::algo::{
11    bfs, bfs_all_shortest_paths, build_view, cdlp, count_triangles, dijkstra, edmonds_karp,
12    local_clustering_coefficient, page_rank, pca, prim_mst, strongly_connected_components,
13    weakly_connected_components, CdlpConfig, CdlpResult, FlowResult, LccResult, MSTResult,
14    PageRankConfig, PathResult, PcaConfig, PcaResult, SccResult, WccResult,
15};
16use graphmind_graph_algorithms::GraphView;
17
18use crate::embedded::EmbeddedClient;
19
20/// Extension trait for graph algorithm operations.
21///
22/// Only implemented by `EmbeddedClient` since algorithms need direct store access.
23#[async_trait]
24pub trait AlgorithmClient {
25    /// Build a `GraphView` projection for algorithm execution.
26    ///
27    /// Optionally filter by node label, edge type, and extract edge weights.
28    async fn build_view(
29        &self,
30        label: Option<&str>,
31        edge_type: Option<&str>,
32        weight_prop: Option<&str>,
33    ) -> GraphView;
34
35    /// Run PageRank on the graph (or a subgraph filtered by label/edge_type).
36    async fn page_rank(
37        &self,
38        config: PageRankConfig,
39        label: Option<&str>,
40        edge_type: Option<&str>,
41    ) -> HashMap<u64, f64>;
42
43    /// Detect weakly connected components.
44    async fn weakly_connected_components(
45        &self,
46        label: Option<&str>,
47        edge_type: Option<&str>,
48    ) -> WccResult;
49
50    /// Detect strongly connected components.
51    async fn strongly_connected_components(
52        &self,
53        label: Option<&str>,
54        edge_type: Option<&str>,
55    ) -> SccResult;
56
57    /// Breadth-first search from source to target.
58    async fn bfs(
59        &self,
60        source: u64,
61        target: u64,
62        label: Option<&str>,
63        edge_type: Option<&str>,
64    ) -> Option<PathResult>;
65
66    /// Dijkstra's shortest path from source to target (weighted).
67    async fn dijkstra(
68        &self,
69        source: u64,
70        target: u64,
71        label: Option<&str>,
72        edge_type: Option<&str>,
73        weight_prop: Option<&str>,
74    ) -> Option<PathResult>;
75
76    /// Edmonds-Karp maximum flow from source to sink.
77    async fn edmonds_karp(
78        &self,
79        source: u64,
80        sink: u64,
81        label: Option<&str>,
82        edge_type: Option<&str>,
83    ) -> Option<FlowResult>;
84
85    /// Prim's minimum spanning tree.
86    async fn prim_mst(
87        &self,
88        label: Option<&str>,
89        edge_type: Option<&str>,
90        weight_prop: Option<&str>,
91    ) -> MSTResult;
92
93    /// Count triangles in the graph.
94    async fn count_triangles(&self, label: Option<&str>, edge_type: Option<&str>) -> usize;
95
96    /// Find all shortest paths between source and target (BFS).
97    async fn bfs_all_shortest_paths(
98        &self,
99        source: u64,
100        target: u64,
101        label: Option<&str>,
102        edge_type: Option<&str>,
103    ) -> Vec<PathResult>;
104
105    /// Community Detection via Label Propagation (CDLP).
106    async fn cdlp(
107        &self,
108        config: CdlpConfig,
109        label: Option<&str>,
110        edge_type: Option<&str>,
111    ) -> CdlpResult;
112
113    /// Local Clustering Coefficient for all nodes.
114    async fn local_clustering_coefficient(
115        &self,
116        label: Option<&str>,
117        edge_type: Option<&str>,
118    ) -> LccResult;
119
120    /// Principal Component Analysis on node numeric properties.
121    ///
122    /// Extracts the specified numeric properties from nodes matching `label`,
123    /// builds a feature matrix, and runs PCA with the given config.
124    async fn pca(&self, label: Option<&str>, properties: &[&str], config: PcaConfig) -> PcaResult;
125}
126
127#[async_trait]
128impl AlgorithmClient for EmbeddedClient {
129    async fn build_view(
130        &self,
131        label: Option<&str>,
132        edge_type: Option<&str>,
133        weight_prop: Option<&str>,
134    ) -> GraphView {
135        let store = self.store.read().await;
136        build_view(&store, label, edge_type, weight_prop)
137    }
138
139    async fn page_rank(
140        &self,
141        config: PageRankConfig,
142        label: Option<&str>,
143        edge_type: Option<&str>,
144    ) -> HashMap<u64, f64> {
145        let store = self.store.read().await;
146        let view = build_view(&store, label, edge_type, None);
147        page_rank(&view, config)
148    }
149
150    async fn weakly_connected_components(
151        &self,
152        label: Option<&str>,
153        edge_type: Option<&str>,
154    ) -> WccResult {
155        let store = self.store.read().await;
156        let view = build_view(&store, label, edge_type, None);
157        weakly_connected_components(&view)
158    }
159
160    async fn strongly_connected_components(
161        &self,
162        label: Option<&str>,
163        edge_type: Option<&str>,
164    ) -> SccResult {
165        let store = self.store.read().await;
166        let view = build_view(&store, label, edge_type, None);
167        strongly_connected_components(&view)
168    }
169
170    async fn bfs(
171        &self,
172        source: u64,
173        target: u64,
174        label: Option<&str>,
175        edge_type: Option<&str>,
176    ) -> Option<PathResult> {
177        let store = self.store.read().await;
178        let view = build_view(&store, label, edge_type, None);
179        bfs(&view, source, target)
180    }
181
182    async fn dijkstra(
183        &self,
184        source: u64,
185        target: u64,
186        label: Option<&str>,
187        edge_type: Option<&str>,
188        weight_prop: Option<&str>,
189    ) -> Option<PathResult> {
190        let store = self.store.read().await;
191        let view = build_view(&store, label, edge_type, weight_prop);
192        dijkstra(&view, source, target)
193    }
194
195    async fn edmonds_karp(
196        &self,
197        source: u64,
198        sink: u64,
199        label: Option<&str>,
200        edge_type: Option<&str>,
201    ) -> Option<FlowResult> {
202        let store = self.store.read().await;
203        let view = build_view(&store, label, edge_type, None);
204        edmonds_karp(&view, source, sink)
205    }
206
207    async fn prim_mst(
208        &self,
209        label: Option<&str>,
210        edge_type: Option<&str>,
211        weight_prop: Option<&str>,
212    ) -> MSTResult {
213        let store = self.store.read().await;
214        let view = build_view(&store, label, edge_type, weight_prop);
215        prim_mst(&view)
216    }
217
218    async fn count_triangles(&self, label: Option<&str>, edge_type: Option<&str>) -> usize {
219        let store = self.store.read().await;
220        let view = build_view(&store, label, edge_type, None);
221        count_triangles(&view)
222    }
223
224    async fn bfs_all_shortest_paths(
225        &self,
226        source: u64,
227        target: u64,
228        label: Option<&str>,
229        edge_type: Option<&str>,
230    ) -> Vec<PathResult> {
231        let store = self.store.read().await;
232        let view = build_view(&store, label, edge_type, None);
233        bfs_all_shortest_paths(&view, source, target)
234    }
235
236    async fn cdlp(
237        &self,
238        config: CdlpConfig,
239        label: Option<&str>,
240        edge_type: Option<&str>,
241    ) -> CdlpResult {
242        let store = self.store.read().await;
243        let view = build_view(&store, label, edge_type, None);
244        cdlp(&view, &config)
245    }
246
247    async fn local_clustering_coefficient(
248        &self,
249        label: Option<&str>,
250        edge_type: Option<&str>,
251    ) -> LccResult {
252        let store = self.store.read().await;
253        let view = build_view(&store, label, edge_type, None);
254        local_clustering_coefficient(&view)
255    }
256
257    async fn pca(&self, label: Option<&str>, properties: &[&str], config: PcaConfig) -> PcaResult {
258        let store = self.store.read().await;
259        use graphmind::graph::{Label, PropertyValue};
260
261        // Collect nodes matching the label filter
262        let nodes: Vec<_> = if let Some(label_str) = label {
263            let l = Label::new(label_str);
264            store.get_nodes_by_label(&l).into_iter().collect()
265        } else {
266            store.all_nodes().into_iter().collect()
267        };
268
269        // Build feature matrix: flat allocation then reshape into rows
270        let n = nodes.len();
271        let d = properties.len();
272        let mut data_flat = vec![0.0f64; n * d];
273        for (i, node) in nodes.iter().enumerate() {
274            for (j, &prop) in properties.iter().enumerate() {
275                data_flat[i * d + j] = match node.get_property(prop) {
276                    Some(PropertyValue::Integer(v)) => *v as f64,
277                    Some(PropertyValue::Float(v)) => *v,
278                    _ => 0.0,
279                };
280            }
281        }
282        let data: Vec<Vec<f64>> = data_flat.chunks_exact(d).map(|c| c.to_vec()).collect();
283
284        if data.is_empty() {
285            return PcaResult {
286                components: vec![],
287                explained_variance: vec![],
288                explained_variance_ratio: vec![],
289                mean: vec![0.0; properties.len()],
290                std_dev: vec![1.0; properties.len()],
291                n_samples: 0,
292                n_features: properties.len(),
293                iterations_used: 0,
294            };
295        }
296
297        pca(&data, config)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::{EmbeddedClient, GraphmindClient};
305
306    #[tokio::test]
307    async fn test_page_rank() {
308        let client = EmbeddedClient::new();
309
310        // Create a small graph: A -> B -> C, A -> C
311        client
312            .query("default", r#"CREATE (a:Person {name: "Alice"})"#)
313            .await
314            .unwrap();
315        client
316            .query("default", r#"CREATE (b:Person {name: "Bob"})"#)
317            .await
318            .unwrap();
319        client
320            .query("default", r#"CREATE (c:Person {name: "Carol"})"#)
321            .await
322            .unwrap();
323        client.query("default",
324            r#"MATCH (a:Person {name: "Alice"}), (b:Person {name: "Bob"}) CREATE (a)-[:KNOWS]->(b)"#
325        ).await.unwrap();
326        client.query("default",
327            r#"MATCH (b:Person {name: "Bob"}), (c:Person {name: "Carol"}) CREATE (b)-[:KNOWS]->(c)"#
328        ).await.unwrap();
329        client.query("default",
330            r#"MATCH (a:Person {name: "Alice"}), (c:Person {name: "Carol"}) CREATE (a)-[:KNOWS]->(c)"#
331        ).await.unwrap();
332
333        let scores = client
334            .page_rank(PageRankConfig::default(), Some("Person"), Some("KNOWS"))
335            .await;
336        assert_eq!(scores.len(), 3);
337        // Carol should have highest PageRank (most incoming links)
338        let max_node = scores
339            .iter()
340            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
341            .unwrap();
342        // At least verify we got results
343        assert!(*max_node.1 > 0.0);
344    }
345
346    #[tokio::test]
347    async fn test_wcc() {
348        let client = EmbeddedClient::new();
349
350        // Two disconnected components
351        client
352            .query(
353                "default",
354                r#"CREATE (a:Person {name: "Alice"})-[:KNOWS]->(b:Person {name: "Bob"})"#,
355            )
356            .await
357            .unwrap();
358        client
359            .query(
360                "default",
361                r#"CREATE (c:Person {name: "Carol"})-[:KNOWS]->(d:Person {name: "Dave"})"#,
362            )
363            .await
364            .unwrap();
365
366        let wcc = client
367            .weakly_connected_components(Some("Person"), Some("KNOWS"))
368            .await;
369        assert_eq!(wcc.components.len(), 2);
370    }
371
372    #[tokio::test]
373    async fn test_bfs() {
374        let client = EmbeddedClient::new();
375
376        client
377            .query(
378                "default",
379                r#"CREATE (a:Person {name: "Alice"})-[:KNOWS]->(b:Person {name: "Bob"})"#,
380            )
381            .await
382            .unwrap();
383        client
384            .query(
385                "default",
386                r#"MATCH (b:Person {name: "Bob"}) CREATE (b)-[:KNOWS]->(c:Person {name: "Carol"})"#,
387            )
388            .await
389            .unwrap();
390
391        // Get node IDs
392        let store = client.store().read().await;
393        let all_nodes: Vec<_> = store.all_nodes().iter().map(|n| n.id.as_u64()).collect();
394        drop(store);
395
396        if all_nodes.len() >= 3 {
397            let result = client
398                .bfs(all_nodes[0], all_nodes[2], Some("Person"), Some("KNOWS"))
399                .await;
400            assert!(result.is_some());
401            let path = result.unwrap();
402            assert!(path.path.len() >= 2);
403        }
404    }
405}