Skip to main content

ploidy_core/ir/
graph.rs

1use std::{
2    any::{Any, TypeId},
3    collections::VecDeque,
4    fmt::Debug,
5};
6
7use atomic_refcell::AtomicRefCell;
8use fixedbitset::FixedBitSet;
9use itertools::Itertools;
10use petgraph::{
11    Direction,
12    adj::UnweightedList,
13    algo::{TarjanScc, tred},
14    data::Build,
15    graph::{DiGraph, NodeIndex},
16    stable_graph::StableDiGraph,
17    visit::{
18        DfsPostOrder, EdgeFiltered, EdgeRef, IntoNeighbors, IntoNeighborsDirected,
19        IntoNodeIdentifiers, NodeCount, NodeIndexable,
20    },
21};
22use rustc_hash::{FxBuildHasher, FxHashMap};
23
24use crate::{
25    arena::Arena,
26    ir::{SchemaTypeInfo, UntaggedVariantMeta},
27    parse::Info,
28};
29
30use super::{
31    spec::{ResolvedSpecType, Spec},
32    types::{
33        FieldMeta, GraphContainer, GraphInlineType, GraphOperation, GraphSchemaType, GraphStruct,
34        GraphTagged, GraphType, InlineTypePath, InlineTypePathRoot, InlineTypePathSegment,
35        PrimitiveType, SpecInlineType, SpecSchemaType, SpecType, SpecUntaggedVariant,
36        StructFieldName, TaggedVariantMeta, VariantMeta,
37        shape::{Operation, Parameter, ParameterInfo, Request, Response},
38    },
39    views::{operation::OperationView, primitive::PrimitiveView, schema::SchemaTypeView},
40};
41
42/// The mutable, sparse graph used for transformations.
43type RawDiGraph<'a> = StableDiGraph<GraphType<'a>, GraphEdge<'a>, usize>;
44
45/// The immutable, dense graph used for code generation.
46type CookedDiGraph<'a> = DiGraph<GraphType<'a>, GraphEdge<'a>, usize>;
47
48/// A mutable intermediate dependency graph of all the types in a [`Spec`],
49/// backed by a sparse [`StableDiGraph`].
50///
51/// This graph is constructed directly from a [`Spec`], and represents
52/// type relationships as they exist in the spec. Transformations like
53/// [`inline_tagged_variants`][Self::inline_tagged_variants] rewrite this graph
54/// in place.
55///
56/// After applying all transformations, call [`cook`][Self::cook] to
57/// turn this graph into a [`CookedGraph`] that's ready for code generation.
58#[derive(Debug)]
59pub struct RawGraph<'a> {
60    arena: &'a Arena,
61    spec: &'a Spec<'a>,
62    graph: RawDiGraph<'a>,
63    schemas: FxHashMap<&'a str, NodeIndex<usize>>,
64    ops: &'a [&'a GraphOperation<'a>],
65}
66
67impl<'a> RawGraph<'a> {
68    /// Builds a raw type graph from the given spec.
69    pub fn new(arena: &'a Arena, spec: &'a Spec<'a>) -> Self {
70        // All roots (named schemas, parameters, request and response bodies),
71        // and all the types within them (inline schemas and primitives).
72        let tys = SpecTypeVisitor::new(
73            spec.schemas
74                .values()
75                .chain(spec.operations.iter().flat_map(|op| op.types().copied())),
76        );
77
78        // Inflate a graph from the traversal.
79        let mut indices = FxHashMap::default();
80        let mut schemas = FxHashMap::default();
81        let mut graph = RawDiGraph::default();
82        for (parent, child) in tys {
83            use std::collections::hash_map::Entry;
84
85            let source = spec.resolve(child);
86            let &mut to = match indices.entry(source) {
87                Entry::Occupied(entry) => entry.into_mut(),
88                Entry::Vacant(entry) => {
89                    // We might see the same schema multiple times if it's
90                    // referenced multiple times in the spec. Only add
91                    // a new node for the schema if we haven't seen it before.
92                    let index = graph.add_node(match *entry.key() {
93                        ResolvedSpecType::Schema(&ty) => GraphType::Schema(ty.into()),
94                        ResolvedSpecType::Inline(&ty) => GraphType::Inline(ty.into()),
95                    });
96                    if let ResolvedSpecType::Schema(ty) = source {
97                        schemas.entry(ty.name()).or_insert(index);
98                    }
99                    entry.insert(index)
100                }
101            };
102
103            if let Some((parent, edge)) = parent {
104                let destination = spec.resolve(parent);
105                let &mut from = match indices.entry(destination) {
106                    Entry::Occupied(entry) => entry.into_mut(),
107                    Entry::Vacant(entry) => {
108                        let index = graph.add_node(match *entry.key() {
109                            ResolvedSpecType::Schema(&ty) => GraphType::Schema(ty.into()),
110                            ResolvedSpecType::Inline(&ty) => GraphType::Inline(ty.into()),
111                        });
112                        if let ResolvedSpecType::Schema(ty) = destination {
113                            schemas.entry(ty.name()).or_insert(index);
114                        }
115                        entry.insert(index)
116                    }
117                };
118                graph.add_edge(from, to, edge);
119            }
120        }
121
122        // Map type references in operations to graph indices.
123        let ops = arena.alloc_slice_exact(spec.operations.iter().map(|op| {
124            let params = arena.alloc_slice_exact(op.params.iter().map(|param| match param {
125                Parameter::Path(info) => Parameter::Path(ParameterInfo {
126                    name: info.name,
127                    ty: match info.ty {
128                        SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
129                        SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
130                        SpecType::Ref(r) => schemas[&*r.name()],
131                    },
132                    required: info.required,
133                    description: info.description,
134                    style: info.style,
135                }),
136                Parameter::Query(info) => Parameter::Query(ParameterInfo {
137                    name: info.name,
138                    ty: match info.ty {
139                        SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
140                        SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
141                        SpecType::Ref(r) => schemas[&*r.name()],
142                    },
143                    required: info.required,
144                    description: info.description,
145                    style: info.style,
146                }),
147            }));
148
149            let request = op.request.as_ref().map(|r| match r {
150                Request::Json(ty) => Request::Json(match ty {
151                    SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
152                    SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
153                    SpecType::Ref(r) => schemas[&*r.name()],
154                }),
155                Request::Multipart => Request::Multipart,
156            });
157
158            let response = op.response.as_ref().map(|r| match r {
159                Response::Json(ty) => Response::Json(match ty {
160                    SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
161                    SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
162                    SpecType::Ref(r) => schemas[&*r.name()],
163                }),
164            });
165
166            &*arena.alloc(Operation {
167                id: op.id,
168                method: op.method,
169                path: op.path,
170                resource: op.resource,
171                description: op.description,
172                params,
173                request,
174                response,
175            })
176        }));
177
178        Self {
179            arena,
180            spec,
181            graph,
182            schemas,
183            ops,
184        }
185    }
186
187    /// Inlines schema types used as variants of multiple tagged unions
188    /// with different tags.
189    ///
190    /// In OpenAPI's model of tagged unions, the tag always references a field
191    /// that's defined on each variant struct. This model works well for Python
192    /// and TypeScript, but not Rust; Serde doesn't allow variant structs to
193    /// declare fields with the same name as the tag. The Rust generator
194    /// excludes tag fields when generating structs, but this introduces a
195    /// new problem: a struct can't appear as a variant of multiple unions
196    /// with different tags [^1].
197    ///
198    /// This transformation finds and inlines these structs, so that
199    /// the Rust generator can safely omit their tag fields.
200    ///
201    /// [^1]: If struct A has fields `foo` and `bar`, A is a variant of
202    /// tagged unions C and D, C's tag is `foo`, and D's tag is `bar`...
203    /// only `foo` should be excluded when A is used in C, and only `bar`
204    /// should be excluded when A is used in D; but this can't be modeled
205    /// in Serde without splitting A into two distinct types.
206    pub fn inline_tagged_variants(&mut self) -> &mut Self {
207        // Collect all inlining decisions before mutating the graph,
208        // so that we can check inlinability per variant.
209        let inlinables = self.inlinable_tagged_variants().collect_vec();
210
211        let mut retargets = FxHashMap::default();
212        retargets.reserve(inlinables.len());
213
214        // Add nodes for the inlined variant structs,
215        // and their outgoing edges.
216        for InlinableVariant { tagged, variant } in inlinables {
217            // Duplicate the variant struct as an inline type,
218            // with its original metadata.
219            let index = self
220                .graph
221                .add_node(GraphType::Inline(GraphInlineType::Struct(
222                    InlineTypePath {
223                        root: InlineTypePathRoot::Type(tagged.info.name),
224                        segments: self.arena.alloc_slice_copy(&[
225                            InlineTypePathSegment::TaggedVariant(variant.info.name),
226                        ]),
227                    },
228                    variant.ty,
229                )));
230
231            // Create shadow edges to the original variant struct's fields.
232            // These serve two purposes:
233            //
234            // 1. If a field is recursive, the duplicate joins the field's SCC,
235            //    not the original's SCC, so field edges to the original type
236            //    won't be treated as cyclic.
237            // 2. Hiding the originals' inlines from the duplicate's inlines.
238            //
239            // `fields()` yields edges in reverse order of addition;
240            // we collect and reverse to add them in their original order.
241            let original_field_edges = self.fields(variant.index).collect_vec();
242            for edge in original_field_edges.into_iter().rev() {
243                self.graph.add_edge(
244                    index,
245                    edge.target,
246                    GraphEdge::Field {
247                        meta: edge.meta,
248                        shadow: true,
249                    },
250                );
251            }
252
253            // Inherit from the tagged union (to pick up its own fields)
254            // and the original variant struct (to pick up its ancestors).
255            // The union is added first so that its fields appear first _and_
256            // can be overridden by the variant's fields.
257            self.graph
258                .add_edge(index, tagged.index, GraphEdge::Inherits { shadow: true });
259            self.graph
260                .add_edge(index, variant.index, GraphEdge::Inherits { shadow: true });
261
262            retargets.insert((tagged.index, variant.index), index);
263        }
264
265        // Retarget every tagged union's variant edges to the new structs.
266        let taggeds: FixedBitSet = retargets
267            .keys()
268            .map(|&(tagged, _)| tagged.index())
269            .collect();
270        for index in taggeds.ones().map(NodeIndex::new) {
271            let old_edges = self
272                .graph
273                .edges_directed(index, Direction::Outgoing)
274                .filter(|e| matches!(e.weight(), GraphEdge::Variant(_)))
275                .map(|e| (e.id(), *e.weight(), e.target()))
276                .collect_vec();
277            for &(id, _, _) in &old_edges {
278                self.graph.remove_edge(id);
279            }
280            // Re-add edges. `edges_directed` yields edges in reverse order
281            // of addition; reversing them adds edges in their original order.
282            for (_, weight, target) in old_edges.into_iter().rev() {
283                let new_target = retargets.get(&(index, target)).copied().unwrap_or(target);
284                self.graph.add_edge(index, new_target, weight);
285            }
286        }
287
288        self
289    }
290
291    /// Builds an immutable [`CookedGraph`] from this mutable raw graph.
292    #[inline]
293    pub fn cook(&self) -> CookedGraph<'a> {
294        CookedGraph::new(self)
295    }
296
297    /// Returns an iterator over all the fields of a struct or union type,
298    /// in reverse insertion order.
299    fn fields(&self, node: NodeIndex<usize>) -> impl Iterator<Item = OutgoingEdge<FieldMeta<'a>>> {
300        self.graph
301            .edges_directed(node, Direction::Outgoing)
302            .filter_map(|e| match e.weight() {
303                &GraphEdge::Field { meta, .. } => {
304                    let target = e.target();
305                    Some(OutgoingEdge { meta, target })
306                }
307                _ => None,
308            })
309    }
310
311    /// Returns an iterator over all the tagged union variant structs
312    /// that should be inlined.
313    fn inlinable_tagged_variants(&self) -> impl Iterator<Item = InlinableVariant<'a>> {
314        // Compute the set of types used by all operations.
315        // Operations don't participate in the graph, but
316        // still need to be considered when deciding
317        // whether to inline a variant struct.
318        //
319        // Otherwise, a struct that's used by same-tag unions
320        // _and_ an operation wouldn't be inlined, incorrectly
321        // removing its tag field.
322        let used_by_ops: FixedBitSet = self
323            .ops
324            .iter()
325            .flat_map(|op| op.types())
326            .map(|index| index.index())
327            .collect();
328
329        self.graph
330            .node_indices()
331            .filter_map(|index| match self.graph[index] {
332                GraphType::Schema(GraphSchemaType::Tagged(info, ty)) => {
333                    Some(Node { index, info, ty })
334                }
335                _ => None,
336            })
337            .flat_map(move |tagged| {
338                self.graph
339                    .edges_directed(tagged.index, Direction::Outgoing)
340                    .filter(|e| matches!(e.weight(), GraphEdge::Variant(_)))
341                    .filter_map(move |e| match self.graph[e.target()] {
342                        GraphType::Schema(GraphSchemaType::Struct(info, ty)) => {
343                            let index = e.target();
344                            Some((tagged, Node { index, info, ty }))
345                        }
346                        _ => None,
347                    })
348            })
349            .filter_map(move |(tagged, variant)| {
350                // A variant struct only needs inlining if it has multiple
351                // distinct uses. Skip if (1) no operation uses the struct,
352                // _and_ (2) every incoming edge is from a tagged union with
353                // the same tag and fields. If both hold, all uses agree, so
354                // the struct can be used directly without inlining.
355                if used_by_ops[variant.index.index()] {
356                    return Some((tagged, variant));
357                }
358
359                // Check that all the variant's inbound edges are from
360                // tagged unions, and that all their tags and field
361                // edges match the first union we found.
362                let first_tagged = self
363                    .graph
364                    .neighbors_directed(variant.index, Direction::Incoming)
365                    .find_map(|index| match self.graph[index] {
366                        GraphType::Schema(GraphSchemaType::Tagged(info, ty)) => {
367                            Some(Node { index, info, ty })
368                        }
369                        _ => None,
370                    })?;
371                let all_agree = self
372                    .graph
373                    .neighbors_directed(variant.index, Direction::Incoming)
374                    .all(|index| match self.graph[index] {
375                        GraphType::Schema(GraphSchemaType::Tagged(_, ty)) => {
376                            ty.tag == first_tagged.ty.tag
377                                && self.fields(index).eq(self.fields(first_tagged.index))
378                        }
379                        _ => false,
380                    });
381                if all_agree {
382                    return None;
383                }
384                Some((tagged, variant))
385            })
386            .filter_map(|(tagged, variant)| {
387                // Skip inlining when the inline copy would be identical
388                // to the original. This happens when the variant
389                // doesn't declare the tag as a field _and_ either
390                // (a) the union has no own fields, or (b) the variant
391                // inherits from the union.
392                let ancestors = EdgeFiltered::from_fn(&self.graph, |e| {
393                    matches!(e.weight(), GraphEdge::Inherits { .. })
394                });
395                let mut dfs = DfsPostOrder::new(&ancestors, variant.index);
396                let has_tag_field = std::iter::from_fn(|| dfs.next(&ancestors))
397                    .filter(|&n| {
398                        matches!(
399                            self.graph[n],
400                            GraphType::Schema(GraphSchemaType::Struct(..))
401                                | GraphType::Inline(GraphInlineType::Struct(..))
402                        )
403                    })
404                    .any(|n| {
405                        self.fields(n).any(|f| {
406                            matches!(f.meta.name, StructFieldName::Name(name)
407                                if name == tagged.ty.tag)
408                        })
409                    });
410
411                // If the variant declares or inherits the tag field,
412                // we must inline, so that the inline copy can safely
413                // omit the tag.
414                if has_tag_field {
415                    return Some(InlinableVariant { tagged, variant });
416                }
417
418                // If the DFS visited the union, the variant already inherits
419                // its fields; the inline copy would be identical.
420                if dfs.discovered[tagged.index.index()] {
421                    return None;
422                }
423
424                // If the variant doesn't inherit from the union, but the union
425                // has no fields of its own, the inline copy would be identical.
426                self.fields(tagged.index).next()?;
427
428                Some(InlinableVariant { tagged, variant })
429            })
430    }
431}
432
433/// The final dependency graph of all the types in a [`Spec`],
434/// backed by a dense [`DiGraph`].
435///
436/// This graph has all transformations applied, and is ready for
437/// code generation.
438#[derive(Debug)]
439pub struct CookedGraph<'a> {
440    pub(super) graph: CookedDiGraph<'a>,
441    info: &'a Info,
442    schemas: FxHashMap<&'a str, NodeIndex<usize>>,
443    ops: &'a [&'a GraphOperation<'a>],
444    /// Additional metadata for each node.
445    pub(super) metadata: CookedGraphMetadata<'a>,
446}
447
448impl<'a> CookedGraph<'a> {
449    fn new(raw: &RawGraph<'a>) -> Self {
450        // Build a dense graph, mapping sparse raw node indices to
451        // dense cooked node indices.
452        let mut graph =
453            CookedDiGraph::with_capacity(raw.graph.node_count(), raw.graph.edge_count());
454        let mut indices =
455            FxHashMap::with_capacity_and_hasher(raw.graph.node_count(), FxBuildHasher);
456        for raw_index in raw.graph.node_indices() {
457            let cooked_index = graph.add_node(raw.graph[raw_index]);
458            indices.insert(raw_index, cooked_index);
459        }
460
461        // Copy edges.
462        //
463        // `raw.graph.edges()` yields edges in reverse order of addition.
464        // The raw graph adds edges in declaration order, so `edges()`
465        // yields them reversed. Re-adding them to the cooked graph in that
466        // reversed order means they're now stored in reverse-declaration order,
467        // letting the cooked graph's accessors yield edges in declaration order
468        // without any extra work.
469        for index in raw.graph.node_indices() {
470            let from = indices[&index];
471            let edges = raw
472                .graph
473                .edges(index)
474                .map(|e| (indices[&e.target()], *e.weight()));
475            for (to, kind) in edges {
476                graph.add_edge(from, to, kind);
477            }
478        }
479
480        // Remap schema type references in operations.
481        let ops: &_ = raw.arena.alloc_slice_exact(raw.ops.iter().map(|&op| {
482            &*raw.arena.alloc(Operation {
483                id: op.id,
484                method: op.method,
485                path: op.path,
486                resource: op.resource,
487                description: op.description,
488                params: raw
489                    .arena
490                    .alloc_slice_exact(op.params.iter().map(|p| match p {
491                        Parameter::Path(info) => Parameter::Path(ParameterInfo {
492                            name: info.name,
493                            ty: indices[&info.ty],
494                            required: info.required,
495                            description: info.description,
496                            style: info.style,
497                        }),
498                        Parameter::Query(info) => Parameter::Query(ParameterInfo {
499                            name: info.name,
500                            ty: indices[&info.ty],
501                            required: info.required,
502                            description: info.description,
503                            style: info.style,
504                        }),
505                    })),
506                request: op.request.as_ref().map(|r| match r {
507                    Request::Json(ty) => Request::Json(indices[ty]),
508                    Request::Multipart => Request::Multipart,
509                }),
510                response: op.response.as_ref().map(|r| match r {
511                    Response::Json(ty) => Response::Json(indices[ty]),
512                }),
513            })
514        }));
515
516        let metadata = MetadataBuilder::new(&graph, ops).build();
517
518        Self {
519            graph,
520            info: raw.spec.info,
521            schemas: raw
522                .schemas
523                .iter()
524                .map(|(&name, index)| (name, indices[index]))
525                .collect(),
526            ops,
527            metadata,
528        }
529    }
530
531    /// Returns [`Info`] from the [`Document`][crate::parse::Document]
532    /// used to build this graph.
533    #[inline]
534    pub fn info(&self) -> &'a Info {
535        self.info
536    }
537
538    /// Returns an iterator over all the named schemas in this graph.
539    #[inline]
540    pub fn schemas(&self) -> impl Iterator<Item = SchemaTypeView<'_, 'a>> + use<'_, 'a> {
541        self.graph
542            .node_indices()
543            .filter_map(|index| match self.graph[index] {
544                GraphType::Schema(ty) => Some(SchemaTypeView::new(self, index, ty)),
545                _ => None,
546            })
547    }
548
549    /// Looks up and returns a schema by name.
550    #[inline]
551    pub fn schema(&self, name: &str) -> Option<SchemaTypeView<'_, 'a>> {
552        self.schemas
553            .get(name)
554            .and_then(|&index| match self.graph[index] {
555                GraphType::Schema(ty) => Some(SchemaTypeView::new(self, index, ty)),
556                _ => None,
557            })
558    }
559
560    /// Returns an iterator over all primitive type nodes in this graph.
561    #[inline]
562    pub fn primitives(&self) -> impl Iterator<Item = PrimitiveView<'_, 'a>> + use<'_, 'a> {
563        self.graph
564            .node_indices()
565            .filter_map(|index| match self.graph[index] {
566                GraphType::Schema(GraphSchemaType::Primitive(_, p))
567                | GraphType::Inline(GraphInlineType::Primitive(_, p)) => {
568                    Some(PrimitiveView::new(self, index, p))
569                }
570                _ => None,
571            })
572    }
573
574    /// Returns an iterator over all the operations in this graph.
575    #[inline]
576    pub fn operations(&self) -> impl Iterator<Item = OperationView<'_, 'a>> + use<'_, 'a> {
577        self.ops.iter().map(|&op| OperationView::new(self, op))
578    }
579
580    #[inline]
581    pub(super) fn inherits(
582        &self,
583        node: NodeIndex<usize>,
584    ) -> impl Iterator<Item = OutgoingEdge<()>> {
585        self.graph
586            .edges_directed(node, Direction::Outgoing)
587            .filter(|e| matches!(e.weight(), GraphEdge::Inherits { .. }))
588            .map(|e| OutgoingEdge {
589                meta: (),
590                target: e.target(),
591            })
592    }
593
594    #[inline]
595    pub(super) fn fields(
596        &self,
597        node: NodeIndex<usize>,
598    ) -> impl Iterator<Item = OutgoingEdge<FieldMeta<'a>>> {
599        self.graph
600            .edges_directed(node, Direction::Outgoing)
601            .filter_map(|e| match e.weight() {
602                &GraphEdge::Field { meta, .. } => {
603                    let target = e.target();
604                    Some(OutgoingEdge { meta, target })
605                }
606                _ => None,
607            })
608    }
609
610    #[inline]
611    pub(super) fn variants(
612        &self,
613        node: NodeIndex<usize>,
614    ) -> impl Iterator<Item = OutgoingEdge<VariantMeta<'a>>> {
615        self.graph
616            .edges_directed(node, Direction::Outgoing)
617            .filter_map(|e| match e.weight() {
618                &GraphEdge::Variant(meta) => {
619                    let target = e.target();
620                    Some(OutgoingEdge { meta, target })
621                }
622                _ => None,
623            })
624    }
625}
626
627/// A variant that should be inlined into its tagged union.
628struct InlinableVariant<'a> {
629    /// The tagged union that owns this variant.
630    tagged: Node<'a, GraphTagged<'a>>,
631    /// The original variant struct node.
632    variant: Node<'a, GraphStruct<'a>>,
633}
634
635/// An edge between two types in the type graph.
636///
637/// Edges describe the relationship between their source and target types.
638#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
639pub enum GraphEdge<'a> {
640    /// The source type inherits from the target type.
641    Inherits { shadow: bool },
642    /// The source struct, tagged union, or untagged union
643    /// has the target type as a field.
644    Field { shadow: bool, meta: FieldMeta<'a> },
645    /// The source union has the target type as a variant.
646    Variant(VariantMeta<'a>),
647    /// The source type is an array, map, or optional that contains
648    /// the target type.
649    Contains,
650}
651
652impl GraphEdge<'_> {
653    /// Returns `true` if the target type should be excluded from
654    /// the source type's [inlines], but still considered a dependency.
655    ///
656    /// Shadow edges prevent inlined variant structs from claiming
657    /// their originals' inlines.
658    ///
659    /// [inlines]: crate::ir::views::View::inlines
660    #[inline]
661    pub fn shadow(&self) -> bool {
662        matches!(
663            self,
664            GraphEdge::Field { shadow: true, .. } | GraphEdge::Inherits { shadow: true }
665        )
666    }
667}
668
669/// Metadata describing an edge from a source to a target type.
670#[derive(Clone, Copy, Debug, Eq, PartialEq)]
671pub struct OutgoingEdge<T> {
672    pub meta: T,
673    pub target: NodeIndex<usize>,
674}
675
676#[derive(Clone, Copy)]
677struct Node<'a, Ty> {
678    index: NodeIndex<usize>,
679    info: SchemaTypeInfo<'a>,
680    ty: Ty,
681}
682
683/// Precomputed metadata for schema types and operations in the graph.
684pub(super) struct CookedGraphMetadata<'a> {
685    /// Transitive closure over the type graph.
686    pub closure: Closure,
687    /// Maps each type to its SCC equivalence class for boxing decisions.
688    /// Two types in the same class form a cycle that requires `Box<T>`.
689    pub box_sccs: Vec<usize>,
690    /// Whether each type can implement `Eq` and `Hash`.
691    pub hashable: FixedBitSet,
692    /// Whether each type can implement `Default`.
693    pub defaultable: FixedBitSet,
694    /// Maps each type to the operations that use it.
695    pub used_by: Vec<Vec<GraphOperation<'a>>>,
696    /// Maps each operation to the types that it uses.
697    pub uses: FxHashMap<GraphOperation<'a>, FixedBitSet>,
698    /// Opaque extended data for each type.
699    pub extensions: Vec<AtomicRefCell<ExtensionMap>>,
700}
701
702impl Debug for CookedGraphMetadata<'_> {
703    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
704        f.debug_struct("CookedGraphMetadata")
705            .field("closure", &self.closure)
706            .field("box_sccs", &self.box_sccs)
707            .field("hashable", &self.hashable)
708            .field("defaultable", &self.defaultable)
709            .field("used_by", &self.used_by)
710            .field("uses", &self.uses)
711            .finish_non_exhaustive()
712    }
713}
714
715/// Precomputed bitsets indicating which types can derive
716/// `Eq` / `Hash` and `Default`.
717struct HashDefault {
718    hashable: FixedBitSet,
719    defaultable: FixedBitSet,
720}
721
722/// Precomputed metadata for an operation that references
723/// types in the graph.
724struct Operations<'a> {
725    /// All the types that each operation depends on, directly and transitively.
726    pub uses: FxHashMap<GraphOperation<'a>, FixedBitSet>,
727    /// All the operations that use each type, directly and transitively.
728    pub used_by: Vec<Vec<GraphOperation<'a>>>,
729}
730
731struct MetadataBuilder<'graph, 'a> {
732    graph: &'graph CookedDiGraph<'a>,
733    ops: &'graph [&'graph GraphOperation<'a>],
734    /// The full transitive closure of each type's dependencies.
735    closure: Closure,
736}
737
738impl<'graph, 'a> MetadataBuilder<'graph, 'a> {
739    fn new(graph: &'graph CookedDiGraph<'a>, ops: &'graph [&'graph GraphOperation<'a>]) -> Self {
740        Self {
741            graph,
742            ops,
743            closure: Closure::new(graph),
744        }
745    }
746
747    fn build(self) -> CookedGraphMetadata<'a> {
748        let operations = self.operations();
749        let HashDefault {
750            hashable,
751            defaultable,
752        } = self.hash_default();
753        let box_sccs = self.box_sccs();
754        CookedGraphMetadata {
755            closure: self.closure,
756            box_sccs,
757            hashable,
758            defaultable,
759            used_by: operations.used_by,
760            uses: operations.uses,
761            // `AtomicRefCell` doesn't implement `Clone`,
762            // so we use this idiom instead of `vec!`.
763            extensions: std::iter::repeat_with(AtomicRefCell::default)
764                .take(self.graph.node_count())
765                .collect(),
766        }
767    }
768
769    fn operations(&self) -> Operations<'a> {
770        let mut operations = Operations {
771            uses: FxHashMap::default(),
772            used_by: vec![vec![]; self.graph.node_count()],
773        };
774
775        for &&op in self.ops {
776            // Forward propagation: start from the direct types, then
777            // expand to the full transitive dependency set.
778            let mut dependencies = FixedBitSet::with_capacity(self.graph.node_count());
779            for &node in op.types() {
780                dependencies.extend(self.closure.dependencies_of(node).map(|n| n.index()));
781            }
782            operations.uses.entry(op).insert_entry(dependencies);
783        }
784
785        // Backward propagation: mark types as used by their operations.
786        for (op, deps) in &operations.uses {
787            for node in deps.ones() {
788                operations.used_by[node].push(*op);
789            }
790        }
791
792        operations
793    }
794
795    fn box_sccs(&self) -> Vec<usize> {
796        let box_edges = EdgeFiltered::from_fn(self.graph, |e| match e.weight() {
797            // Inheritance edges don't contribute to cycles;
798            // a type can't inherit from itself.
799            GraphEdge::Inherits { .. } => false,
800            GraphEdge::Contains => match self.graph[e.source()] {
801                GraphType::Schema(GraphSchemaType::Container(_, c))
802                | GraphType::Inline(GraphInlineType::Container(_, c)) => {
803                    // Array and map containers are heap-allocated,
804                    // cycles through these edges don't need `Box`.
805                    !matches!(c, GraphContainer::Array { .. } | GraphContainer::Map { .. })
806                }
807                _ => true,
808            },
809            _ => true,
810        });
811        let mut scc = TarjanScc::new();
812        scc.run(&box_edges, |_| ());
813        self.graph
814            .node_indices()
815            .map(|node| scc.node_component_index(&box_edges, node))
816            .collect()
817    }
818
819    fn hash_default(&self) -> HashDefault {
820        // Mark all leaf types that can't derive `Eq` / `Hash` or `Default`.
821        let n = self.graph.node_count();
822        let mut unhashable = FixedBitSet::with_capacity(n);
823        let mut undefaultable = FixedBitSet::with_capacity(n);
824        for node in self.graph.node_indices() {
825            use {GraphType::*, PrimitiveType::*};
826            match &self.graph[node] {
827                Schema(GraphSchemaType::Primitive(_, F32 | F64))
828                | Inline(GraphInlineType::Primitive(_, F32 | F64)) => {
829                    unhashable.insert(node.index());
830                }
831                Schema(
832                    GraphSchemaType::Primitive(_, Url)
833                    | GraphSchemaType::Tagged(_, _)
834                    | GraphSchemaType::Untagged(_, _),
835                )
836                | Inline(
837                    GraphInlineType::Primitive(_, Url)
838                    | GraphInlineType::Tagged(_, _)
839                    | GraphInlineType::Untagged(_, _),
840                ) => {
841                    undefaultable.insert(node.index());
842                }
843                _ => (),
844            }
845        }
846
847        // Compute the transitive closure over the inheritance subgraph.
848        let inherits = Closure::new(&EdgeFiltered::from_fn(self.graph, |e| {
849            matches!(e.weight(), GraphEdge::Inherits { .. })
850        }));
851
852        // Propagate unhashability backward, from leaves to roots.
853        //
854        // This is conservative: if a descendant overrides an inherited
855        // unhashable or undefaultable field with a different hashable or
856        // defaultable type, that descendant is still marked.
857        let mut queue: VecDeque<_> = unhashable.ones().map(NodeIndex::new).collect();
858        while let Some(node) = queue.pop_front() {
859            for edge in self.graph.edges_directed(node, Direction::Incoming) {
860                let source = edge.source();
861                match edge.weight() {
862                    GraphEdge::Contains | GraphEdge::Variant(_) => {
863                        if !unhashable.put(source.index()) {
864                            queue.push_back(source);
865                        }
866                    }
867                    GraphEdge::Field { .. } => {
868                        if !unhashable.put(source.index()) {
869                            queue.push_back(source);
870                        }
871                        // Every type that inherits from `source` also
872                        // inherits this unhashable field, so mark all
873                        // descendants of `source` as unhashable.
874                        for desc in inherits.dependents_of(source).filter(|&d| d != source) {
875                            if !unhashable.put(desc.index()) {
876                                queue.push_back(desc);
877                            }
878                        }
879                    }
880                    // Don't follow inheritance edges: a parent's intrinsic
881                    // unhashability (e.g., being a tagged union) doesn't
882                    // make its children unhashable, because children only
883                    // inherit the parent's fields, not its shape.
884                    GraphEdge::Inherits { .. } => {}
885                }
886            }
887        }
888
889        // Propagate undefaultability backward.
890        let mut queue: VecDeque<_> = undefaultable.ones().map(NodeIndex::new).collect();
891        while let Some(node) = queue.pop_front() {
892            for edge in self.graph.edges_directed(node, Direction::Incoming) {
893                if !matches!(
894                    edge.weight(),
895                    GraphEdge::Field { meta, .. } if meta.required
896                ) {
897                    // Optional fields become `AbsentOr<T>`,
898                    // which is always `Default`.
899                    continue;
900                }
901                let source = edge.source();
902                if !undefaultable.put(source.index()) {
903                    queue.push_back(source);
904                }
905                // Every type that inherits from `source` also
906                // inherits this undefaultable field, so mark all
907                // descendants of `source` as undefaultable.
908                for desc in inherits.dependents_of(source).filter(|&d| d != source) {
909                    if !undefaultable.put(desc.index()) {
910                        queue.push_back(desc);
911                    }
912                }
913            }
914        }
915
916        HashDefault {
917            hashable: invert(unhashable),
918            defaultable: invert(undefaultable),
919        }
920    }
921}
922
923/// Inverts every bit in the bitset.
924fn invert(mut bits: FixedBitSet) -> FixedBitSet {
925    bits.toggle_range(..);
926    bits
927}
928
929/// Visits all the types and references contained within a [`SpecType`].
930#[derive(Debug)]
931struct SpecTypeVisitor<'a> {
932    stack: Vec<(Option<(&'a SpecType<'a>, GraphEdge<'a>)>, &'a SpecType<'a>)>,
933}
934
935impl<'a> SpecTypeVisitor<'a> {
936    /// Creates a visitor with `roots` on the stack of types to visit.
937    #[inline]
938    fn new(roots: impl Iterator<Item = &'a SpecType<'a>>) -> Self {
939        let mut stack = roots.map(|root| (None, root)).collect_vec();
940        stack.reverse();
941        Self { stack }
942    }
943}
944
945impl<'a> Iterator for SpecTypeVisitor<'a> {
946    type Item = (Option<(&'a SpecType<'a>, GraphEdge<'a>)>, &'a SpecType<'a>);
947
948    fn next(&mut self) -> Option<Self::Item> {
949        let (parent, top) = self.stack.pop()?;
950        if matches!(
951            parent,
952            Some((
953                _,
954                GraphEdge::Variant(VariantMeta::Untagged(UntaggedVariantMeta::Null))
955            ))
956        ) {
957            // Unit variants form self-edges; skip them
958            // to avoid an infinite loop.
959            return Some((parent, top));
960        }
961        match top {
962            SpecType::Schema(SpecSchemaType::Struct(_, ty))
963            | SpecType::Inline(SpecInlineType::Struct(_, ty)) => {
964                self.stack.extend(
965                    itertools::chain!(
966                        ty.fields.iter().map(|field| (
967                            GraphEdge::Field {
968                                shadow: false,
969                                meta: FieldMeta {
970                                    name: field.name,
971                                    required: field.required,
972                                    description: field.description,
973                                    flattened: field.flattened,
974                                },
975                            },
976                            field.ty
977                        )),
978                        ty.parents
979                            .iter()
980                            .map(|parent| (GraphEdge::Inherits { shadow: false }, *parent)),
981                    )
982                    .map(|(edge, ty)| (Some((top, edge)), ty))
983                    .rev(),
984                );
985            }
986            SpecType::Schema(SpecSchemaType::Untagged(_, ty))
987            | SpecType::Inline(SpecInlineType::Untagged(_, ty)) => {
988                self.stack.extend(
989                    itertools::chain!(
990                        ty.fields.iter().map(|field| (
991                            GraphEdge::Field {
992                                shadow: false,
993                                meta: FieldMeta {
994                                    name: field.name,
995                                    required: field.required,
996                                    description: field.description,
997                                    flattened: field.flattened,
998                                },
999                            },
1000                            field.ty
1001                        )),
1002                        ty.variants.iter().map(|variant| match variant {
1003                            &SpecUntaggedVariant::Some(hint, ty) => {
1004                                let meta = UntaggedVariantMeta::Type { hint };
1005                                (GraphEdge::Variant(meta.into()), ty)
1006                            }
1007                            // `null` variants have no target type;
1008                            // we represent these variants as self-edges.
1009                            SpecUntaggedVariant::Null => {
1010                                (GraphEdge::Variant(UntaggedVariantMeta::Null.into()), top)
1011                            }
1012                        }),
1013                    )
1014                    .map(|(edge, ty)| (Some((top, edge)), ty))
1015                    .rev(),
1016                );
1017            }
1018            SpecType::Schema(SpecSchemaType::Tagged(_, ty))
1019            | SpecType::Inline(SpecInlineType::Tagged(_, ty)) => {
1020                self.stack.extend(
1021                    itertools::chain!(
1022                        ty.fields.iter().map(|field| (
1023                            GraphEdge::Field {
1024                                shadow: false,
1025                                meta: FieldMeta {
1026                                    name: field.name,
1027                                    required: field.required,
1028                                    description: field.description,
1029                                    flattened: field.flattened,
1030                                },
1031                            },
1032                            field.ty
1033                        )),
1034                        ty.variants.iter().map(|variant| (
1035                            GraphEdge::Variant(
1036                                TaggedVariantMeta {
1037                                    name: variant.name,
1038                                    aliases: variant.aliases,
1039                                }
1040                                .into()
1041                            ),
1042                            variant.ty
1043                        )),
1044                    )
1045                    .map(|(edge, ty)| (Some((top, edge)), ty))
1046                    .rev(),
1047                );
1048            }
1049            SpecType::Schema(SpecSchemaType::Container(_, container))
1050            | SpecType::Inline(SpecInlineType::Container(_, container)) => {
1051                self.stack
1052                    .push((Some((top, GraphEdge::Contains)), container.inner().ty));
1053            }
1054            SpecType::Schema(
1055                SpecSchemaType::Enum(..) | SpecSchemaType::Primitive(..) | SpecSchemaType::Any(_),
1056            )
1057            | SpecType::Inline(
1058                SpecInlineType::Enum(..) | SpecInlineType::Primitive(..) | SpecInlineType::Any(_),
1059            ) => (),
1060            SpecType::Ref(_) => (),
1061        }
1062        Some((parent, top))
1063    }
1064}
1065
1066/// A map that can store one value for each type.
1067pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
1068
1069pub trait Extension: Any + Send + Sync {
1070    fn into_inner(self: Box<Self>) -> Box<dyn Any>;
1071}
1072
1073impl dyn Extension {
1074    #[inline]
1075    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
1076        (self as &dyn Any).downcast_ref::<T>()
1077    }
1078}
1079
1080impl<T: Send + Sync + 'static> Extension for T {
1081    #[inline]
1082    fn into_inner(self: Box<Self>) -> Box<dyn Any> {
1083        self
1084    }
1085}
1086
1087/// Strongly connected components (SCCs) in topological order.
1088///
1089/// [`TopoSccs`] uses Tarjan's single-pass algorithm to find all SCCs,
1090/// and provides topological ordering, efficient membership testing, and
1091/// condensation for computing the transitive closure. These are
1092/// building blocks for cycle detection and dependency propagation.
1093struct TopoSccs<G> {
1094    graph: G,
1095    tarjan: TarjanScc<NodeIndex<usize>>,
1096    sccs: Vec<Vec<usize>>,
1097}
1098
1099impl<G> TopoSccs<G>
1100where
1101    G: Closable<NodeIndex<usize>> + Copy,
1102{
1103    fn new(graph: G) -> Self {
1104        let mut sccs = Vec::new();
1105        let mut tarjan = TarjanScc::new();
1106        tarjan.run(graph, |scc_nodes| {
1107            sccs.push(scc_nodes.iter().map(|node| node.index()).collect());
1108        });
1109        // Tarjan's algorithm returns SCCs in reverse topological order;
1110        // reverse them to get the topological order.
1111        sccs.reverse();
1112        Self {
1113            graph,
1114            tarjan,
1115            sccs,
1116        }
1117    }
1118
1119    #[inline]
1120    fn scc_count(&self) -> usize {
1121        self.sccs.len()
1122    }
1123
1124    /// Returns the topological index of the SCC that contains the given node.
1125    #[inline]
1126    fn topo_index(&self, node: NodeIndex<usize>) -> usize {
1127        // Tarjan's algorithm returns indices in reverse topological order;
1128        // inverting the component index gets us the topological index.
1129        self.sccs.len() - 1 - self.tarjan.node_component_index(self.graph, node)
1130    }
1131
1132    /// Builds a condensed DAG of SCCs.
1133    ///
1134    /// The condensed graph is represented as an adjacency list, where both
1135    /// the node indices and the neighbors of each node are stored in
1136    /// topological order. This specific ordering is required by
1137    /// [`tred::dag_transitive_reduction_closure`].
1138    fn condensation(&self) -> UnweightedList<usize> {
1139        let mut dag = UnweightedList::with_capacity(self.scc_count());
1140        for to in 0..self.scc_count() {
1141            dag.add_node();
1142            for neighbor in self.sccs[to].iter().flat_map(|&index| {
1143                self.graph
1144                    .neighbors_directed(NodeIndex::new(index), Direction::Incoming)
1145            }) {
1146                let from = self.topo_index(neighbor);
1147                if from != to {
1148                    dag.update_edge(from, to, ());
1149                }
1150            }
1151        }
1152        dag
1153    }
1154}
1155
1156/// The transitive closure of a graph.
1157#[derive(Debug)]
1158pub(super) struct Closure {
1159    /// Maps each node index to its SCC's topological index.
1160    scc_indices: Vec<usize>,
1161    /// Members of each SCC, indexed by topological SCC index.
1162    scc_members: Vec<Vec<usize>>,
1163    /// Maps each SCC to a list of all the SCCs that it transitively depends on,
1164    /// excluding itself.
1165    scc_deps: Vec<Vec<usize>>,
1166    /// Maps each SCC to a list of all the SCCs that transitively depend on it,
1167    /// excluding itself.
1168    scc_rdeps: Vec<Vec<usize>>,
1169}
1170
1171impl Closure {
1172    /// Computes the transitive closure of a graph.
1173    fn new<G>(graph: G) -> Self
1174    where
1175        G: Closable<NodeIndex<usize>> + Copy,
1176    {
1177        let sccs = TopoSccs::new(graph);
1178        let condensation = sccs.condensation();
1179        let (_, closure) = tred::dag_transitive_reduction_closure(&condensation);
1180
1181        // Build the forward and reverse adjacency lists
1182        // from the transitive closure graph.
1183        let scc_deps = (0..sccs.scc_count())
1184            .map(|scc| closure.neighbors(scc).collect_vec())
1185            .collect_vec();
1186        let mut scc_rdeps = vec![vec![]; sccs.scc_count()];
1187        for (scc, deps) in scc_deps.iter().enumerate() {
1188            for &dep in deps {
1189                scc_rdeps[dep].push(scc);
1190            }
1191        }
1192
1193        let mut scc_indices = vec![0; graph.node_count()];
1194        for node in graph.node_identifiers() {
1195            scc_indices[node.index()] = sccs.topo_index(node);
1196        }
1197
1198        Closure {
1199            scc_indices,
1200            scc_members: sccs.sccs.iter().cloned().collect_vec(),
1201            scc_deps,
1202            scc_rdeps,
1203        }
1204    }
1205
1206    /// Returns the topological SCC index for the given node.
1207    #[inline]
1208    pub fn scc_index_of(&self, node: NodeIndex<usize>) -> usize {
1209        self.scc_indices[node.index()]
1210    }
1211
1212    /// Iterates over all nodes that `node` transitively depends on,
1213    /// including `node` and all members of its SCC.
1214    pub fn dependencies_of(
1215        &self,
1216        node: NodeIndex<usize>,
1217    ) -> impl Iterator<Item = NodeIndex<usize>> {
1218        let scc = self.scc_index_of(node);
1219        std::iter::once(scc)
1220            .chain(self.scc_deps[scc].iter().copied())
1221            .flat_map(|s| self.scc_members[s].iter().copied()) // Expand SCCs to nodes.
1222            .map(NodeIndex::new)
1223    }
1224
1225    /// Iterates over all nodes that transitively depend on `node`,
1226    /// including `node` and all members of its SCC.
1227    pub fn dependents_of(&self, node: NodeIndex<usize>) -> impl Iterator<Item = NodeIndex<usize>> {
1228        let scc = self.scc_index_of(node);
1229        std::iter::once(scc)
1230            .chain(self.scc_rdeps[scc].iter().copied())
1231            .flat_map(|s| self.scc_members[s].iter().copied())
1232            .map(NodeIndex::new)
1233    }
1234
1235    /// Returns whether `node` transitively depends on `other`,
1236    /// or `false` when `node == other`.
1237    #[inline]
1238    pub fn depends_on(&self, node: NodeIndex<usize>, other: NodeIndex<usize>) -> bool {
1239        if node == other {
1240            return false;
1241        }
1242        let scc = self.scc_index_of(node);
1243        let other_scc = self.scc_index_of(other);
1244        scc == other_scc || self.scc_deps[scc].contains(&other_scc)
1245    }
1246}
1247
1248/// Trait requirements for computing a transitive closure.
1249trait Closable<N>:
1250    NodeCount
1251    + IntoNodeIdentifiers<NodeId = N>
1252    + IntoNeighbors<NodeId = N>
1253    + IntoNeighborsDirected<NodeId = N>
1254    + NodeIndexable<NodeId = N>
1255{
1256}
1257
1258impl<N, G> Closable<N> for G where
1259    G: NodeCount
1260        + IntoNodeIdentifiers<NodeId = N>
1261        + IntoNeighbors<NodeId = N>
1262        + IntoNeighborsDirected<NodeId = N>
1263        + NodeIndexable<NodeId = N>
1264{
1265}
1266
1267#[cfg(test)]
1268mod tests {
1269    use super::*;
1270
1271    use crate::tests::assert_matches;
1272
1273    /// Creates a simple graph: `A -> B -> C`.
1274    fn linear_graph() -> DiGraph<(), (), usize> {
1275        let mut g = DiGraph::default();
1276        let a = g.add_node(());
1277        let b = g.add_node(());
1278        let c = g.add_node(());
1279        g.extend_with_edges([(a, b), (b, c)]);
1280        g
1281    }
1282
1283    /// Creates a cyclic graph: `A -> B -> C -> A`, with `D -> A`.
1284    fn cyclic_graph() -> DiGraph<(), (), usize> {
1285        let mut g = DiGraph::default();
1286        let a = g.add_node(());
1287        let b = g.add_node(());
1288        let c = g.add_node(());
1289        let d = g.add_node(());
1290        g.extend_with_edges([(a, b), (b, c), (c, a), (d, a)]);
1291        g
1292    }
1293
1294    // MARK: SCC detection
1295
1296    #[test]
1297    fn test_linear_graph_has_singleton_sccs() {
1298        let g = linear_graph();
1299        let sccs = TopoSccs::new(&g);
1300        let sizes = sccs.sccs.iter().map(|scc| scc.len()).collect_vec();
1301        assert_matches!(&*sizes, [1, 1, 1]);
1302    }
1303
1304    #[test]
1305    fn test_cyclic_graph_has_one_multi_node_scc() {
1306        let g = cyclic_graph();
1307        let sccs = TopoSccs::new(&g);
1308
1309        // A-B-C form one SCC; D is its own SCC. Since D has an edge to
1310        // the cycle, D must precede the cycle in topological order.
1311        let sizes = sccs.sccs.iter().map(|scc| scc.len()).collect_vec();
1312        assert_matches!(&*sizes, [1, 3]);
1313    }
1314
1315    // MARK: Topological ordering
1316
1317    #[test]
1318    fn test_sccs_are_in_topological_order() {
1319        let g = cyclic_graph();
1320        let sccs = TopoSccs::new(&g);
1321
1322        let d_topo = sccs.topo_index(3.into());
1323        let a_topo = sccs.topo_index(0.into());
1324        assert!(
1325            d_topo < a_topo,
1326            "D should precede A-B-C in topological order"
1327        );
1328    }
1329
1330    #[test]
1331    fn test_topo_index_consistent_within_scc() {
1332        let g = cyclic_graph();
1333        let sccs = TopoSccs::new(&g);
1334
1335        // A, B, C are in the same SCC, so they should have
1336        // the same topological index.
1337        let a_topo = sccs.topo_index(0.into());
1338        let b_topo = sccs.topo_index(1.into());
1339        let c_topo = sccs.topo_index(2.into());
1340
1341        assert_eq!(a_topo, b_topo);
1342        assert_eq!(b_topo, c_topo);
1343    }
1344
1345    // MARK: Condensation
1346
1347    #[test]
1348    fn test_condensation_has_correct_node_count() {
1349        let g = cyclic_graph();
1350        let sccs = TopoSccs::new(&g);
1351        let dag = sccs.condensation();
1352
1353        assert_eq!(dag.node_count(), 2);
1354    }
1355
1356    #[test]
1357    fn test_condensation_has_correct_edges() {
1358        let g = cyclic_graph();
1359        let sccs = TopoSccs::new(&g);
1360        let dag = sccs.condensation();
1361
1362        // D should have an edge to the A-B-C SCC, and
1363        // A-B-C shouldn't create a self-loop.
1364        let d_topo = sccs.topo_index(3.into());
1365        let abc_topo = sccs.topo_index(0.into());
1366
1367        let d_neighbors = dag.neighbors(d_topo).collect_vec();
1368        assert_eq!(&*d_neighbors, [abc_topo]);
1369
1370        let abc_neighbors = dag.neighbors(abc_topo).collect_vec();
1371        assert!(abc_neighbors.is_empty());
1372    }
1373
1374    #[test]
1375    fn test_condensation_neighbors_in_topological_order() {
1376        // Matches Petgraph's `dag_to_toposorted_adjacency_list` example:
1377        // edges added as `(top, second), (top, first)`, but neighbors should be
1378        // `[first, second]` in topological order.
1379        let mut g = DiGraph::<(), (), usize>::default();
1380        let second = g.add_node(());
1381        let top = g.add_node(());
1382        let first = g.add_node(());
1383        g.extend_with_edges([(top, second), (top, first), (first, second)]);
1384
1385        let sccs = TopoSccs::new(&g);
1386        let dag = sccs.condensation();
1387
1388        let top_topo = sccs.topo_index(top);
1389        assert_eq!(top_topo, 0);
1390
1391        let first_topo = sccs.topo_index(first);
1392        assert_eq!(first_topo, 1);
1393
1394        let second_topo = sccs.topo_index(second);
1395        assert_eq!(second_topo, 2);
1396
1397        let neighbors = dag.neighbors(top_topo).collect_vec();
1398        assert_eq!(&*neighbors, [first_topo, second_topo]);
1399    }
1400}