1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
18
19use crate::graph::graph_classifier::{GraphClass, GraphClassifier, GraphValue};
20use crate::graph::graph_query::{GraphQuery, TraversalDirection, TraversalWeight};
21use crate::pattern::Pattern;
22
23#[inline]
35fn reachable_neighbors<V>(
36 q: &GraphQuery<V>,
37 weight: &TraversalWeight<V>,
38 node: &Pattern<V>,
39) -> Vec<(Pattern<V>, f64)>
40where
41 V: GraphValue + Clone,
42{
43 let node_id = node.value.identify();
44 let rels = (q.query_incident_rels)(node);
45 let mut neighbors = Vec::new();
46
47 for rel in rels {
48 let src = (q.query_source)(&rel);
49 let tgt = (q.query_target)(&rel);
50
51 if let Some(ref s) = src {
53 if s.value.identify() == node_id {
54 let fwd = weight(&rel, TraversalDirection::Forward);
55 if fwd.is_finite() {
56 if let Some(t) = tgt.clone() {
57 neighbors.push((t, fwd));
58 }
59 }
60 }
61 }
62
63 if let Some(ref t) = tgt {
65 if t.value.identify() == node_id {
66 let bwd = weight(&rel, TraversalDirection::Backward);
67 if bwd.is_finite() {
68 if let Some(s) = src.clone() {
69 neighbors.push((s, bwd));
70 }
71 }
72 }
73 }
74 }
75
76 neighbors
77}
78
79pub fn bfs<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>, start: &Pattern<V>) -> Vec<Pattern<V>>
87where
88 V: GraphValue + Clone,
89 V::Id: Clone + Eq + std::hash::Hash + Ord,
90{
91 let mut visited: HashSet<V::Id> = HashSet::new();
92 let mut queue = VecDeque::new();
93 let mut result = Vec::new();
94
95 let start_id = start.value.identify().clone();
96 visited.insert(start_id);
97 queue.push_back(start.clone());
98
99 while let Some(current) = queue.pop_front() {
100 result.push(current.clone());
101 for (neighbor, _cost) in reachable_neighbors(q, weight, ¤t) {
102 let nid = neighbor.value.identify().clone();
103 if visited.insert(nid) {
104 queue.push_back(neighbor);
105 }
106 }
107 }
108
109 result
110}
111
112pub fn dfs<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>, start: &Pattern<V>) -> Vec<Pattern<V>>
116where
117 V: GraphValue + Clone,
118 V::Id: Clone + Eq + std::hash::Hash + Ord,
119{
120 let mut visited: HashSet<V::Id> = HashSet::new();
121 let mut stack = vec![start.clone()];
122 let mut result = Vec::new();
123
124 while let Some(current) = stack.pop() {
125 let cid = current.value.identify().clone();
126 if visited.insert(cid) {
127 result.push(current.clone());
128 for (neighbor, _cost) in reachable_neighbors(q, weight, ¤t) {
129 if !visited.contains(neighbor.value.identify()) {
130 stack.push(neighbor);
131 }
132 }
133 }
134 }
135
136 result
137}
138
139pub fn shortest_path<V>(
149 q: &GraphQuery<V>,
150 weight: &TraversalWeight<V>,
151 from: &Pattern<V>,
152 to: &Pattern<V>,
153) -> Option<Vec<Pattern<V>>>
154where
155 V: GraphValue + Clone,
156 V::Id: Clone + Eq + std::hash::Hash + Ord,
157{
158 if from.value.identify() == to.value.identify() {
160 return Some(vec![from.clone()]);
161 }
162
163 let mut dist: HashMap<V::Id, f64> = HashMap::new();
165 let mut prev: HashMap<V::Id, Pattern<V>> = HashMap::new();
167
168 let from_id = from.value.identify().clone();
169 dist.insert(from_id.clone(), 0.0);
170
171 let mut pq: BTreeMap<(u64, V::Id), Pattern<V>> = BTreeMap::new();
174 pq.insert((0u64, from_id.clone()), from.clone());
175
176 while let Some(((cost_bits, uid), node)) = pq.pop_first() {
177 let cost = f64::from_bits(cost_bits);
178
179 if let Some(&best) = dist.get(&uid) {
181 if cost > best {
182 continue;
183 }
184 }
185
186 if uid == *to.value.identify() {
188 let mut path = vec![node.clone()];
189 let mut cur_id = uid.clone();
190 while let Some(p) = prev.get(&cur_id) {
191 path.push(p.clone());
192 cur_id = p.value.identify().clone();
193 }
194 path.reverse();
195 return Some(path);
196 }
197
198 for (neighbor, edge_cost) in reachable_neighbors(q, weight, &node) {
199 if !edge_cost.is_finite() {
200 continue;
201 }
202 let new_cost = cost + edge_cost;
203 let nid = neighbor.value.identify().clone();
204
205 let should_update = dist.get(&nid).map(|&d| new_cost < d).unwrap_or(true);
206 if should_update {
207 dist.insert(nid.clone(), new_cost);
208 prev.insert(nid.clone(), node.clone());
209 pq.insert((new_cost.to_bits(), nid), neighbor);
210 }
211 }
212 }
213
214 None
215}
216
217pub fn has_path<V>(
221 q: &GraphQuery<V>,
222 weight: &TraversalWeight<V>,
223 from: &Pattern<V>,
224 to: &Pattern<V>,
225) -> bool
226where
227 V: GraphValue + Clone,
228 V::Id: Clone + Eq + std::hash::Hash + Ord,
229{
230 shortest_path(q, weight, from, to).is_some()
231}
232
233pub fn all_paths<V>(
238 q: &GraphQuery<V>,
239 weight: &TraversalWeight<V>,
240 from: &Pattern<V>,
241 to: &Pattern<V>,
242) -> Vec<Vec<Pattern<V>>>
243where
244 V: GraphValue + Clone,
245 V::Id: Clone + Eq + std::hash::Hash + Ord,
246{
247 let mut all = Vec::new();
248 let mut current_path = vec![from.clone()];
249 let mut visited: HashSet<V::Id> = HashSet::new();
250 visited.insert(from.value.identify().clone());
251
252 all_paths_dfs(
253 q,
254 weight,
255 from,
256 to,
257 &mut visited,
258 &mut current_path,
259 &mut all,
260 );
261 all
262}
263
264fn all_paths_dfs<V>(
265 q: &GraphQuery<V>,
266 weight: &TraversalWeight<V>,
267 current: &Pattern<V>,
268 to: &Pattern<V>,
269 visited: &mut HashSet<V::Id>,
270 current_path: &mut Vec<Pattern<V>>,
271 all: &mut Vec<Vec<Pattern<V>>>,
272) where
273 V: GraphValue + Clone,
274 V::Id: Clone + Eq + std::hash::Hash + Ord,
275{
276 if current.value.identify() == to.value.identify() {
277 all.push(current_path.clone());
278 return;
279 }
280
281 for (neighbor, _cost) in reachable_neighbors(q, weight, current) {
282 let nid = neighbor.value.identify().clone();
283 if !visited.contains(&nid) {
284 visited.insert(nid.clone());
285 current_path.push(neighbor.clone());
286 all_paths_dfs(q, weight, &neighbor, to, visited, current_path, all);
287 current_path.pop();
288 visited.remove(&nid);
289 }
290 }
291}
292
293pub fn is_neighbor<V>(
299 q: &GraphQuery<V>,
300 weight: &TraversalWeight<V>,
301 a: &Pattern<V>,
302 b: &Pattern<V>,
303) -> bool
304where
305 V: GraphValue + Clone,
306 V::Id: Clone + Eq + std::hash::Hash,
307{
308 let b_id = b.value.identify();
309 reachable_neighbors(q, weight, a)
310 .iter()
311 .any(|(n, _)| n.value.identify() == b_id)
312}
313
314pub fn is_connected<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>) -> bool
318where
319 V: GraphValue + Clone,
320 V::Id: Clone + Eq + std::hash::Hash + Ord,
321{
322 let nodes = (q.query_nodes)();
323 if nodes.is_empty() {
324 return true;
325 }
326 let visited = bfs(q, weight, &nodes[0]);
327 visited.len() == nodes.len()
328}
329
330pub fn connected_components<V>(
339 q: &GraphQuery<V>,
340 weight: &TraversalWeight<V>,
341) -> Vec<Vec<Pattern<V>>>
342where
343 V: GraphValue + Clone,
344 V::Id: Clone + Eq + std::hash::Hash + Ord,
345{
346 let all_nodes = (q.query_nodes)();
347 let mut visited: HashSet<V::Id> = HashSet::new();
348 let mut components = Vec::new();
349
350 for node in &all_nodes {
351 let nid = node.value.identify().clone();
352 if !visited.contains(&nid) {
353 let component = bfs(q, weight, node);
354 for n in &component {
355 visited.insert(n.value.identify().clone());
356 }
357 components.push(component);
358 }
359 }
360
361 components
362}
363
364pub fn topological_sort<V>(q: &GraphQuery<V>) -> Option<Vec<Pattern<V>>>
370where
371 V: GraphValue + Clone,
372 V::Id: Clone + Eq + std::hash::Hash + Ord,
373{
374 let nodes = (q.query_nodes)();
375
376 let mut in_stack: HashSet<V::Id> = HashSet::new();
377 let mut done: HashSet<V::Id> = HashSet::new();
378 let mut result: Vec<Pattern<V>> = Vec::new();
379
380 let forward_neighbors = |node: &Pattern<V>| -> Vec<Pattern<V>> {
382 let rels = (q.query_incident_rels)(node);
383 let node_id = node.value.identify();
384 rels.into_iter()
385 .filter_map(|rel| {
386 let src = (q.query_source)(&rel)?;
387 if src.value.identify() == node_id {
388 (q.query_target)(&rel)
389 } else {
390 None
391 }
392 })
393 .collect()
394 };
395
396 for start in &nodes {
397 if done.contains(start.value.identify()) {
398 continue;
399 }
400
401 let start_id = start.value.identify().clone();
402 in_stack.insert(start_id);
403 let neighbors = forward_neighbors(start);
404 let mut stack: Vec<(Pattern<V>, Vec<Pattern<V>>, usize)> =
406 vec![(start.clone(), neighbors, 0)];
407
408 while !stack.is_empty() {
409 let cur_idx = stack.last().unwrap().2;
410 let neighbors_len = stack.last().unwrap().1.len();
411
412 if cur_idx < neighbors_len {
413 let neighbor = stack.last().unwrap().1[cur_idx].clone();
414 stack.last_mut().unwrap().2 += 1;
415
416 let nid = neighbor.value.identify().clone();
417 if in_stack.contains(&nid) {
418 return None; }
420 if !done.contains(&nid) {
421 in_stack.insert(nid);
422 let next_neighbors = forward_neighbors(&neighbor);
423 stack.push((neighbor, next_neighbors, 0));
424 }
425 } else {
426 let (node, _, _) = stack.pop().unwrap();
427 let nid = node.value.identify().clone();
428 in_stack.remove(&nid);
429 done.insert(nid);
430 result.push(node);
431 }
432 }
433 }
434
435 result.reverse();
436 Some(result)
437}
438
439pub fn has_cycle<V>(q: &GraphQuery<V>) -> bool
443where
444 V: GraphValue + Clone,
445 V::Id: Clone + Eq + std::hash::Hash + Ord,
446{
447 topological_sort(q).is_none()
448}
449
450pub fn minimum_spanning_tree<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>) -> Vec<Pattern<V>>
460where
461 V: GraphValue + Clone,
462 V::Id: Clone + Eq + std::hash::Hash + Ord,
463{
464 let nodes = (q.query_nodes)();
465 if nodes.is_empty() {
466 return Vec::new();
467 }
468
469 let mut edges: Vec<(f64, Pattern<V>)> = (q.query_relationships)()
471 .into_iter()
472 .filter_map(|rel| {
473 let fwd = weight(&rel, TraversalDirection::Forward);
474 let bwd = weight(&rel, TraversalDirection::Backward);
475 let cost = fwd.min(bwd);
476 if cost.is_finite() {
477 Some((cost, rel))
478 } else {
479 None
480 }
481 })
482 .collect();
483
484 edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
486
487 let mut parent: HashMap<V::Id, V::Id> = nodes
489 .iter()
490 .map(|n| (n.value.identify().clone(), n.value.identify().clone()))
491 .collect();
492
493 let mut mst_node_ids: HashSet<V::Id> = HashSet::new();
494
495 for (_, rel) in edges {
496 let src = match (q.query_source)(&rel) {
497 Some(s) => s,
498 None => continue,
499 };
500 let tgt = match (q.query_target)(&rel) {
501 Some(t) => t,
502 None => continue,
503 };
504
505 let src_id = src.value.identify().clone();
506 let tgt_id = tgt.value.identify().clone();
507
508 let root_src = uf_find(&mut parent, src_id.clone());
509 let root_tgt = uf_find(&mut parent, tgt_id.clone());
510
511 if root_src != root_tgt {
512 parent.insert(root_src, root_tgt);
514 mst_node_ids.insert(src_id);
515 mst_node_ids.insert(tgt_id);
516 }
517 }
518
519 nodes
520 .into_iter()
521 .filter(|n| mst_node_ids.contains(n.value.identify()))
522 .collect()
523}
524
525fn uf_find<Id>(parent: &mut HashMap<Id, Id>, x: Id) -> Id
527where
528 Id: Clone + Eq + std::hash::Hash,
529{
530 let p = parent[&x].clone();
531 if p == x {
532 return x;
533 }
534 let root = uf_find(parent, p);
535 parent.insert(x, root.clone());
536 root
537}
538
539pub fn degree_centrality<V>(q: &GraphQuery<V>) -> HashMap<V::Id, f64>
551where
552 V: GraphValue + Clone,
553 V::Id: Clone + Eq + std::hash::Hash,
554{
555 let nodes = (q.query_nodes)();
556 let n = nodes.len();
557 let mut result = HashMap::new();
558
559 for node in &nodes {
560 let degree = (q.query_degree)(node) as f64;
561 let centrality = if n > 1 { degree / (n - 1) as f64 } else { 0.0 };
562 result.insert(node.value.identify().clone(), centrality);
563 }
564
565 result
566}
567
568pub fn betweenness_centrality<V>(
577 q: &GraphQuery<V>,
578 weight: &TraversalWeight<V>,
579) -> HashMap<V::Id, f64>
580where
581 V: GraphValue + Clone,
582 V::Id: Clone + Eq + std::hash::Hash + Ord,
583{
584 let nodes = (q.query_nodes)();
585 let mut betweenness: HashMap<V::Id, f64> = nodes
586 .iter()
587 .map(|n| (n.value.identify().clone(), 0.0))
588 .collect();
589
590 for s in &nodes {
591 let s_id = s.value.identify().clone();
592
593 let mut stack: Vec<Pattern<V>> = Vec::new();
595 let mut pred: HashMap<V::Id, Vec<Pattern<V>>> = nodes
596 .iter()
597 .map(|n| (n.value.identify().clone(), Vec::new()))
598 .collect();
599 let mut sigma: HashMap<V::Id, f64> = nodes
600 .iter()
601 .map(|n| (n.value.identify().clone(), 0.0))
602 .collect();
603 sigma.insert(s_id.clone(), 1.0);
604 let mut dist: HashMap<V::Id, i64> = nodes
605 .iter()
606 .map(|n| (n.value.identify().clone(), -1))
607 .collect();
608 dist.insert(s_id.clone(), 0);
609
610 let mut queue = VecDeque::new();
611 queue.push_back(s.clone());
612
613 while let Some(v) = queue.pop_front() {
614 stack.push(v.clone());
615 let v_id = v.value.identify().clone();
616 let v_dist = dist[&v_id];
617 let v_sigma = sigma[&v_id];
618
619 for (w, _cost) in reachable_neighbors(q, weight, &v) {
620 let w_id = w.value.identify().clone();
621 if dist[&w_id] < 0 {
623 queue.push_back(w.clone());
624 *dist.get_mut(&w_id).unwrap() = v_dist + 1;
625 }
626 if dist[&w_id] == v_dist + 1 {
628 *sigma.get_mut(&w_id).unwrap() += v_sigma;
629 pred.get_mut(&w_id).unwrap().push(v.clone());
630 }
631 }
632 }
633
634 let mut delta: HashMap<V::Id, f64> = nodes
636 .iter()
637 .map(|n| (n.value.identify().clone(), 0.0))
638 .collect();
639
640 while let Some(w) = stack.pop() {
641 let w_id = w.value.identify().clone();
642 for v in &pred[&w_id] {
643 let v_id = v.value.identify().clone();
644 let sigma_w = sigma[&w_id];
645 if sigma_w != 0.0 {
646 let coeff = sigma[&v_id] / sigma_w * (1.0 + delta[&w_id]);
647 *delta.get_mut(&v_id).unwrap() += coeff;
648 }
649 }
650 if w_id != s_id {
651 *betweenness.get_mut(&w_id).unwrap() += delta[&w_id];
652 }
653 }
654 }
655
656 betweenness
657}
658
659pub fn query_annotations_of<Extra, V>(
665 classifier: &GraphClassifier<Extra, V>,
666 q: &GraphQuery<V>,
667 element: &Pattern<V>,
668) -> Vec<Pattern<V>>
669where
670 V: GraphValue + Clone,
671{
672 (q.query_containers)(element)
673 .into_iter()
674 .filter(|c| matches!((classifier.classify)(c), GraphClass::GAnnotation))
675 .collect()
676}
677
678pub fn query_walks_containing<Extra, V>(
680 classifier: &GraphClassifier<Extra, V>,
681 q: &GraphQuery<V>,
682 element: &Pattern<V>,
683) -> Vec<Pattern<V>>
684where
685 V: GraphValue + Clone,
686{
687 (q.query_containers)(element)
688 .into_iter()
689 .filter(|c| matches!((classifier.classify)(c), GraphClass::GWalk))
690 .collect()
691}
692
693pub fn query_co_members<V>(
698 _q: &GraphQuery<V>,
699 element: &Pattern<V>,
700 container: &Pattern<V>,
701) -> Vec<Pattern<V>>
702where
703 V: GraphValue + Clone,
704 V::Id: Clone + Eq + std::hash::Hash,
705{
706 let elem_id = element.value.identify();
707 container
708 .elements
709 .iter()
710 .filter(|e| e.value.identify() != elem_id)
711 .cloned()
712 .collect()
713}