1use async_trait::async_trait;
8use std::collections::HashMap;
9
10use graphmind::algo::{
11 bfs, bfs_all_shortest_paths, build_view, cdlp, count_triangles, dijkstra, edmonds_karp,
12 local_clustering_coefficient, page_rank, pca, prim_mst, strongly_connected_components,
13 weakly_connected_components, CdlpConfig, CdlpResult, FlowResult, LccResult, MSTResult,
14 PageRankConfig, PathResult, PcaConfig, PcaResult, SccResult, WccResult,
15};
16use graphmind_graph_algorithms::GraphView;
17
18use crate::embedded::EmbeddedClient;
19
20#[async_trait]
24pub trait AlgorithmClient {
25 async fn build_view(
29 &self,
30 label: Option<&str>,
31 edge_type: Option<&str>,
32 weight_prop: Option<&str>,
33 ) -> GraphView;
34
35 async fn page_rank(
37 &self,
38 config: PageRankConfig,
39 label: Option<&str>,
40 edge_type: Option<&str>,
41 ) -> HashMap<u64, f64>;
42
43 async fn weakly_connected_components(
45 &self,
46 label: Option<&str>,
47 edge_type: Option<&str>,
48 ) -> WccResult;
49
50 async fn strongly_connected_components(
52 &self,
53 label: Option<&str>,
54 edge_type: Option<&str>,
55 ) -> SccResult;
56
57 async fn bfs(
59 &self,
60 source: u64,
61 target: u64,
62 label: Option<&str>,
63 edge_type: Option<&str>,
64 ) -> Option<PathResult>;
65
66 async fn dijkstra(
68 &self,
69 source: u64,
70 target: u64,
71 label: Option<&str>,
72 edge_type: Option<&str>,
73 weight_prop: Option<&str>,
74 ) -> Option<PathResult>;
75
76 async fn edmonds_karp(
78 &self,
79 source: u64,
80 sink: u64,
81 label: Option<&str>,
82 edge_type: Option<&str>,
83 ) -> Option<FlowResult>;
84
85 async fn prim_mst(
87 &self,
88 label: Option<&str>,
89 edge_type: Option<&str>,
90 weight_prop: Option<&str>,
91 ) -> MSTResult;
92
93 async fn count_triangles(&self, label: Option<&str>, edge_type: Option<&str>) -> usize;
95
96 async fn bfs_all_shortest_paths(
98 &self,
99 source: u64,
100 target: u64,
101 label: Option<&str>,
102 edge_type: Option<&str>,
103 ) -> Vec<PathResult>;
104
105 async fn cdlp(
107 &self,
108 config: CdlpConfig,
109 label: Option<&str>,
110 edge_type: Option<&str>,
111 ) -> CdlpResult;
112
113 async fn local_clustering_coefficient(
115 &self,
116 label: Option<&str>,
117 edge_type: Option<&str>,
118 ) -> LccResult;
119
120 async fn pca(&self, label: Option<&str>, properties: &[&str], config: PcaConfig) -> PcaResult;
125}
126
127#[async_trait]
128impl AlgorithmClient for EmbeddedClient {
129 async fn build_view(
130 &self,
131 label: Option<&str>,
132 edge_type: Option<&str>,
133 weight_prop: Option<&str>,
134 ) -> GraphView {
135 let store = self.store.read().await;
136 build_view(&store, label, edge_type, weight_prop)
137 }
138
139 async fn page_rank(
140 &self,
141 config: PageRankConfig,
142 label: Option<&str>,
143 edge_type: Option<&str>,
144 ) -> HashMap<u64, f64> {
145 let store = self.store.read().await;
146 let view = build_view(&store, label, edge_type, None);
147 page_rank(&view, config)
148 }
149
150 async fn weakly_connected_components(
151 &self,
152 label: Option<&str>,
153 edge_type: Option<&str>,
154 ) -> WccResult {
155 let store = self.store.read().await;
156 let view = build_view(&store, label, edge_type, None);
157 weakly_connected_components(&view)
158 }
159
160 async fn strongly_connected_components(
161 &self,
162 label: Option<&str>,
163 edge_type: Option<&str>,
164 ) -> SccResult {
165 let store = self.store.read().await;
166 let view = build_view(&store, label, edge_type, None);
167 strongly_connected_components(&view)
168 }
169
170 async fn bfs(
171 &self,
172 source: u64,
173 target: u64,
174 label: Option<&str>,
175 edge_type: Option<&str>,
176 ) -> Option<PathResult> {
177 let store = self.store.read().await;
178 let view = build_view(&store, label, edge_type, None);
179 bfs(&view, source, target)
180 }
181
182 async fn dijkstra(
183 &self,
184 source: u64,
185 target: u64,
186 label: Option<&str>,
187 edge_type: Option<&str>,
188 weight_prop: Option<&str>,
189 ) -> Option<PathResult> {
190 let store = self.store.read().await;
191 let view = build_view(&store, label, edge_type, weight_prop);
192 dijkstra(&view, source, target)
193 }
194
195 async fn edmonds_karp(
196 &self,
197 source: u64,
198 sink: u64,
199 label: Option<&str>,
200 edge_type: Option<&str>,
201 ) -> Option<FlowResult> {
202 let store = self.store.read().await;
203 let view = build_view(&store, label, edge_type, None);
204 edmonds_karp(&view, source, sink)
205 }
206
207 async fn prim_mst(
208 &self,
209 label: Option<&str>,
210 edge_type: Option<&str>,
211 weight_prop: Option<&str>,
212 ) -> MSTResult {
213 let store = self.store.read().await;
214 let view = build_view(&store, label, edge_type, weight_prop);
215 prim_mst(&view)
216 }
217
218 async fn count_triangles(&self, label: Option<&str>, edge_type: Option<&str>) -> usize {
219 let store = self.store.read().await;
220 let view = build_view(&store, label, edge_type, None);
221 count_triangles(&view)
222 }
223
224 async fn bfs_all_shortest_paths(
225 &self,
226 source: u64,
227 target: u64,
228 label: Option<&str>,
229 edge_type: Option<&str>,
230 ) -> Vec<PathResult> {
231 let store = self.store.read().await;
232 let view = build_view(&store, label, edge_type, None);
233 bfs_all_shortest_paths(&view, source, target)
234 }
235
236 async fn cdlp(
237 &self,
238 config: CdlpConfig,
239 label: Option<&str>,
240 edge_type: Option<&str>,
241 ) -> CdlpResult {
242 let store = self.store.read().await;
243 let view = build_view(&store, label, edge_type, None);
244 cdlp(&view, &config)
245 }
246
247 async fn local_clustering_coefficient(
248 &self,
249 label: Option<&str>,
250 edge_type: Option<&str>,
251 ) -> LccResult {
252 let store = self.store.read().await;
253 let view = build_view(&store, label, edge_type, None);
254 local_clustering_coefficient(&view)
255 }
256
257 async fn pca(&self, label: Option<&str>, properties: &[&str], config: PcaConfig) -> PcaResult {
258 let store = self.store.read().await;
259 use graphmind::graph::{Label, PropertyValue};
260
261 let nodes: Vec<_> = if let Some(label_str) = label {
263 let l = Label::new(label_str);
264 store.get_nodes_by_label(&l).into_iter().collect()
265 } else {
266 store.all_nodes().into_iter().collect()
267 };
268
269 let n = nodes.len();
271 let d = properties.len();
272 let mut data_flat = vec![0.0f64; n * d];
273 for (i, node) in nodes.iter().enumerate() {
274 for (j, &prop) in properties.iter().enumerate() {
275 data_flat[i * d + j] = match node.get_property(prop) {
276 Some(PropertyValue::Integer(v)) => *v as f64,
277 Some(PropertyValue::Float(v)) => *v,
278 _ => 0.0,
279 };
280 }
281 }
282 let data: Vec<Vec<f64>> = data_flat.chunks_exact(d).map(|c| c.to_vec()).collect();
283
284 if data.is_empty() {
285 return PcaResult {
286 components: vec![],
287 explained_variance: vec![],
288 explained_variance_ratio: vec![],
289 mean: vec![0.0; properties.len()],
290 std_dev: vec![1.0; properties.len()],
291 n_samples: 0,
292 n_features: properties.len(),
293 iterations_used: 0,
294 };
295 }
296
297 pca(&data, config)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::{EmbeddedClient, GraphmindClient};
305
306 #[tokio::test]
307 async fn test_page_rank() {
308 let client = EmbeddedClient::new();
309
310 client
312 .query("default", r#"CREATE (a:Person {name: "Alice"})"#)
313 .await
314 .unwrap();
315 client
316 .query("default", r#"CREATE (b:Person {name: "Bob"})"#)
317 .await
318 .unwrap();
319 client
320 .query("default", r#"CREATE (c:Person {name: "Carol"})"#)
321 .await
322 .unwrap();
323 client.query("default",
324 r#"MATCH (a:Person {name: "Alice"}), (b:Person {name: "Bob"}) CREATE (a)-[:KNOWS]->(b)"#
325 ).await.unwrap();
326 client.query("default",
327 r#"MATCH (b:Person {name: "Bob"}), (c:Person {name: "Carol"}) CREATE (b)-[:KNOWS]->(c)"#
328 ).await.unwrap();
329 client.query("default",
330 r#"MATCH (a:Person {name: "Alice"}), (c:Person {name: "Carol"}) CREATE (a)-[:KNOWS]->(c)"#
331 ).await.unwrap();
332
333 let scores = client
334 .page_rank(PageRankConfig::default(), Some("Person"), Some("KNOWS"))
335 .await;
336 assert_eq!(scores.len(), 3);
337 let max_node = scores
339 .iter()
340 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
341 .unwrap();
342 assert!(*max_node.1 > 0.0);
344 }
345
346 #[tokio::test]
347 async fn test_wcc() {
348 let client = EmbeddedClient::new();
349
350 client
352 .query(
353 "default",
354 r#"CREATE (a:Person {name: "Alice"})-[:KNOWS]->(b:Person {name: "Bob"})"#,
355 )
356 .await
357 .unwrap();
358 client
359 .query(
360 "default",
361 r#"CREATE (c:Person {name: "Carol"})-[:KNOWS]->(d:Person {name: "Dave"})"#,
362 )
363 .await
364 .unwrap();
365
366 let wcc = client
367 .weakly_connected_components(Some("Person"), Some("KNOWS"))
368 .await;
369 assert_eq!(wcc.components.len(), 2);
370 }
371
372 #[tokio::test]
373 async fn test_bfs() {
374 let client = EmbeddedClient::new();
375
376 client
377 .query(
378 "default",
379 r#"CREATE (a:Person {name: "Alice"})-[:KNOWS]->(b:Person {name: "Bob"})"#,
380 )
381 .await
382 .unwrap();
383 client
384 .query(
385 "default",
386 r#"MATCH (b:Person {name: "Bob"}) CREATE (b)-[:KNOWS]->(c:Person {name: "Carol"})"#,
387 )
388 .await
389 .unwrap();
390
391 let store = client.store().read().await;
393 let all_nodes: Vec<_> = store.all_nodes().iter().map(|n| n.id.as_u64()).collect();
394 drop(store);
395
396 if all_nodes.len() >= 3 {
397 let result = client
398 .bfs(all_nodes[0], all_nodes[2], Some("Person"), Some("KNOWS"))
399 .await;
400 assert!(result.is_some());
401 let path = result.unwrap();
402 assert!(path.path.len() >= 2);
403 }
404 }
405}