Skip to main content

pattern_core/graph/
graph_query.rs

1//! GraphQuery: portable, composable graph query interface.
2//!
3//! Ported from `Pattern.Graph.GraphQuery` in the Haskell reference implementation.
4//!
5//! # Overview
6//!
7//! `GraphQuery<V>` is a struct-of-closures representing the complete query interface
8//! over a graph. Algorithms operate against this interface, not against any specific
9//! backing representation. This enables the same algorithm code to run against
10//! `PatternGraph`, database-backed stores, or any other structure that can produce
11//! the nine required closures.
12//!
13//! # Structural Invariants
14//!
15//! Implementations of `GraphQuery<V>` must uphold these invariants:
16//! 1. `query_source(r) = Some(s)` implies `s ∈ query_nodes()`
17//! 2. `query_target(r) = Some(t)` implies `t ∈ query_nodes()`
18//! 3. `r ∈ query_incident_rels(n)` implies `query_source(r) = Some(n) || query_target(r) = Some(n)`
19//! 4. `query_degree(n) == query_incident_rels(n).len()` (default; may be faster indexed)
20//! 5. `query_node_by_id(n.value.identify()) = Some(n)` for all `n ∈ query_nodes()`
21//! 6. `query_relationship_by_id(r.value.identify()) = Some(r)` for all `r ∈ query_relationships()`
22//! 7. `query_containers` returns only **direct** containers — not transitive containment
23
24use std::collections::HashMap;
25
26use crate::graph::graph_classifier::GraphValue;
27use crate::pattern::Pattern;
28
29// ============================================================================
30// TraversalDirection
31// ============================================================================
32
33/// Which direction along a directed relationship is being traversed.
34///
35/// Used by [`TraversalWeight`] functions to return per-direction costs.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum TraversalDirection {
38    /// Source → Target: follow the relationship in its declared direction.
39    Forward,
40    /// Target → Source: traverse the relationship in reverse.
41    Backward,
42}
43
44// ============================================================================
45// TraversalWeight type alias  (Rc default; Arc under thread-safe feature)
46// ============================================================================
47
48/// A cost function for traversing a relationship in a given direction.
49///
50/// Returns a non-negative `f64`:
51/// - Finite ≥ 0 — traversal allowed at that cost
52/// - `f64::INFINITY` — traversal is blocked (impassable)
53/// - Negative values are not supported; algorithm behavior is undefined
54///
55/// # Thread Safety
56///
57/// By default uses `Rc` (single-threaded). Enable the `thread-safe` feature
58/// to use `Arc` with `Send + Sync` bounds.
59#[cfg(not(feature = "thread-safe"))]
60pub type TraversalWeight<V> = std::rc::Rc<dyn Fn(&Pattern<V>, TraversalDirection) -> f64>;
61
62#[cfg(feature = "thread-safe")]
63pub type TraversalWeight<V> =
64    std::sync::Arc<dyn Fn(&Pattern<V>, TraversalDirection) -> f64 + Send + Sync>;
65
66// ============================================================================
67// Canonical weight functions
68// ============================================================================
69
70/// Uniform cost 1.0 in both directions — treat all edges as bidirectional.
71#[cfg(not(feature = "thread-safe"))]
72pub fn undirected<V>() -> TraversalWeight<V> {
73    std::rc::Rc::new(|_rel: &Pattern<V>, _dir: TraversalDirection| 1.0)
74}
75
76#[cfg(feature = "thread-safe")]
77pub fn undirected<V: Send + Sync + 'static>() -> TraversalWeight<V> {
78    std::sync::Arc::new(|_rel: &Pattern<V>, _dir: TraversalDirection| 1.0)
79}
80
81/// Forward cost 1.0, Backward cost INFINITY — follow edge direction only.
82#[cfg(not(feature = "thread-safe"))]
83pub fn directed<V>() -> TraversalWeight<V> {
84    std::rc::Rc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
85        TraversalDirection::Forward => 1.0,
86        TraversalDirection::Backward => f64::INFINITY,
87    })
88}
89
90#[cfg(feature = "thread-safe")]
91pub fn directed<V: Send + Sync + 'static>() -> TraversalWeight<V> {
92    std::sync::Arc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
93        TraversalDirection::Forward => 1.0,
94        TraversalDirection::Backward => f64::INFINITY,
95    })
96}
97
98/// Forward cost INFINITY, Backward cost 1.0 — follow edges in reverse only.
99#[cfg(not(feature = "thread-safe"))]
100pub fn directed_reverse<V>() -> TraversalWeight<V> {
101    std::rc::Rc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
102        TraversalDirection::Forward => f64::INFINITY,
103        TraversalDirection::Backward => 1.0,
104    })
105}
106
107#[cfg(feature = "thread-safe")]
108pub fn directed_reverse<V: Send + Sync + 'static>() -> TraversalWeight<V> {
109    std::sync::Arc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
110        TraversalDirection::Forward => f64::INFINITY,
111        TraversalDirection::Backward => 1.0,
112    })
113}
114
115// ============================================================================
116// GraphQuery struct (Rc default; Arc under thread-safe feature)
117// ============================================================================
118
119/// Portable graph query interface: a struct of nine closures.
120///
121/// All graph algorithms operate against `GraphQuery<V>`, not against any specific
122/// backing representation. Cloning is cheap — it increments reference counts only.
123///
124/// # Construction
125///
126/// Use [`crate::from_pattern_graph`] to wrap a [`crate::PatternGraph`], or build
127/// manually by providing all nine closure fields.
128///
129/// # Thread Safety
130///
131/// By default uses `Rc` (single-threaded). Enable the `thread-safe` Cargo feature
132/// to use `Arc` with `Send + Sync` bounds throughout.
133#[cfg(not(feature = "thread-safe"))]
134#[allow(clippy::type_complexity)]
135pub struct GraphQuery<V: GraphValue> {
136    /// Returns all node patterns in the graph.
137    pub query_nodes: std::rc::Rc<dyn Fn() -> Vec<Pattern<V>>>,
138    /// Returns all relationship patterns in the graph.
139    pub query_relationships: std::rc::Rc<dyn Fn() -> Vec<Pattern<V>>>,
140    /// Returns all relationships incident to the given node (as source or target).
141    pub query_incident_rels: std::rc::Rc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>>>,
142    /// Returns the source node of a relationship, or `None` if not available.
143    pub query_source: std::rc::Rc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>>>,
144    /// Returns the target node of a relationship, or `None` if not available.
145    pub query_target: std::rc::Rc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>>>,
146    /// Returns the count of incident relationships for a node.
147    pub query_degree: std::rc::Rc<dyn Fn(&Pattern<V>) -> usize>,
148    /// Returns the node with the given identity, or `None`.
149    pub query_node_by_id: std::rc::Rc<dyn Fn(&V::Id) -> Option<Pattern<V>>>,
150    /// Returns the relationship with the given identity, or `None`.
151    pub query_relationship_by_id: std::rc::Rc<dyn Fn(&V::Id) -> Option<Pattern<V>>>,
152    /// Returns all direct containers of the given element (relationships, walks, annotations).
153    pub query_containers: std::rc::Rc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>>>,
154}
155
156#[cfg(feature = "thread-safe")]
157#[allow(clippy::type_complexity)]
158pub struct GraphQuery<V: GraphValue> {
159    /// Returns all node patterns in the graph.
160    pub query_nodes: std::sync::Arc<dyn Fn() -> Vec<Pattern<V>> + Send + Sync>,
161    /// Returns all relationship patterns in the graph.
162    pub query_relationships: std::sync::Arc<dyn Fn() -> Vec<Pattern<V>> + Send + Sync>,
163    /// Returns all relationships incident to the given node (as source or target).
164    pub query_incident_rels: std::sync::Arc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>> + Send + Sync>,
165    /// Returns the source node of a relationship, or `None` if not available.
166    pub query_source: std::sync::Arc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>> + Send + Sync>,
167    /// Returns the target node of a relationship, or `None` if not available.
168    pub query_target: std::sync::Arc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>> + Send + Sync>,
169    /// Returns the count of incident relationships for a node.
170    pub query_degree: std::sync::Arc<dyn Fn(&Pattern<V>) -> usize + Send + Sync>,
171    /// Returns the node with the given identity, or `None`.
172    pub query_node_by_id: std::sync::Arc<dyn Fn(&V::Id) -> Option<Pattern<V>> + Send + Sync>,
173    /// Returns the relationship with the given identity, or `None`.
174    pub query_relationship_by_id:
175        std::sync::Arc<dyn Fn(&V::Id) -> Option<Pattern<V>> + Send + Sync>,
176    /// Returns all direct containers of the given element (relationships, walks, annotations).
177    pub query_containers: std::sync::Arc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>> + Send + Sync>,
178}
179
180// ============================================================================
181// Manual Clone for GraphQuery (pointer clone only — no data copy)
182// ============================================================================
183
184#[cfg(not(feature = "thread-safe"))]
185impl<V: GraphValue> Clone for GraphQuery<V> {
186    fn clone(&self) -> Self {
187        GraphQuery {
188            query_nodes: std::rc::Rc::clone(&self.query_nodes),
189            query_relationships: std::rc::Rc::clone(&self.query_relationships),
190            query_incident_rels: std::rc::Rc::clone(&self.query_incident_rels),
191            query_source: std::rc::Rc::clone(&self.query_source),
192            query_target: std::rc::Rc::clone(&self.query_target),
193            query_degree: std::rc::Rc::clone(&self.query_degree),
194            query_node_by_id: std::rc::Rc::clone(&self.query_node_by_id),
195            query_relationship_by_id: std::rc::Rc::clone(&self.query_relationship_by_id),
196            query_containers: std::rc::Rc::clone(&self.query_containers),
197        }
198    }
199}
200
201#[cfg(feature = "thread-safe")]
202impl<V: GraphValue> Clone for GraphQuery<V> {
203    fn clone(&self) -> Self {
204        GraphQuery {
205            query_nodes: std::sync::Arc::clone(&self.query_nodes),
206            query_relationships: std::sync::Arc::clone(&self.query_relationships),
207            query_incident_rels: std::sync::Arc::clone(&self.query_incident_rels),
208            query_source: std::sync::Arc::clone(&self.query_source),
209            query_target: std::sync::Arc::clone(&self.query_target),
210            query_degree: std::sync::Arc::clone(&self.query_degree),
211            query_node_by_id: std::sync::Arc::clone(&self.query_node_by_id),
212            query_relationship_by_id: std::sync::Arc::clone(&self.query_relationship_by_id),
213            query_containers: std::sync::Arc::clone(&self.query_containers),
214        }
215    }
216}
217
218// ============================================================================
219// frame_query combinator
220// ============================================================================
221
222/// Restrict a `GraphQuery<V>` to elements satisfying `include`.
223///
224/// The returned `GraphQuery<V>` is itself a full query interface. All seven
225/// structural invariants are preserved if they hold for `base`.
226///
227/// - `query_nodes` / `query_relationships` — filtered by predicate
228/// - `query_incident_rels(n)` — base incident rels where both source AND target satisfy predicate
229/// - `query_source` / `query_target` — delegated unchanged to base
230/// - `query_degree(n)` — count of filtered incident rels
231/// - `query_node_by_id(i)` — base lookup; returns `None` if result doesn't satisfy predicate
232/// - `query_relationship_by_id(i)` — base lookup; returns `None` if result doesn't satisfy predicate
233/// - `query_containers(p)` — base containers filtered by predicate
234///
235/// Rc and Arc variants are intentionally separate (no macro): only one is compiled per build,
236/// and Rust does not abstract over Rc/Arc here without macros or runtime indirection.
237#[cfg(not(feature = "thread-safe"))]
238#[allow(clippy::type_complexity)]
239pub fn frame_query<V>(
240    include: std::rc::Rc<dyn Fn(&Pattern<V>) -> bool>,
241    base: GraphQuery<V>,
242) -> GraphQuery<V>
243where
244    V: GraphValue + Clone + 'static,
245{
246    use std::rc::Rc;
247
248    let inc1 = Rc::clone(&include);
249    let query_nodes = Rc::new(move || {
250        (base.query_nodes)()
251            .into_iter()
252            .filter(|n| inc1(n))
253            .collect()
254    });
255
256    let inc2 = Rc::clone(&include);
257    let base_rels = Rc::clone(&base.query_relationships);
258    let query_relationships =
259        Rc::new(move || base_rels().into_iter().filter(|r| inc2(r)).collect());
260
261    let inc3 = Rc::clone(&include);
262    let base_inc = Rc::clone(&base.query_incident_rels);
263    let base_src = Rc::clone(&base.query_source);
264    let base_tgt = Rc::clone(&base.query_target);
265    let query_incident_rels = Rc::new(move |node: &Pattern<V>| {
266        base_inc(node)
267            .into_iter()
268            .filter(|rel| {
269                let src_ok = base_src(rel).as_ref().map(|s| inc3(s)).unwrap_or(false);
270                let tgt_ok = base_tgt(rel).as_ref().map(|t| inc3(t)).unwrap_or(false);
271                src_ok && tgt_ok
272            })
273            .collect()
274    });
275
276    let query_source = Rc::clone(&base.query_source);
277    let query_target = Rc::clone(&base.query_target);
278
279    let inc4 = Rc::clone(&include);
280    let base_inc2 = Rc::clone(&base.query_incident_rels);
281    let base_src2 = Rc::clone(&base.query_source);
282    let base_tgt2 = Rc::clone(&base.query_target);
283    let query_degree = Rc::new(move |node: &Pattern<V>| {
284        base_inc2(node)
285            .into_iter()
286            .filter(|rel| {
287                let src_ok = base_src2(rel).as_ref().map(|s| inc4(s)).unwrap_or(false);
288                let tgt_ok = base_tgt2(rel).as_ref().map(|t| inc4(t)).unwrap_or(false);
289                src_ok && tgt_ok
290            })
291            .count()
292    });
293
294    let inc5 = Rc::clone(&include);
295    let base_nbi = Rc::clone(&base.query_node_by_id);
296    let query_node_by_id = Rc::new(move |id: &V::Id| base_nbi(id).filter(|n| inc5(n)));
297
298    let inc6 = Rc::clone(&include);
299    let base_rbi = Rc::clone(&base.query_relationship_by_id);
300    let query_relationship_by_id = Rc::new(move |id: &V::Id| base_rbi(id).filter(|r| inc6(r)));
301
302    let inc7 = Rc::clone(&include);
303    let base_cont = Rc::clone(&base.query_containers);
304    let query_containers = Rc::new(move |element: &Pattern<V>| {
305        base_cont(element).into_iter().filter(|c| inc7(c)).collect()
306    });
307
308    GraphQuery {
309        query_nodes,
310        query_relationships,
311        query_incident_rels,
312        query_source,
313        query_target,
314        query_degree,
315        query_node_by_id,
316        query_relationship_by_id,
317        query_containers,
318    }
319}
320
321#[cfg(feature = "thread-safe")]
322#[allow(clippy::type_complexity)]
323pub fn frame_query<V>(
324    include: std::sync::Arc<dyn Fn(&Pattern<V>) -> bool + Send + Sync>,
325    base: GraphQuery<V>,
326) -> GraphQuery<V>
327where
328    V: GraphValue + Clone + Send + Sync + 'static,
329    V::Id: Clone + Send + Sync + 'static,
330{
331    use std::sync::Arc;
332
333    let inc1 = Arc::clone(&include);
334    let query_nodes = Arc::new(move || {
335        (base.query_nodes)()
336            .into_iter()
337            .filter(|n| inc1(n))
338            .collect()
339    });
340
341    let inc2 = Arc::clone(&include);
342    let base_rels = Arc::clone(&base.query_relationships);
343    let query_relationships =
344        Arc::new(move || base_rels().into_iter().filter(|r| inc2(r)).collect());
345
346    let inc3 = Arc::clone(&include);
347    let base_inc = Arc::clone(&base.query_incident_rels);
348    let base_src = Arc::clone(&base.query_source);
349    let base_tgt = Arc::clone(&base.query_target);
350    let query_incident_rels = Arc::new(move |node: &Pattern<V>| {
351        base_inc(node)
352            .into_iter()
353            .filter(|rel| {
354                let src_ok = base_src(rel).as_ref().map(|s| inc3(s)).unwrap_or(false);
355                let tgt_ok = base_tgt(rel).as_ref().map(|t| inc3(t)).unwrap_or(false);
356                src_ok && tgt_ok
357            })
358            .collect()
359    });
360
361    let query_source = Arc::clone(&base.query_source);
362    let query_target = Arc::clone(&base.query_target);
363
364    let inc4 = Arc::clone(&include);
365    let base_inc2 = Arc::clone(&base.query_incident_rels);
366    let base_src2 = Arc::clone(&base.query_source);
367    let base_tgt2 = Arc::clone(&base.query_target);
368    let query_degree = Arc::new(move |node: &Pattern<V>| {
369        base_inc2(node)
370            .into_iter()
371            .filter(|rel| {
372                let src_ok = base_src2(rel).as_ref().map(|s| inc4(s)).unwrap_or(false);
373                let tgt_ok = base_tgt2(rel).as_ref().map(|t| inc4(t)).unwrap_or(false);
374                src_ok && tgt_ok
375            })
376            .count()
377    });
378
379    let inc5 = Arc::clone(&include);
380    let base_nbi = Arc::clone(&base.query_node_by_id);
381    let query_node_by_id = Arc::new(move |id: &V::Id| base_nbi(id).filter(|n| inc5(n)));
382
383    let inc6 = Arc::clone(&include);
384    let base_rbi = Arc::clone(&base.query_relationship_by_id);
385    let query_relationship_by_id = Arc::new(move |id: &V::Id| base_rbi(id).filter(|r| inc6(r)));
386
387    let inc7 = Arc::clone(&include);
388    let base_cont = Arc::clone(&base.query_containers);
389    let query_containers = Arc::new(move |element: &Pattern<V>| {
390        base_cont(element).into_iter().filter(|c| inc7(c)).collect()
391    });
392
393    GraphQuery {
394        query_nodes,
395        query_relationships,
396        query_incident_rels,
397        query_source,
398        query_target,
399        query_degree,
400        query_node_by_id,
401        query_relationship_by_id,
402        query_containers,
403    }
404}
405
406// ============================================================================
407// memoize_incident_rels combinator
408// ============================================================================
409
410/// Wrap `query_incident_rels` and `query_degree` with an eager HashMap cache.
411///
412/// The cache is built upfront at construction time by calling `query_nodes()`
413/// and `query_incident_rels(n)` for each node. All other fields pass through
414/// unchanged.
415///
416/// Recommended for algorithms that call `query_incident_rels` repeatedly
417/// (e.g., betweenness centrality).
418///
419/// # Cache Semantics
420///
421/// - Eager: the full cache is built when `memoize_incident_rels` is called.
422/// - Per-`GraphQuery` cache — not global.
423/// - No `RefCell` needed (immutable after construction).
424#[cfg(not(feature = "thread-safe"))]
425pub fn memoize_incident_rels<V>(base: GraphQuery<V>) -> GraphQuery<V>
426where
427    V: GraphValue + Clone + 'static,
428    V::Id: Clone + Eq + std::hash::Hash + 'static,
429{
430    use std::rc::Rc;
431
432    // Build the cache eagerly from all nodes.
433    let nodes = (base.query_nodes)();
434    let mut cache: HashMap<V::Id, Vec<Pattern<V>>> = HashMap::new();
435    for node in &nodes {
436        let id = node.value.identify().clone();
437        let rels = (base.query_incident_rels)(node);
438        cache.insert(id, rels);
439    }
440    let cache = Rc::new(cache);
441
442    let cache1 = Rc::clone(&cache);
443    let query_incident_rels = Rc::new(move |node: &Pattern<V>| {
444        cache1
445            .get(node.value.identify())
446            .cloned()
447            .unwrap_or_default()
448    });
449
450    let cache2 = Rc::clone(&cache);
451    let query_degree = Rc::new(move |node: &Pattern<V>| {
452        cache2
453            .get(node.value.identify())
454            .map(|v| v.len())
455            .unwrap_or(0)
456    });
457
458    GraphQuery {
459        query_nodes: base.query_nodes,
460        query_relationships: base.query_relationships,
461        query_incident_rels,
462        query_source: base.query_source,
463        query_target: base.query_target,
464        query_degree,
465        query_node_by_id: base.query_node_by_id,
466        query_relationship_by_id: base.query_relationship_by_id,
467        query_containers: base.query_containers,
468    }
469}
470
471#[cfg(feature = "thread-safe")]
472pub fn memoize_incident_rels<V>(base: GraphQuery<V>) -> GraphQuery<V>
473where
474    V: GraphValue + Clone + Send + Sync + 'static,
475    V::Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
476{
477    use std::sync::Arc;
478
479    let nodes = (base.query_nodes)();
480    let mut cache: HashMap<V::Id, Vec<Pattern<V>>> = HashMap::new();
481    for node in &nodes {
482        let id = node.value.identify().clone();
483        let rels = (base.query_incident_rels)(node);
484        cache.insert(id, rels);
485    }
486    let cache = Arc::new(cache);
487
488    let cache1 = Arc::clone(&cache);
489    let query_incident_rels = Arc::new(move |node: &Pattern<V>| {
490        cache1
491            .get(node.value.identify())
492            .cloned()
493            .unwrap_or_default()
494    });
495
496    let cache2 = Arc::clone(&cache);
497    let query_degree = Arc::new(move |node: &Pattern<V>| {
498        cache2
499            .get(node.value.identify())
500            .map(|v| v.len())
501            .unwrap_or(0)
502    });
503
504    GraphQuery {
505        query_nodes: base.query_nodes,
506        query_relationships: base.query_relationships,
507        query_incident_rels,
508        query_source: base.query_source,
509        query_target: base.query_target,
510        query_degree,
511        query_node_by_id: base.query_node_by_id,
512        query_relationship_by_id: base.query_relationship_by_id,
513        query_containers: base.query_containers,
514    }
515}