Skip to main content

swh_graph/views/
subgraph.rs

1// Copyright (C) 2023-2026  The Software Heritage developers
2// See the AUTHORS file at the top-level directory of this distribution
3// License: GNU General Public License version 3, or any later version
4// See top-level LICENSE file for more information
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use anyhow::{anyhow, Result};
10use webgraph::traits::labels::SortedIterator;
11
12use crate::arc_iterators::FlattenedSuccessorsIterator;
13use crate::graph::*;
14use crate::properties;
15use crate::{NodeConstraint, NodeType};
16
17macro_rules! make_filtered_arcs_iterator {
18    ($name:ident, $inner:ident, $( $next:tt )*) => {
19        pub struct $name<
20            'a,
21            $inner: Iterator<Item = NodeId> + 'a,
22            NodeFilter: Fn(NodeId) -> bool,
23            ArcFilter: Fn(NodeId, NodeId) -> bool,
24        > {
25            inner: $inner,
26            node: NodeId,
27            node_filter: &'a NodeFilter,
28            arc_filter: &'a ArcFilter,
29        }
30
31        impl<
32            'a,
33            $inner: Iterator<Item = NodeId> + 'a,
34            NodeFilter: Fn(NodeId) -> bool,
35            ArcFilter: Fn(NodeId, NodeId) -> bool,
36        > Iterator for $name<'a, $inner, NodeFilter, ArcFilter> {
37            type Item = $inner::Item;
38
39            $( $next )*
40        }
41
42        // SAFETY: filtering out elements out of an iterator preserves sortedness
43        unsafe impl<
44            'a,
45            $inner: SortedIterator<Item = NodeId> + 'a,
46            NodeFilter: Fn(NodeId) -> bool,
47            ArcFilter: Fn(NodeId, NodeId) -> bool,
48        > SortedIterator for $name<'a, $inner, NodeFilter, ArcFilter> {
49        }
50    }
51}
52
53make_filtered_arcs_iterator! {
54    FilteredSuccessors,
55    Successors,
56    fn next(&mut self) -> Option<Self::Item> {
57        if !(self.node_filter)(self.node) {
58            return None;
59        }
60
61        self.inner
62            .by_ref()
63            .find(|&dst| (self.node_filter)(dst) && (self.arc_filter)(self.node, dst))
64    }
65}
66make_filtered_arcs_iterator! {
67    FilteredPredecessors,
68    Predecessors,
69    fn next(&mut self) -> Option<Self::Item> {
70        if !(self.node_filter)(self.node) {
71            return None;
72        }
73
74        self.inner
75            .by_ref()
76            .find(|&src| (self.node_filter)(src) && (self.arc_filter)(src, self.node))
77    }
78}
79
80macro_rules! make_filtered_labeled_arcs_iterator {
81    ($name:ident, $inner:ident, $( $next:tt )*) => {
82        pub struct $name<
83            'a,
84            Labels,
85            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
86            NodeFilter: Fn(NodeId) -> bool,
87            ArcFilter: Fn(NodeId, NodeId) -> bool,
88        > {
89            inner: $inner,
90            node: NodeId,
91            node_filter: &'a NodeFilter,
92            arc_filter: &'a ArcFilter,
93        }
94
95        impl<
96            'a,
97            Labels,
98            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
99            NodeFilter: Fn(NodeId) -> bool,
100            ArcFilter: Fn(NodeId, NodeId) -> bool,
101        > Iterator for $name<'a, Labels, $inner, NodeFilter, ArcFilter> {
102            type Item = $inner::Item;
103
104            $( $next )*
105        }
106
107        // SAFETY: filtering out elements out of an iterator preserves sortedness
108        // 'Labels' itself does not need to be sorted because we only implement
109        // SortedIterator on the outer iterator, not in the inner one.
110        unsafe impl<
111            'a,
112            Labels,
113            $inner: SortedIterator<Item = (NodeId, Labels)> + 'a,
114            NodeFilter: Fn(NodeId) -> bool,
115            ArcFilter: Fn(NodeId, NodeId) -> bool,
116        > SortedIterator for $name<'a, Labels, $inner, NodeFilter, ArcFilter>
117        {
118        }
119
120        impl<
121            'a,
122            Labels: IntoIterator,
123            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
124            NodeFilter: Fn(NodeId) -> bool,
125            ArcFilter: Fn(NodeId, NodeId) -> bool,
126        > IntoFlattenedLabeledArcsIterator<<Labels as IntoIterator>::Item> for $name<'a, Labels, $inner, NodeFilter, ArcFilter> {
127            type Flattened = FlattenedSuccessorsIterator<Self>;
128
129            fn flatten_labels(self) -> Self::Flattened {
130                FlattenedSuccessorsIterator::new(self)
131            }
132        }
133    }
134}
135
136make_filtered_labeled_arcs_iterator! {
137    FilteredLabeledSuccessors,
138    LabeledSuccessors,
139    fn next(&mut self) -> Option<Self::Item> {
140        if !(self.node_filter)(self.node) {
141            return None;
142        }
143        for (dst, label) in self.inner.by_ref() {
144            if (self.node_filter)(dst) && (self.arc_filter)(self.node, dst) {
145                return Some((dst, label))
146            }
147        }
148        None
149    }
150}
151make_filtered_labeled_arcs_iterator! {
152    FilteredLabeledPredecessors,
153    LabeledPredecessors,
154    fn next(&mut self) -> Option<Self::Item> {
155        if !(self.node_filter)(self.node) {
156            return None;
157        }
158        for (src, label) in self.inner.by_ref() {
159            if (self.node_filter)(src) && (self.arc_filter)(src, self.node) {
160                return Some((src, label))
161            }
162        }
163        None
164    }
165}
166
167/// A view over [`SwhGraph`] and related traits, that filters out some nodes and arcs
168/// based on arbitrary closures.
169#[derive(Clone, Debug)]
170pub struct Subgraph<G: SwhGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
171{
172    pub graph: G,
173    pub node_filter: NodeFilter,
174    pub arc_filter: ArcFilter,
175    pub num_nodes_by_type: Option<HashMap<NodeType, usize>>,
176    pub num_arcs_by_type: Option<HashMap<(NodeType, NodeType), usize>>,
177}
178
179impl<G: SwhGraph, NodeFilter: Fn(usize) -> bool> Subgraph<G, NodeFilter, fn(usize, usize) -> bool> {
180    /// Create a [Subgraph] keeping only nodes matching a given node filter function.
181    ///
182    /// Shorthand for `Subgraph { graph, node_filter, arc_filter: |_src, _dst| true }`
183    pub fn with_node_filter(
184        graph: G,
185        node_filter: NodeFilter,
186    ) -> Subgraph<G, NodeFilter, fn(usize, usize) -> bool> {
187        Subgraph {
188            graph,
189            node_filter,
190            arc_filter: |_src, _dst| true,
191            num_nodes_by_type: None,
192            num_arcs_by_type: None,
193        }
194    }
195}
196
197impl<G: SwhGraph, ArcFilter: Fn(usize, usize) -> bool> Subgraph<G, fn(usize) -> bool, ArcFilter> {
198    /// Create a [Subgraph] keeping only arcs matching a arc filter function.
199    ///
200    /// Shorthand for `Subgraph { graph, node_filter: |_node| true, arc_filter }`
201    pub fn with_arc_filter(
202        graph: G,
203        arc_filter: ArcFilter,
204    ) -> Subgraph<G, fn(usize) -> bool, ArcFilter> {
205        Subgraph {
206            graph,
207            node_filter: |_node| true,
208            arc_filter,
209            num_nodes_by_type: None,
210            num_arcs_by_type: None,
211        }
212    }
213}
214
215impl<G> Subgraph<G, fn(usize) -> bool, fn(usize, usize) -> bool>
216where
217    G: SwhGraphWithProperties + Clone,
218    <G as SwhGraphWithProperties>::Maps: properties::Maps,
219{
220    /// Create a [Subgraph] keeping only nodes matching a given node constraint.
221    #[allow(clippy::type_complexity)]
222    pub fn with_node_constraint(
223        graph: G,
224        node_constraint: NodeConstraint,
225    ) -> Subgraph<G, impl Fn(NodeId) -> bool, fn(usize, usize) -> bool> {
226        Subgraph {
227            graph: graph.clone(),
228            num_nodes_by_type: graph.num_nodes_by_type().ok().map(|counts| {
229                counts
230                    .into_iter()
231                    .filter(|&(type_, _count)| node_constraint.matches(type_))
232                    .collect()
233            }),
234            num_arcs_by_type: graph.num_arcs_by_type().ok().map(|counts| {
235                counts
236                    .into_iter()
237                    .filter(|&((src_type, dst_type), _count)| {
238                        node_constraint.matches(src_type) && node_constraint.matches(dst_type)
239                    })
240                    .collect()
241            }),
242            node_filter: move |node| node_constraint.matches(graph.properties().node_type(node)),
243            arc_filter: |_src, _dst| true,
244        }
245    }
246}
247
248impl<G: SwhGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool> SwhGraph
249    for Subgraph<G, NodeFilter, ArcFilter>
250{
251    #[inline(always)]
252    fn path(&self) -> &Path {
253        self.graph.path()
254    }
255    #[inline(always)]
256    fn is_transposed(&self) -> bool {
257        self.graph.is_transposed()
258    }
259    // Note: this return the number or nodes in the original graph, before
260    // subgraph filtering.
261    #[inline(always)]
262    fn num_nodes(&self) -> usize {
263        self.graph.num_nodes()
264    }
265    #[inline(always)]
266    fn actual_num_nodes(&self) -> Result<usize> {
267        self.num_nodes_by_type
268            .as_ref()
269            .map(|num_nodes_by_type| num_nodes_by_type.values().sum())
270            .ok_or_else(|| anyhow!("Subgraph::actual_num_nodes() is only available when constructed with Subgraph::with_node_constraint on a graph with num_nodes_by_type defined"))
271    }
272    #[inline(always)]
273    fn has_node(&self, node_id: NodeId) -> bool {
274        (self.node_filter)(node_id)
275    }
276    // Note: this return the number or arcs in the original graph, before
277    // subgraph filtering.
278    #[inline(always)]
279    fn num_arcs(&self) -> u64 {
280        self.graph.num_arcs()
281    }
282    fn num_nodes_by_type(&self) -> Result<HashMap<NodeType, usize>> {
283        self.num_nodes_by_type.clone().ok_or(anyhow!(
284            "num_nodes_by_type is not supported by this Subgraph (if possible, use Subgraph::with_node_constraint to build it)"
285        ))
286    }
287    fn num_arcs_by_type(&self) -> Result<HashMap<(NodeType, NodeType), usize>> {
288        self.num_arcs_by_type.clone().ok_or(anyhow!(
289            "num_arcs_by_type is not supported by this Subgraph (if possible, use Subgraph::with_node_constraint to build it)"
290        ))
291    }
292    #[inline(always)]
293    fn has_arc(&self, src_node_id: NodeId, dst_node_id: NodeId) -> bool {
294        (self.node_filter)(src_node_id)
295            && (self.node_filter)(dst_node_id)
296            && (self.arc_filter)(src_node_id, dst_node_id)
297            && self.graph.has_arc(src_node_id, dst_node_id)
298    }
299}
300
301impl<G: SwhForwardGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
302    SwhForwardGraph for Subgraph<G, NodeFilter, ArcFilter>
303{
304    type Successors<'succ>
305        = FilteredSuccessors<
306        'succ,
307        <<G as SwhForwardGraph>::Successors<'succ> as IntoIterator>::IntoIter,
308        NodeFilter,
309        ArcFilter,
310    >
311    where
312        Self: 'succ;
313
314    #[inline(always)]
315    fn successors(&self, node_id: NodeId) -> Self::Successors<'_> {
316        FilteredSuccessors {
317            inner: self.graph.successors(node_id).into_iter(),
318            node: node_id,
319            node_filter: &self.node_filter,
320            arc_filter: &self.arc_filter,
321        }
322    }
323    #[inline(always)]
324    fn outdegree(&self, node_id: NodeId) -> usize {
325        self.successors(node_id).count()
326    }
327}
328
329impl<G: SwhBackwardGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
330    SwhBackwardGraph for Subgraph<G, NodeFilter, ArcFilter>
331{
332    type Predecessors<'succ>
333        = FilteredPredecessors<
334        'succ,
335        <<G as SwhBackwardGraph>::Predecessors<'succ> as IntoIterator>::IntoIter,
336        NodeFilter,
337        ArcFilter,
338    >
339    where
340        Self: 'succ;
341
342    #[inline(always)]
343    fn predecessors(&self, node_id: NodeId) -> Self::Predecessors<'_> {
344        FilteredPredecessors {
345            inner: self.graph.predecessors(node_id).into_iter(),
346            node: node_id,
347            node_filter: &self.node_filter,
348            arc_filter: &self.arc_filter,
349        }
350    }
351    #[inline(always)]
352    fn indegree(&self, node_id: NodeId) -> usize {
353        self.predecessors(node_id).count()
354    }
355}
356
357impl<
358        G: SwhLabeledForwardGraph,
359        NodeFilter: Fn(usize) -> bool,
360        ArcFilter: Fn(usize, usize) -> bool,
361    > SwhLabeledForwardGraph for Subgraph<G, NodeFilter, ArcFilter>
362{
363    type LabeledArcs<'arc>
364        = <G as SwhLabeledForwardGraph>::LabeledArcs<'arc>
365    where
366        Self: 'arc;
367    type LabeledSuccessors<'node>
368        = FilteredLabeledSuccessors<
369        'node,
370        Self::LabeledArcs<'node>,
371        <<G as SwhLabeledForwardGraph>::LabeledSuccessors<'node> as IntoIterator>::IntoIter,
372        NodeFilter,
373        ArcFilter,
374    >
375    where
376        Self: 'node;
377
378    #[inline(always)]
379    fn untyped_labeled_successors(&self, node_id: NodeId) -> Self::LabeledSuccessors<'_> {
380        FilteredLabeledSuccessors {
381            inner: self.graph.untyped_labeled_successors(node_id).into_iter(),
382            node: node_id,
383            node_filter: &self.node_filter,
384            arc_filter: &self.arc_filter,
385        }
386    }
387}
388
389impl<
390        G: SwhLabeledBackwardGraph,
391        NodeFilter: Fn(usize) -> bool,
392        ArcFilter: Fn(usize, usize) -> bool,
393    > SwhLabeledBackwardGraph for Subgraph<G, NodeFilter, ArcFilter>
394{
395    type LabeledArcs<'arc>
396        = <G as SwhLabeledBackwardGraph>::LabeledArcs<'arc>
397    where
398        Self: 'arc;
399    type LabeledPredecessors<'node>
400        = FilteredLabeledPredecessors<
401        'node,
402        Self::LabeledArcs<'node>,
403        <<G as SwhLabeledBackwardGraph>::LabeledPredecessors<'node> as IntoIterator>::IntoIter,
404        NodeFilter,
405        ArcFilter,
406    >
407    where
408        Self: 'node;
409
410    #[inline(always)]
411    fn untyped_labeled_predecessors(&self, node_id: NodeId) -> Self::LabeledPredecessors<'_> {
412        FilteredLabeledPredecessors {
413            inner: self.graph.untyped_labeled_predecessors(node_id).into_iter(),
414            node: node_id,
415            node_filter: &self.node_filter,
416            arc_filter: &self.arc_filter,
417        }
418    }
419}
420
421impl<
422        G: SwhGraphWithProperties,
423        NodeFilter: Fn(usize) -> bool,
424        ArcFilter: Fn(usize, usize) -> bool,
425    > SwhGraphWithProperties for Subgraph<G, NodeFilter, ArcFilter>
426{
427    type Maps = <G as SwhGraphWithProperties>::Maps;
428    type Timestamps = <G as SwhGraphWithProperties>::Timestamps;
429    type Persons = <G as SwhGraphWithProperties>::Persons;
430    type Contents = <G as SwhGraphWithProperties>::Contents;
431    type Strings = <G as SwhGraphWithProperties>::Strings;
432    type LabelNames = <G as SwhGraphWithProperties>::LabelNames;
433
434    #[inline(always)]
435    fn properties(
436        &self,
437    ) -> &properties::SwhGraphProperties<
438        Self::Maps,
439        Self::Timestamps,
440        Self::Persons,
441        Self::Contents,
442        Self::Strings,
443        Self::LabelNames,
444    > {
445        self.graph.properties()
446    }
447}