1use std::{
2 any::{Any, TypeId},
3 collections::VecDeque,
4 fmt::Debug,
5};
6
7use atomic_refcell::AtomicRefCell;
8use enum_map::{Enum, EnumMap, enum_map};
9use fixedbitset::FixedBitSet;
10use itertools::Itertools;
11use petgraph::{
12 Direction,
13 adj::UnweightedList,
14 algo::{TarjanScc, tred},
15 data::Build,
16 graph::{DiGraph, NodeIndex},
17 stable_graph::StableDiGraph,
18 visit::{DfsPostOrder, EdgeFiltered, EdgeRef, IntoNeighbors, VisitMap, Visitable},
19};
20use rustc_hash::{FxHashMap, FxHashSet};
21
22use crate::{arena::Arena, parse::Info};
23
24use super::{
25 spec::{ResolvedSpecType, Spec},
26 types::{
27 GraphInlineType, GraphOperation, GraphSchemaType, GraphStruct, GraphTagged,
28 GraphTaggedVariant, GraphType, InlineTypePath, InlineTypePathRoot, InlineTypePathSegment,
29 SchemaTypeInfo, SpecInlineType, SpecSchemaType, SpecType, SpecUntaggedVariant,
30 StructFieldName, mapper::TypeMapper,
31 },
32 views::{operation::OperationView, primitive::PrimitiveView, schema::SchemaTypeView},
33};
34
35type RawDiGraph<'a> = StableDiGraph<GraphType<'a>, EdgeKind, usize>;
37
38type CookedDiGraph<'a> = DiGraph<GraphType<'a>, EdgeKind, usize>;
40
41#[derive(Debug)]
52pub struct RawGraph<'a> {
53 arena: &'a Arena,
54 spec: &'a Spec<'a>,
55 graph: RawDiGraph<'a>,
56 ops: &'a [&'a GraphOperation<'a>],
57}
58
59impl<'a> RawGraph<'a> {
60 pub fn new(arena: &'a Arena, spec: &'a Spec<'a>) -> Self {
61 let tys = SpecTypeVisitor::new(
64 spec.schemas
65 .values()
66 .chain(spec.operations.iter().flat_map(|op| op.types().copied())),
67 );
68
69 let mut indices = FxHashMap::default();
71 let mut schemas = FxHashMap::default();
72 let mut nodes = vec![];
73 let mut edges = vec![];
74 for (parent, kind, child) in tys {
75 use std::collections::hash_map::Entry;
76 let source = spec.resolve(child);
77 let &mut to = match indices.entry(source) {
78 Entry::Occupied(entry) => entry.into_mut(),
79 Entry::Vacant(entry) => {
80 let index = NodeIndex::new(nodes.len());
84 nodes.push(*entry.key());
85 entry.insert(index)
86 }
87 };
88 if let ResolvedSpecType::Schema(ty) = source {
90 schemas.entry(ty.name()).or_insert(to);
91 }
92 if let Some(parent) = parent {
93 let destination = spec.resolve(parent);
94 let &mut from = match indices.entry(destination) {
95 Entry::Occupied(entry) => entry.into_mut(),
96 Entry::Vacant(entry) => {
97 let index = NodeIndex::new(nodes.len());
98 nodes.push(*entry.key());
99 entry.insert(index)
100 }
101 };
102 if let ResolvedSpecType::Schema(ty) = destination {
103 schemas.entry(ty.name()).or_insert(from);
104 }
105 edges.push((from, to, kind));
106 }
107 }
108
109 let mut graph = RawDiGraph::with_capacity(nodes.len(), edges.len());
112 let mapper = TypeMapper::new(arena, |ty: &SpecType<'_>| match ty {
113 SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
114 SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
115 SpecType::Ref(r) => schemas[&*r.name()],
116 });
117 for node in nodes {
118 let mapped = match node {
119 ResolvedSpecType::Schema(ty) => GraphType::Schema(mapper.schema(ty)),
120 ResolvedSpecType::Inline(ty) => GraphType::Inline(mapper.inline(ty)),
121 };
122 let index = graph.add_node(mapped);
123 debug_assert_eq!(index, indices[&node]);
124 }
125 for (from, to, kind) in edges {
126 graph.add_edge(from, to, kind);
127 }
128
129 let ops = arena.alloc_slice_exact(spec.operations.iter().map(|op| mapper.operation(op)));
131
132 Self {
133 arena,
134 spec,
135 graph,
136 ops,
137 }
138 }
139
140 pub fn inline_tagged_variants(&mut self) -> &mut Self {
160 struct TaggedPlan<'a> {
161 tagged_index: NodeIndex<usize>,
162 info: SchemaTypeInfo<'a>,
163 tagged: GraphTagged<'a>,
164 inlines: Vec<VariantInline<'a>>,
165 }
166 struct VariantInline<'a> {
167 node: GraphType<'a>,
168 variant_index: NodeIndex<usize>,
169 parent_indices: &'a [NodeIndex<usize>],
170 name: &'a str,
171 aliases: &'a [&'a str],
172 }
173
174 let used_by_ops: FixedBitSet = self
181 .ops
182 .iter()
183 .flat_map(|op| op.types())
184 .map(|node| node.index())
185 .collect();
186
187 let plans = self
190 .graph
191 .node_indices()
192 .filter_map(|index| {
193 let GraphType::Schema(GraphSchemaType::Tagged(info, tagged)) = self.graph[index]
194 else {
195 return None;
196 };
197 let mut inlines = vec![];
198
199 for variant in tagged.variants {
200 let variant_index = variant.ty;
201 let GraphType::Schema(GraphSchemaType::Struct(variant_info, variant_struct)) =
202 self.graph[variant_index]
203 else {
204 continue;
205 };
206
207 if !used_by_ops.contains(variant_index.index()) {
213 let Some(first) = ({
214 self.graph
215 .neighbors_directed(variant_index, Direction::Incoming)
216 .find_map(|neighbor| match self.graph[neighbor] {
217 GraphType::Schema(GraphSchemaType::Tagged(_, t)) => Some(t),
218 _ => None,
219 })
220 }) else {
221 continue;
222 };
223 let all_agree = self
227 .graph
228 .neighbors_directed(variant_index, Direction::Incoming)
229 .all(|neighbor| {
230 matches!(
231 self.graph[neighbor],
232 GraphType::Schema(GraphSchemaType::Tagged(_, t))
233 if t.tag == first.tag && t.fields == first.fields,
234 )
235 });
236 if all_agree {
237 continue;
238 }
239 }
240
241 let (has_tag_field, already_inherits) = {
248 let inherits = EdgeFiltered::from_fn(&self.graph, |e| {
249 matches!(e.weight(), EdgeKind::Inherits)
250 });
251 let mut dfs = DfsPostOrder::new(&inherits, variant_index);
252 let mut has_tag_field = false;
253 let mut already_inherits = false;
254 while let Some(ancestor) = dfs.next(&inherits)
255 && !(has_tag_field && already_inherits)
256 {
257 already_inherits |= ancestor == index;
258 has_tag_field |= match self.graph[ancestor] {
259 GraphType::Schema(GraphSchemaType::Struct(_, s))
260 | GraphType::Inline(GraphInlineType::Struct(_, s)) => {
261 s.fields.iter().any(|f| {
264 matches!(
265 f.name,
266 StructFieldName::Name(n) if n == tagged.tag,
267 )
268 })
269 }
270 _ => false,
271 };
272 }
273 (has_tag_field, already_inherits)
274 };
275 if !has_tag_field && (tagged.fields.is_empty() || already_inherits) {
276 continue;
277 }
278
279 let parents = self.arena.alloc_slice(itertools::chain!(
284 variant_struct.parents.iter().copied(),
285 std::iter::once(index),
286 ));
287 let node = GraphType::Inline(GraphInlineType::Struct(
288 InlineTypePath {
289 root: InlineTypePathRoot::Type(info.name),
290 segments: self.arena.alloc_slice_copy(&[
291 InlineTypePathSegment::TaggedVariant(variant_info.name),
292 ]),
293 },
294 GraphStruct {
295 description: variant_struct.description,
296 fields: variant_struct.fields,
297 parents,
298 },
299 ));
300
301 inlines.push(VariantInline {
302 node,
303 variant_index,
304 parent_indices: parents,
305 name: variant.name,
306 aliases: variant.aliases,
307 });
308 }
309 if inlines.is_empty() {
310 return None;
311 }
312
313 Some(TaggedPlan {
314 tagged_index: index,
315 info,
316 tagged,
317 inlines,
318 })
319 })
320 .collect_vec();
321
322 for plan in plans {
324 let mut new_variants = FxHashMap::default();
325
326 for entry in &plan.inlines {
328 let node_index = self.graph.add_node(entry.node);
329
330 self.graph
334 .add_edge(node_index, entry.variant_index, EdgeKind::Reference);
335
336 for &parent_index in entry.parent_indices {
338 self.graph
339 .add_edge(node_index, parent_index, EdgeKind::Inherits);
340 }
341
342 new_variants.insert(
343 entry.variant_index,
344 GraphTaggedVariant {
345 name: entry.name,
346 aliases: entry.aliases,
347 ty: node_index,
348 },
349 );
350 }
351
352 let edges_to_retarget = self
356 .graph
357 .edges_directed(plan.tagged_index, Direction::Outgoing)
358 .filter(|e| {
359 matches!(e.weight(), EdgeKind::Reference)
360 && new_variants.contains_key(&e.target())
361 })
362 .map(|e| (e.id(), new_variants[&e.target()].ty))
363 .collect_vec();
364 for (edge_id, new_target) in edges_to_retarget {
365 self.graph.remove_edge(edge_id);
366 self.graph
367 .add_edge(plan.tagged_index, new_target, EdgeKind::Reference);
368 }
369
370 let modified_tagged = GraphTagged {
372 description: plan.tagged.description,
373 tag: plan.tagged.tag,
374 variants: self.arena.alloc_slice_exact(
375 plan.tagged
376 .variants
377 .iter()
378 .map(|&v| new_variants.get(&v.ty).copied().unwrap_or(v)),
379 ),
380 fields: plan.tagged.fields,
381 };
382 self.graph[plan.tagged_index] =
383 GraphType::Schema(GraphSchemaType::Tagged(plan.info, modified_tagged));
384 }
385
386 self
387 }
388
389 #[inline]
391 pub fn cook(&self) -> CookedGraph<'a> {
392 CookedGraph::new(self)
393 }
394}
395
396#[derive(Debug)]
402pub struct CookedGraph<'a> {
403 pub(super) graph: CookedDiGraph<'a>,
404 info: &'a Info,
405 ops: &'a [&'a GraphOperation<'a>],
406 pub(super) metadata: CookedGraphMetadata<'a>,
408}
409
410impl<'a> CookedGraph<'a> {
411 fn new(raw: &RawGraph<'a>) -> Self {
412 let indices: FxHashMap<_, _> = raw
414 .graph
415 .node_indices()
416 .enumerate()
417 .map(|(cooked, raw)| (raw, NodeIndex::new(cooked)))
418 .collect();
419
420 let mapper = TypeMapper::new(raw.arena, |index| indices[&index]);
422
423 let mut graph = CookedDiGraph::with_capacity(indices.len(), raw.graph.edge_count());
425 for index in raw.graph.node_indices() {
426 let node = raw.graph[index];
427 let mapped = match node {
428 GraphType::Schema(ty) => GraphType::Schema(mapper.schema(&ty)),
429 GraphType::Inline(ty) => GraphType::Inline(mapper.inline(&ty)),
430 };
431 let cooked = graph.add_node(mapped);
432 debug_assert_eq!(indices[&index], cooked);
433 }
434
435 for index in raw.graph.node_indices() {
437 let from = indices[&index];
438 let edges = raw
441 .graph
442 .edges(index)
443 .map(|e| (indices[&e.target()], *e.weight()))
444 .collect_vec();
445 for (to, kind) in edges.into_iter().rev() {
446 graph.add_edge(from, to, kind);
447 }
448 }
449
450 let sccs = TopoSccs::new(&graph);
451
452 let (metadata, ops) = {
453 let mut metadata = CookedGraphMetadata {
454 scc_indices: {
455 let refs = EdgeFiltered::from_fn(&graph, |e| {
458 matches!(e.weight(), EdgeKind::Reference)
459 });
460 let mut scc = TarjanScc::new();
461 scc.run(&refs, |_| ());
462 graph
463 .node_indices()
464 .map(|node| scc.node_component_index(&refs, node))
465 .collect()
466 },
467 schemas: std::iter::repeat_with(GraphNodeMeta::default)
470 .take(graph.node_count())
471 .collect(),
472 operations: FxHashMap::default(),
473 };
474
475 let ops: &_ = raw
477 .arena
478 .alloc_slice_exact(raw.ops.iter().map(|&op| mapper.operation(op)));
479
480 for &&op in ops {
483 metadata.operations.entry(op).or_default().types =
484 op.types().map(|node| node.index()).collect();
485 }
486
487 {
490 let condensation = sccs.condensation();
493
494 let (_, closure) = tred::dag_transitive_reduction_closure(&condensation);
496
497 let mut scc_deps =
499 vec![FixedBitSet::with_capacity(graph.node_count()); sccs.scc_count()];
500 for (scc, deps) in scc_deps.iter_mut().enumerate() {
501 deps.extend(
505 std::iter::once(scc)
506 .chain(closure.neighbors(scc))
507 .flat_map(|scc| sccs.sccs[scc].ones()),
508 );
509 }
510
511 let mut node_dependents =
514 vec![FixedBitSet::with_capacity(graph.node_count()); graph.node_count()];
515 for node in graph.node_indices() {
516 let mut deps = scc_deps[sccs.topo_index(node)].clone();
517 deps.remove(node.index()); for dep in deps.ones() {
519 node_dependents[dep].insert(node.index());
520 }
521 metadata.schemas[node.index()].dependencies = deps;
522 }
523 for (index, dependents) in node_dependents.into_iter().enumerate() {
524 metadata.schemas[index].dependents = dependents;
525 }
526 }
527
528 for &&op in ops {
531 let meta = &metadata.operations[&op];
532
533 let mut transitive_deps = FixedBitSet::with_capacity(graph.node_count());
535 for node in meta.types.ones() {
536 transitive_deps.insert(node);
537 transitive_deps.union_with(&metadata.schemas[node].dependencies);
538 }
539
540 for node in transitive_deps.ones() {
542 metadata.schemas[node].used_by.insert(op);
543 }
544 }
545
546 (metadata, ops)
547 };
548
549 Self {
550 graph,
551 info: raw.spec.info,
552 ops,
553 metadata,
554 }
555 }
556
557 #[inline]
560 pub fn info(&self) -> &'a Info {
561 self.info
562 }
563
564 #[inline]
566 pub fn schemas(&self) -> impl Iterator<Item = SchemaTypeView<'_>> {
567 self.graph
568 .node_indices()
569 .filter_map(|index| match self.graph[index] {
570 GraphType::Schema(ty) => Some(SchemaTypeView::new(self, index, ty)),
571 _ => None,
572 })
573 }
574
575 #[inline]
577 pub fn primitives(&self) -> impl Iterator<Item = PrimitiveView<'_>> {
578 self.graph
579 .node_indices()
580 .filter_map(|index| match self.graph[index] {
581 GraphType::Schema(GraphSchemaType::Primitive(_, p))
582 | GraphType::Inline(GraphInlineType::Primitive(_, p)) => {
583 Some(PrimitiveView::new(self, index, p))
584 }
585 _ => None,
586 })
587 }
588
589 #[inline]
591 pub fn operations(&self) -> impl Iterator<Item = OperationView<'_>> {
592 self.ops.iter().map(move |&op| OperationView::new(self, op))
593 }
594}
595
596#[derive(Clone, Copy, Debug, Enum, Eq, Hash, PartialEq)]
598pub enum EdgeKind {
599 Reference,
601 Inherits,
603}
604
605#[derive(Debug, Default)]
607pub(super) struct CookedGraphMetadata<'a> {
608 pub scc_indices: Vec<usize>,
611 pub schemas: Vec<GraphNodeMeta<'a>>,
612 pub operations: FxHashMap<GraphOperation<'a>, GraphOperationMeta>,
613}
614
615#[derive(Debug, Default)]
618pub(super) struct GraphOperationMeta {
619 pub types: FixedBitSet,
622}
623
624#[derive(Default)]
626pub(super) struct GraphNodeMeta<'a> {
627 pub used_by: FxHashSet<GraphOperation<'a>>,
629 pub dependencies: FixedBitSet,
631 pub dependents: FixedBitSet,
633 pub extensions: AtomicRefCell<ExtensionMap>,
635}
636
637impl Debug for GraphNodeMeta<'_> {
638 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.debug_struct("GraphNodeMeta")
640 .field("used_by", &self.used_by)
641 .field("dependencies", &self.dependencies)
642 .field("dependents", &self.dependents)
643 .finish_non_exhaustive()
644 }
645}
646
647#[derive(Debug)]
649struct SpecTypeVisitor<'a> {
650 stack: Vec<(Option<&'a SpecType<'a>>, EdgeKind, &'a SpecType<'a>)>,
651}
652
653impl<'a> SpecTypeVisitor<'a> {
654 #[inline]
656 fn new(roots: impl Iterator<Item = &'a SpecType<'a>>) -> Self {
657 let mut stack = roots
658 .map(|root| (None, EdgeKind::Reference, root))
659 .collect_vec();
660 stack.reverse();
661 Self { stack }
662 }
663}
664
665impl<'a> Iterator for SpecTypeVisitor<'a> {
666 type Item = (Option<&'a SpecType<'a>>, EdgeKind, &'a SpecType<'a>);
667
668 fn next(&mut self) -> Option<Self::Item> {
669 let (parent, kind, top) = self.stack.pop()?;
670 match top {
671 SpecType::Schema(SpecSchemaType::Struct(_, ty))
672 | SpecType::Inline(SpecInlineType::Struct(_, ty)) => {
673 self.stack.extend(
674 itertools::chain!(
675 ty.fields
676 .iter()
677 .map(|field| (EdgeKind::Reference, field.ty)),
678 ty.parents
679 .iter()
680 .map(|parent| (EdgeKind::Inherits, *parent)),
681 )
682 .map(|(kind, ty)| (Some(top), kind, ty))
683 .rev(),
684 );
685 }
686 SpecType::Schema(SpecSchemaType::Untagged(_, ty))
687 | SpecType::Inline(SpecInlineType::Untagged(_, ty)) => {
688 self.stack.extend(
689 itertools::chain!(
690 ty.fields
691 .iter()
692 .map(|field| (EdgeKind::Reference, field.ty)),
693 ty.variants.iter().filter_map(|variant| match variant {
694 SpecUntaggedVariant::Some(_, ty) => {
695 Some((EdgeKind::Reference, *ty))
696 }
697 _ => None,
698 }),
699 )
700 .map(|(kind, ty)| (Some(top), kind, ty))
701 .rev(),
702 );
703 }
704 SpecType::Schema(SpecSchemaType::Tagged(_, ty))
705 | SpecType::Inline(SpecInlineType::Tagged(_, ty)) => {
706 self.stack.extend(
707 itertools::chain!(
708 ty.fields
709 .iter()
710 .map(|field| (EdgeKind::Reference, field.ty)),
711 ty.variants
712 .iter()
713 .map(|variant| (EdgeKind::Reference, variant.ty)),
714 )
715 .map(|(kind, ty)| (Some(top), kind, ty))
716 .rev(),
717 );
718 }
719 SpecType::Schema(SpecSchemaType::Container(_, container))
720 | SpecType::Inline(SpecInlineType::Container(_, container)) => {
721 self.stack
722 .push((Some(top), EdgeKind::Reference, container.inner().ty));
723 }
724 SpecType::Schema(
725 SpecSchemaType::Enum(..) | SpecSchemaType::Primitive(..) | SpecSchemaType::Any(_),
726 )
727 | SpecType::Inline(
728 SpecInlineType::Enum(..) | SpecInlineType::Primitive(..) | SpecInlineType::Any(_),
729 ) => (),
730 SpecType::Ref(_) => (),
731 }
732 Some((parent, kind, top))
733 }
734}
735
736pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
738
739pub trait Extension: Any + Send + Sync {
740 fn into_inner(self: Box<Self>) -> Box<dyn Any>;
741}
742
743impl dyn Extension {
744 #[inline]
745 pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
746 (self as &dyn Any).downcast_ref::<T>()
747 }
748}
749
750impl<T: Send + Sync + 'static> Extension for T {
751 #[inline]
752 fn into_inner(self: Box<Self>) -> Box<dyn Any> {
753 self
754 }
755}
756
757struct TopoSccs<'a, N, E> {
764 graph: &'a DiGraph<N, E, usize>,
765 tarjan: TarjanScc<NodeIndex<usize>>,
766 sccs: Vec<FixedBitSet>,
767}
768
769impl<'a, N, E> TopoSccs<'a, N, E> {
770 fn new(graph: &'a DiGraph<N, E, usize>) -> Self {
771 let mut sccs = Vec::new();
772 let mut tarjan = TarjanScc::new();
773 tarjan.run(graph, |scc_nodes| {
774 sccs.push(scc_nodes.iter().map(|node| node.index()).collect());
775 });
776 sccs.reverse();
779 Self {
780 graph,
781 tarjan,
782 sccs,
783 }
784 }
785
786 #[inline]
787 fn scc_count(&self) -> usize {
788 self.sccs.len()
789 }
790
791 #[inline]
793 fn topo_index(&self, node: NodeIndex<usize>) -> usize {
794 self.sccs.len() - 1 - self.tarjan.node_component_index(self.graph, node)
797 }
798
799 #[cfg(test)]
801 fn iter(&self) -> std::slice::Iter<'_, FixedBitSet> {
802 self.sccs.iter()
803 }
804
805 fn condensation(&self) -> UnweightedList<usize> {
812 let mut dag = UnweightedList::with_capacity(self.scc_count());
813 for to in 0..self.scc_count() {
814 dag.add_node();
815 for index in self.sccs[to].ones().map(NodeIndex::new) {
816 for neighbor in self.graph.neighbors_directed(index, Direction::Incoming) {
817 let from = self.topo_index(neighbor);
818 if from != to {
819 dag.update_edge(from, to, ());
820 }
821 }
822 }
823 }
824 dag
825 }
826}
827
828#[derive(Clone, Copy, Debug, Eq, PartialEq)]
830pub enum Traversal {
831 Visit,
833 Stop,
835 Skip,
837 Ignore,
839}
840
841pub struct Traverse<'a> {
850 graph: &'a CookedDiGraph<'a>,
851 stack: VecDeque<(EdgeKind, NodeIndex<usize>)>,
852 discovered: EnumMap<EdgeKind, FixedBitSet>,
853 direction: Direction,
854}
855
856impl<'a> Traverse<'a> {
857 pub fn at_root(
860 graph: &'a CookedDiGraph<'a>,
861 root: NodeIndex<usize>,
862 direction: Direction,
863 ) -> Self {
864 let mut discovered = enum_map!(_ => graph.visit_map());
865 discovered[EdgeKind::Reference].grow_and_insert(root.index());
866 Self {
867 graph,
868 stack: VecDeque::from([(EdgeKind::Reference, root)]),
869 discovered,
870 direction,
871 }
872 }
873
874 pub fn at_roots(
877 graph: &'a CookedDiGraph<'a>,
878 roots: &FixedBitSet,
879 direction: Direction,
880 ) -> Self {
881 let mut stack = VecDeque::new();
882 let mut discovered = enum_map!(_ => graph.visit_map());
883 stack.extend(
884 roots
885 .ones()
886 .map(|index| (EdgeKind::Reference, NodeIndex::new(index))),
887 );
888 discovered[EdgeKind::Reference].union_with(roots);
889 Self {
890 graph,
891 stack,
892 discovered,
893 direction,
894 }
895 }
896
897 pub fn from_neighbors(
900 graph: &'a CookedDiGraph<'a>,
901 root: NodeIndex<usize>,
902 direction: Direction,
903 ) -> Self {
904 let mut stack = VecDeque::new();
905 let mut discovered = enum_map! {
906 _ => {
907 let mut map = graph.visit_map();
908 map.visit(root);
909 map
910 }
911 };
912 for (kind, neighbors) in neighbors(graph, root, direction) {
913 stack.extend(
914 neighbors
915 .difference(&discovered[kind])
916 .map(|index| (kind, NodeIndex::new(index))),
917 );
918 discovered[kind].union_with(&neighbors);
919 }
920 Self {
921 graph,
922 stack,
923 discovered,
924 direction,
925 }
926 }
927
928 pub fn run<F>(mut self, filter: F) -> impl Iterator<Item = NodeIndex<usize>> + use<'a, F>
929 where
930 F: Fn(EdgeKind, NodeIndex<usize>) -> Traversal,
931 {
932 std::iter::from_fn(move || {
933 while let Some((kind, index)) = self.stack.pop_front() {
934 let traversal = filter(kind, index);
935
936 if matches!(traversal, Traversal::Visit | Traversal::Skip) {
937 for (kind, neighbors) in neighbors(self.graph, index, self.direction) {
938 for neighbor in neighbors.difference(&self.discovered[kind]) {
939 self.stack.push_back((kind, NodeIndex::new(neighbor)));
940 }
941 self.discovered[kind].union_with(&neighbors);
942 }
943 }
944
945 if matches!(traversal, Traversal::Visit | Traversal::Stop) {
946 return Some(index);
947 }
948
949 }
951 None
952 })
953 }
954}
955
956fn neighbors(
959 graph: &CookedDiGraph<'_>,
960 node: NodeIndex<usize>,
961 direction: Direction,
962) -> EnumMap<EdgeKind, FixedBitSet> {
963 let mut neighbors = enum_map!(_ => graph.visit_map());
964 for edge in graph.edges_directed(node, direction) {
965 let neighbor = match direction {
966 Direction::Outgoing => edge.target(),
967 Direction::Incoming => edge.source(),
968 };
969 neighbors[*edge.weight()].insert(neighbor.index());
970 }
971 neighbors
972}
973
974#[cfg(test)]
975mod tests {
976 use super::*;
977
978 use petgraph::visit::NodeCount;
979
980 use crate::tests::assert_matches;
981
982 fn linear_graph() -> DiGraph<(), (), usize> {
984 let mut g = DiGraph::default();
985 let a = g.add_node(());
986 let b = g.add_node(());
987 let c = g.add_node(());
988 g.extend_with_edges([(a, b), (b, c)]);
989 g
990 }
991
992 fn cyclic_graph() -> DiGraph<(), (), usize> {
994 let mut g = DiGraph::default();
995 let a = g.add_node(());
996 let b = g.add_node(());
997 let c = g.add_node(());
998 let d = g.add_node(());
999 g.extend_with_edges([(a, b), (b, c), (c, a), (d, a)]);
1000 g
1001 }
1002
1003 #[test]
1006 fn test_linear_graph_has_singleton_sccs() {
1007 let g = linear_graph();
1008 let sccs = TopoSccs::new(&g);
1009 let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
1010 assert_matches!(&*sizes, [1, 1, 1]);
1011 }
1012
1013 #[test]
1014 fn test_cyclic_graph_has_one_multi_node_scc() {
1015 let g = cyclic_graph();
1016 let sccs = TopoSccs::new(&g);
1017
1018 let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
1021 assert_matches!(&*sizes, [1, 3]);
1022 }
1023
1024 #[test]
1027 fn test_sccs_are_in_topological_order() {
1028 let g = cyclic_graph();
1029 let sccs = TopoSccs::new(&g);
1030
1031 let d_topo = sccs.topo_index(3.into());
1032 let a_topo = sccs.topo_index(0.into());
1033 assert!(
1034 d_topo < a_topo,
1035 "D should precede A-B-C in topological order"
1036 );
1037 }
1038
1039 #[test]
1040 fn test_topo_index_consistent_within_scc() {
1041 let g = cyclic_graph();
1042 let sccs = TopoSccs::new(&g);
1043
1044 let a_topo = sccs.topo_index(0.into());
1047 let b_topo = sccs.topo_index(1.into());
1048 let c_topo = sccs.topo_index(2.into());
1049
1050 assert_eq!(a_topo, b_topo);
1051 assert_eq!(b_topo, c_topo);
1052 }
1053
1054 #[test]
1057 fn test_condensation_has_correct_node_count() {
1058 let g = cyclic_graph();
1059 let sccs = TopoSccs::new(&g);
1060 let dag = sccs.condensation();
1061
1062 assert_eq!(dag.node_count(), 2);
1063 }
1064
1065 #[test]
1066 fn test_condensation_has_correct_edges() {
1067 let g = cyclic_graph();
1068 let sccs = TopoSccs::new(&g);
1069 let dag = sccs.condensation();
1070
1071 let d_topo = sccs.topo_index(3.into());
1074 let abc_topo = sccs.topo_index(0.into());
1075
1076 let d_neighbors = dag.neighbors(d_topo).collect_vec();
1077 assert_eq!(&*d_neighbors, [abc_topo]);
1078
1079 let abc_neighbors = dag.neighbors(abc_topo).collect_vec();
1080 assert!(abc_neighbors.is_empty());
1081 }
1082
1083 #[test]
1084 fn test_condensation_neighbors_in_topological_order() {
1085 let mut g = DiGraph::<(), (), usize>::default();
1089 let second = g.add_node(());
1090 let top = g.add_node(());
1091 let first = g.add_node(());
1092 g.extend_with_edges([(top, second), (top, first), (first, second)]);
1093
1094 let sccs = TopoSccs::new(&g);
1095 let dag = sccs.condensation();
1096
1097 let top_topo = sccs.topo_index(top);
1098 assert_eq!(top_topo, 0);
1099
1100 let first_topo = sccs.topo_index(first);
1101 assert_eq!(first_topo, 1);
1102
1103 let second_topo = sccs.topo_index(second);
1104 assert_eq!(second_topo, 2);
1105
1106 let neighbors = dag.neighbors(top_topo).collect_vec();
1107 assert_eq!(&*neighbors, [first_topo, second_topo]);
1108 }
1109}