Skip to main content

pattern_core/graph/
algorithms.rs

1//! Graph algorithms operating against the `GraphQuery<V>` interface.
2//!
3//! Ported from `Pattern.Graph.Algorithms` in the Haskell reference implementation.
4//!
5//! All functions are representation-independent: they operate on [`GraphQuery<V>`]
6//! closures only. The same code works against `PatternGraph`, in-memory closures,
7//! or any future backing store.
8//!
9//! # Traversal weight semantics
10//!
11//! Pass a [`TraversalWeight<V>`] to control which edges are traversable and at
12//! what cost. Use the canonical functions [`undirected`](crate::graph::graph_query::undirected), [`directed`](crate::graph::graph_query::directed), or
13//! [`directed_reverse`](crate::graph::graph_query::directed_reverse), or supply a custom `Rc<dyn Fn(...)>`.
14//!
15//! An edge with `INFINITY` cost in a given direction is impassable in that direction.
16
17use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
18
19use crate::graph::graph_classifier::{GraphClass, GraphClassifier, GraphValue};
20use crate::graph::graph_query::{GraphQuery, TraversalDirection, TraversalWeight};
21use crate::pattern::Pattern;
22
23// ============================================================================
24// Private helper: reachable_neighbors
25// ============================================================================
26
27/// Returns all immediately reachable neighbors of `node` under `weight`,
28/// together with the traversal cost to reach each neighbor.
29///
30/// A neighbor is reachable via a forward traversal if the node is the source
31/// of an incident relationship and the forward cost is finite. A neighbor is
32/// reachable via a backward traversal if the node is the target and the
33/// backward cost is finite.
34#[inline]
35fn reachable_neighbors<V>(
36    q: &GraphQuery<V>,
37    weight: &TraversalWeight<V>,
38    node: &Pattern<V>,
39) -> Vec<(Pattern<V>, f64)>
40where
41    V: GraphValue + Clone,
42{
43    let node_id = node.value.identify();
44    let rels = (q.query_incident_rels)(node);
45    let mut neighbors = Vec::new();
46
47    for rel in rels {
48        let src = (q.query_source)(&rel);
49        let tgt = (q.query_target)(&rel);
50
51        // Forward: node is the source → neighbor is the target
52        if let Some(ref s) = src {
53            if s.value.identify() == node_id {
54                let fwd = weight(&rel, TraversalDirection::Forward);
55                if fwd.is_finite() {
56                    if let Some(t) = tgt.clone() {
57                        neighbors.push((t, fwd));
58                    }
59                }
60            }
61        }
62
63        // Backward: node is the target → neighbor is the source
64        if let Some(ref t) = tgt {
65            if t.value.identify() == node_id {
66                let bwd = weight(&rel, TraversalDirection::Backward);
67                if bwd.is_finite() {
68                    if let Some(s) = src.clone() {
69                        neighbors.push((s, bwd));
70                    }
71                }
72            }
73        }
74    }
75
76    neighbors
77}
78
79// ============================================================================
80// Traversal algorithms
81// ============================================================================
82
83/// Breadth-first traversal from `start`.
84///
85/// Returns nodes in BFS visit order. The start node is always included.
86pub fn bfs<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>, start: &Pattern<V>) -> Vec<Pattern<V>>
87where
88    V: GraphValue + Clone,
89    V::Id: Clone + Eq + std::hash::Hash + Ord,
90{
91    let mut visited: HashSet<V::Id> = HashSet::new();
92    let mut queue = VecDeque::new();
93    let mut result = Vec::new();
94
95    let start_id = start.value.identify().clone();
96    visited.insert(start_id);
97    queue.push_back(start.clone());
98
99    while let Some(current) = queue.pop_front() {
100        result.push(current.clone());
101        for (neighbor, _cost) in reachable_neighbors(q, weight, &current) {
102            let nid = neighbor.value.identify().clone();
103            if visited.insert(nid) {
104                queue.push_back(neighbor);
105            }
106        }
107    }
108
109    result
110}
111
112/// Depth-first traversal from `start`.
113///
114/// Returns nodes in DFS visit order. The start node is always included.
115pub fn dfs<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>, start: &Pattern<V>) -> Vec<Pattern<V>>
116where
117    V: GraphValue + Clone,
118    V::Id: Clone + Eq + std::hash::Hash + Ord,
119{
120    let mut visited: HashSet<V::Id> = HashSet::new();
121    let mut stack = vec![start.clone()];
122    let mut result = Vec::new();
123
124    while let Some(current) = stack.pop() {
125        let cid = current.value.identify().clone();
126        if visited.insert(cid) {
127            result.push(current.clone());
128            for (neighbor, _cost) in reachable_neighbors(q, weight, &current) {
129                if !visited.contains(neighbor.value.identify()) {
130                    stack.push(neighbor);
131                }
132            }
133        }
134    }
135
136    result
137}
138
139// ============================================================================
140// Path algorithms
141// ============================================================================
142
143/// Find the minimum-cost path from `from` to `to` using Dijkstra's algorithm.
144///
145/// - Same node: returns `Some(vec![node])` immediately.
146/// - No path: returns `None`.
147/// - Uses `f64::INFINITY` cost to mark impassable edges.
148pub fn shortest_path<V>(
149    q: &GraphQuery<V>,
150    weight: &TraversalWeight<V>,
151    from: &Pattern<V>,
152    to: &Pattern<V>,
153) -> Option<Vec<Pattern<V>>>
154where
155    V: GraphValue + Clone,
156    V::Id: Clone + Eq + std::hash::Hash + Ord,
157{
158    // Same-node case
159    if from.value.identify() == to.value.identify() {
160        return Some(vec![from.clone()]);
161    }
162
163    // dist[id] = best known cost from `from` to node with that id
164    let mut dist: HashMap<V::Id, f64> = HashMap::new();
165    // prev[id] = predecessor node on the best-known path
166    let mut prev: HashMap<V::Id, Pattern<V>> = HashMap::new();
167
168    let from_id = from.value.identify().clone();
169    dist.insert(from_id.clone(), 0.0);
170
171    // Priority queue: (cost_bits_for_ordering, node_id) → node
172    // For non-negative finite f64, the IEEE 754 bit pattern preserves ordering.
173    let mut pq: BTreeMap<(u64, V::Id), Pattern<V>> = BTreeMap::new();
174    pq.insert((0u64, from_id.clone()), from.clone());
175
176    while let Some(((cost_bits, uid), node)) = pq.pop_first() {
177        let cost = f64::from_bits(cost_bits);
178
179        // Skip stale entries
180        if let Some(&best) = dist.get(&uid) {
181            if cost > best {
182                continue;
183            }
184        }
185
186        // Reached destination
187        if uid == *to.value.identify() {
188            let mut path = vec![node.clone()];
189            let mut cur_id = uid.clone();
190            while let Some(p) = prev.get(&cur_id) {
191                path.push(p.clone());
192                cur_id = p.value.identify().clone();
193            }
194            path.reverse();
195            return Some(path);
196        }
197
198        for (neighbor, edge_cost) in reachable_neighbors(q, weight, &node) {
199            if !edge_cost.is_finite() {
200                continue;
201            }
202            let new_cost = cost + edge_cost;
203            let nid = neighbor.value.identify().clone();
204
205            let should_update = dist.get(&nid).map(|&d| new_cost < d).unwrap_or(true);
206            if should_update {
207                dist.insert(nid.clone(), new_cost);
208                prev.insert(nid.clone(), node.clone());
209                pq.insert((new_cost.to_bits(), nid), neighbor);
210            }
211        }
212    }
213
214    None
215}
216
217/// Returns `true` if a path exists from `from` to `to`.
218///
219/// Delegates to [`shortest_path`].
220pub fn has_path<V>(
221    q: &GraphQuery<V>,
222    weight: &TraversalWeight<V>,
223    from: &Pattern<V>,
224    to: &Pattern<V>,
225) -> bool
226where
227    V: GraphValue + Clone,
228    V::Id: Clone + Eq + std::hash::Hash + Ord,
229{
230    shortest_path(q, weight, from, to).is_some()
231}
232
233/// Enumerate all simple paths from `from` to `to` (no repeated nodes).
234///
235/// Returns a `Vec` of paths. Exponential worst case — use only on small graphs
236/// or bounded subgraphs.
237pub fn all_paths<V>(
238    q: &GraphQuery<V>,
239    weight: &TraversalWeight<V>,
240    from: &Pattern<V>,
241    to: &Pattern<V>,
242) -> Vec<Vec<Pattern<V>>>
243where
244    V: GraphValue + Clone,
245    V::Id: Clone + Eq + std::hash::Hash + Ord,
246{
247    let mut all = Vec::new();
248    let mut current_path = vec![from.clone()];
249    let mut visited: HashSet<V::Id> = HashSet::new();
250    visited.insert(from.value.identify().clone());
251
252    all_paths_dfs(
253        q,
254        weight,
255        from,
256        to,
257        &mut visited,
258        &mut current_path,
259        &mut all,
260    );
261    all
262}
263
264fn all_paths_dfs<V>(
265    q: &GraphQuery<V>,
266    weight: &TraversalWeight<V>,
267    current: &Pattern<V>,
268    to: &Pattern<V>,
269    visited: &mut HashSet<V::Id>,
270    current_path: &mut Vec<Pattern<V>>,
271    all: &mut Vec<Vec<Pattern<V>>>,
272) where
273    V: GraphValue + Clone,
274    V::Id: Clone + Eq + std::hash::Hash + Ord,
275{
276    if current.value.identify() == to.value.identify() {
277        all.push(current_path.clone());
278        return;
279    }
280
281    for (neighbor, _cost) in reachable_neighbors(q, weight, current) {
282        let nid = neighbor.value.identify().clone();
283        if !visited.contains(&nid) {
284            visited.insert(nid.clone());
285            current_path.push(neighbor.clone());
286            all_paths_dfs(q, weight, &neighbor, to, visited, current_path, all);
287            current_path.pop();
288            visited.remove(&nid);
289        }
290    }
291}
292
293// ============================================================================
294// Boolean queries
295// ============================================================================
296
297/// Returns `true` if `b` is directly reachable from `a` in one hop under `weight`.
298pub fn is_neighbor<V>(
299    q: &GraphQuery<V>,
300    weight: &TraversalWeight<V>,
301    a: &Pattern<V>,
302    b: &Pattern<V>,
303) -> bool
304where
305    V: GraphValue + Clone,
306    V::Id: Clone + Eq + std::hash::Hash,
307{
308    let b_id = b.value.identify();
309    reachable_neighbors(q, weight, a)
310        .iter()
311        .any(|(n, _)| n.value.identify() == b_id)
312}
313
314/// Returns `true` if the entire graph is connected under `weight`.
315///
316/// An empty graph is vacuously connected (returns `true`).
317pub fn is_connected<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>) -> bool
318where
319    V: GraphValue + Clone,
320    V::Id: Clone + Eq + std::hash::Hash + Ord,
321{
322    let nodes = (q.query_nodes)();
323    if nodes.is_empty() {
324        return true;
325    }
326    let visited = bfs(q, weight, &nodes[0]);
327    visited.len() == nodes.len()
328}
329
330// ============================================================================
331// Structural algorithms
332// ============================================================================
333
334/// Partition the graph into connected components.
335///
336/// Returns a `Vec` of `Vec`s; each inner `Vec` is one component.
337/// Uses BFS internally.
338pub fn connected_components<V>(
339    q: &GraphQuery<V>,
340    weight: &TraversalWeight<V>,
341) -> Vec<Vec<Pattern<V>>>
342where
343    V: GraphValue + Clone,
344    V::Id: Clone + Eq + std::hash::Hash + Ord,
345{
346    let all_nodes = (q.query_nodes)();
347    let mut visited: HashSet<V::Id> = HashSet::new();
348    let mut components = Vec::new();
349
350    for node in &all_nodes {
351        let nid = node.value.identify().clone();
352        if !visited.contains(&nid) {
353            let component = bfs(q, weight, node);
354            for n in &component {
355                visited.insert(n.value.identify().clone());
356            }
357            components.push(component);
358        }
359    }
360
361    components
362}
363
364/// Topological sort using iterative DFS post-order with cycle detection.
365///
366/// - Returns `Some(order)` if the graph is a DAG.
367/// - Returns `None` if a directed cycle is detected.
368/// - Ignores `TraversalWeight` — uses relationship endpoint order only.
369pub fn topological_sort<V>(q: &GraphQuery<V>) -> Option<Vec<Pattern<V>>>
370where
371    V: GraphValue + Clone,
372    V::Id: Clone + Eq + std::hash::Hash + Ord,
373{
374    let nodes = (q.query_nodes)();
375
376    let mut in_stack: HashSet<V::Id> = HashSet::new();
377    let mut done: HashSet<V::Id> = HashSet::new();
378    let mut result: Vec<Pattern<V>> = Vec::new();
379
380    // Returns forward neighbors (rels where node is the source)
381    let forward_neighbors = |node: &Pattern<V>| -> Vec<Pattern<V>> {
382        let rels = (q.query_incident_rels)(node);
383        let node_id = node.value.identify();
384        rels.into_iter()
385            .filter_map(|rel| {
386                let src = (q.query_source)(&rel)?;
387                if src.value.identify() == node_id {
388                    (q.query_target)(&rel)
389                } else {
390                    None
391                }
392            })
393            .collect()
394    };
395
396    for start in &nodes {
397        if done.contains(start.value.identify()) {
398            continue;
399        }
400
401        let start_id = start.value.identify().clone();
402        in_stack.insert(start_id);
403        let neighbors = forward_neighbors(start);
404        // Stack: (node, its_forward_neighbors, current_neighbor_index)
405        let mut stack: Vec<(Pattern<V>, Vec<Pattern<V>>, usize)> =
406            vec![(start.clone(), neighbors, 0)];
407
408        while !stack.is_empty() {
409            let cur_idx = stack.last().unwrap().2;
410            let neighbors_len = stack.last().unwrap().1.len();
411
412            if cur_idx < neighbors_len {
413                let neighbor = stack.last().unwrap().1[cur_idx].clone();
414                stack.last_mut().unwrap().2 += 1;
415
416                let nid = neighbor.value.identify().clone();
417                if in_stack.contains(&nid) {
418                    return None; // Back edge — cycle detected
419                }
420                if !done.contains(&nid) {
421                    in_stack.insert(nid);
422                    let next_neighbors = forward_neighbors(&neighbor);
423                    stack.push((neighbor, next_neighbors, 0));
424                }
425            } else {
426                let (node, _, _) = stack.pop().unwrap();
427                let nid = node.value.identify().clone();
428                in_stack.remove(&nid);
429                done.insert(nid);
430                result.push(node);
431            }
432        }
433    }
434
435    result.reverse();
436    Some(result)
437}
438
439/// Returns `true` if the graph contains a directed cycle.
440///
441/// Delegates to [`topological_sort`].
442pub fn has_cycle<V>(q: &GraphQuery<V>) -> bool
443where
444    V: GraphValue + Clone,
445    V::Id: Clone + Eq + std::hash::Hash + Ord,
446{
447    topological_sort(q).is_none()
448}
449
450// ============================================================================
451// Spanning
452// ============================================================================
453
454/// Minimum spanning tree using Kruskal's algorithm with path-compression union-find.
455///
456/// - Edge cost is `min(forward_cost, backward_cost)`.
457/// - Edges with `INFINITY` cost in both directions are excluded.
458/// - Returns the subset of nodes that are included in the MST.
459pub fn minimum_spanning_tree<V>(q: &GraphQuery<V>, weight: &TraversalWeight<V>) -> Vec<Pattern<V>>
460where
461    V: GraphValue + Clone,
462    V::Id: Clone + Eq + std::hash::Hash + Ord,
463{
464    let nodes = (q.query_nodes)();
465    if nodes.is_empty() {
466        return Vec::new();
467    }
468
469    // Collect all edges with their MST cost
470    let mut edges: Vec<(f64, Pattern<V>)> = (q.query_relationships)()
471        .into_iter()
472        .filter_map(|rel| {
473            let fwd = weight(&rel, TraversalDirection::Forward);
474            let bwd = weight(&rel, TraversalDirection::Backward);
475            let cost = fwd.min(bwd);
476            if cost.is_finite() {
477                Some((cost, rel))
478            } else {
479                None
480            }
481        })
482        .collect();
483
484    // Sort edges by cost (ascending)
485    edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
486
487    // Initialize union-find: each node is its own component
488    let mut parent: HashMap<V::Id, V::Id> = nodes
489        .iter()
490        .map(|n| (n.value.identify().clone(), n.value.identify().clone()))
491        .collect();
492
493    let mut mst_node_ids: HashSet<V::Id> = HashSet::new();
494
495    for (_, rel) in edges {
496        let src = match (q.query_source)(&rel) {
497            Some(s) => s,
498            None => continue,
499        };
500        let tgt = match (q.query_target)(&rel) {
501            Some(t) => t,
502            None => continue,
503        };
504
505        let src_id = src.value.identify().clone();
506        let tgt_id = tgt.value.identify().clone();
507
508        let root_src = uf_find(&mut parent, src_id.clone());
509        let root_tgt = uf_find(&mut parent, tgt_id.clone());
510
511        if root_src != root_tgt {
512            // Union: merge the two components
513            parent.insert(root_src, root_tgt);
514            mst_node_ids.insert(src_id);
515            mst_node_ids.insert(tgt_id);
516        }
517    }
518
519    nodes
520        .into_iter()
521        .filter(|n| mst_node_ids.contains(n.value.identify()))
522        .collect()
523}
524
525/// Path-compression union-find: returns the root of the component containing `x`.
526fn uf_find<Id>(parent: &mut HashMap<Id, Id>, x: Id) -> Id
527where
528    Id: Clone + Eq + std::hash::Hash,
529{
530    let p = parent[&x].clone();
531    if p == x {
532        return x;
533    }
534    let root = uf_find(parent, p);
535    parent.insert(x, root.clone());
536    root
537}
538
539// ============================================================================
540// Centrality
541// ============================================================================
542
543/// Degree centrality for all nodes.
544///
545/// For a graph with `n` nodes, the degree centrality of node `v` is
546/// `degree(v) / (n - 1)`. Returns 0.0 for all nodes in a single-node graph.
547///
548/// Does **not** take a `TraversalWeight` parameter — degree centrality is a
549/// structural property (count of incident relationships, direction-agnostic).
550pub fn degree_centrality<V>(q: &GraphQuery<V>) -> HashMap<V::Id, f64>
551where
552    V: GraphValue + Clone,
553    V::Id: Clone + Eq + std::hash::Hash,
554{
555    let nodes = (q.query_nodes)();
556    let n = nodes.len();
557    let mut result = HashMap::new();
558
559    for node in &nodes {
560        let degree = (q.query_degree)(node) as f64;
561        let centrality = if n > 1 { degree / (n - 1) as f64 } else { 0.0 };
562        result.insert(node.value.identify().clone(), centrality);
563    }
564
565    result
566}
567
568/// Betweenness centrality using the Brandes BFS algorithm (unnormalized).
569///
570/// Returns the unnormalized betweenness score for each node. To normalize for
571/// an undirected graph with `n` nodes, divide by `(n-1)(n-2)/2`.
572///
573/// Uses the `weight` function to determine which edges are traversable (finite
574/// cost = reachable). All reachable edges are treated as unit-weight for the
575/// shortest-path counting phase.
576pub fn betweenness_centrality<V>(
577    q: &GraphQuery<V>,
578    weight: &TraversalWeight<V>,
579) -> HashMap<V::Id, f64>
580where
581    V: GraphValue + Clone,
582    V::Id: Clone + Eq + std::hash::Hash + Ord,
583{
584    let nodes = (q.query_nodes)();
585    let mut betweenness: HashMap<V::Id, f64> = nodes
586        .iter()
587        .map(|n| (n.value.identify().clone(), 0.0))
588        .collect();
589
590    for s in &nodes {
591        let s_id = s.value.identify().clone();
592
593        // BFS phase
594        let mut stack: Vec<Pattern<V>> = Vec::new();
595        let mut pred: HashMap<V::Id, Vec<Pattern<V>>> = nodes
596            .iter()
597            .map(|n| (n.value.identify().clone(), Vec::new()))
598            .collect();
599        let mut sigma: HashMap<V::Id, f64> = nodes
600            .iter()
601            .map(|n| (n.value.identify().clone(), 0.0))
602            .collect();
603        sigma.insert(s_id.clone(), 1.0);
604        let mut dist: HashMap<V::Id, i64> = nodes
605            .iter()
606            .map(|n| (n.value.identify().clone(), -1))
607            .collect();
608        dist.insert(s_id.clone(), 0);
609
610        let mut queue = VecDeque::new();
611        queue.push_back(s.clone());
612
613        while let Some(v) = queue.pop_front() {
614            stack.push(v.clone());
615            let v_id = v.value.identify().clone();
616            let v_dist = dist[&v_id];
617            let v_sigma = sigma[&v_id];
618
619            for (w, _cost) in reachable_neighbors(q, weight, &v) {
620                let w_id = w.value.identify().clone();
621                // First time visiting w?
622                if dist[&w_id] < 0 {
623                    queue.push_back(w.clone());
624                    *dist.get_mut(&w_id).unwrap() = v_dist + 1;
625                }
626                // On a shortest path through v?
627                if dist[&w_id] == v_dist + 1 {
628                    *sigma.get_mut(&w_id).unwrap() += v_sigma;
629                    pred.get_mut(&w_id).unwrap().push(v.clone());
630                }
631            }
632        }
633
634        // Back-propagation
635        let mut delta: HashMap<V::Id, f64> = nodes
636            .iter()
637            .map(|n| (n.value.identify().clone(), 0.0))
638            .collect();
639
640        while let Some(w) = stack.pop() {
641            let w_id = w.value.identify().clone();
642            for v in &pred[&w_id] {
643                let v_id = v.value.identify().clone();
644                let sigma_w = sigma[&w_id];
645                if sigma_w != 0.0 {
646                    let coeff = sigma[&v_id] / sigma_w * (1.0 + delta[&w_id]);
647                    *delta.get_mut(&v_id).unwrap() += coeff;
648                }
649            }
650            if w_id != s_id {
651                *betweenness.get_mut(&w_id).unwrap() += delta[&w_id];
652            }
653        }
654    }
655
656    betweenness
657}
658
659// ============================================================================
660// Context query helpers
661// ============================================================================
662
663/// Returns all containers of `element` that are classified as annotations.
664pub fn query_annotations_of<Extra, V>(
665    classifier: &GraphClassifier<Extra, V>,
666    q: &GraphQuery<V>,
667    element: &Pattern<V>,
668) -> Vec<Pattern<V>>
669where
670    V: GraphValue + Clone,
671{
672    (q.query_containers)(element)
673        .into_iter()
674        .filter(|c| matches!((classifier.classify)(c), GraphClass::GAnnotation))
675        .collect()
676}
677
678/// Returns all containers of `element` that are classified as walks.
679pub fn query_walks_containing<Extra, V>(
680    classifier: &GraphClassifier<Extra, V>,
681    q: &GraphQuery<V>,
682    element: &Pattern<V>,
683) -> Vec<Pattern<V>>
684where
685    V: GraphValue + Clone,
686{
687    (q.query_containers)(element)
688        .into_iter()
689        .filter(|c| matches!((classifier.classify)(c), GraphClass::GWalk))
690        .collect()
691}
692
693/// Returns all elements that share `container` with `element`, excluding `element` itself.
694///
695/// Co-membership is checked by identity (`V::Id`). The container's `elements` field
696/// is traversed directly — O(k) where k = number of elements in the container.
697pub fn query_co_members<V>(
698    _q: &GraphQuery<V>,
699    element: &Pattern<V>,
700    container: &Pattern<V>,
701) -> Vec<Pattern<V>>
702where
703    V: GraphValue + Clone,
704    V::Id: Clone + Eq + std::hash::Hash,
705{
706    let elem_id = element.value.identify();
707    container
708        .elements
709        .iter()
710        .filter(|e| e.value.identify() != elem_id)
711        .cloned()
712        .collect()
713}