Skip to main content

ploidy_core/ir/
graph.rs

1use std::{
2    any::{Any, TypeId},
3    collections::VecDeque,
4    fmt::Debug,
5};
6
7use atomic_refcell::AtomicRefCell;
8use enum_map::{Enum, EnumMap, enum_map};
9use fixedbitset::FixedBitSet;
10use itertools::Itertools;
11use petgraph::{
12    Direction,
13    adj::UnweightedList,
14    algo::{TarjanScc, tred},
15    data::Build,
16    graph::{DiGraph, NodeIndex},
17    stable_graph::StableDiGraph,
18    visit::{DfsPostOrder, EdgeFiltered, EdgeRef, IntoNeighbors, VisitMap, Visitable},
19};
20use rustc_hash::{FxHashMap, FxHashSet};
21
22use crate::{arena::Arena, parse::Info};
23
24use super::{
25    spec::{ResolvedSpecType, Spec},
26    types::{
27        GraphInlineType, GraphOperation, GraphSchemaType, GraphStruct, GraphTagged,
28        GraphTaggedVariant, GraphType, InlineTypePath, InlineTypePathRoot, InlineTypePathSegment,
29        SchemaTypeInfo, SpecInlineType, SpecSchemaType, SpecType, SpecUntaggedVariant,
30        StructFieldName, mapper::TypeMapper,
31    },
32    views::{operation::OperationView, primitive::PrimitiveView, schema::SchemaTypeView},
33};
34
35/// The mutable, sparse graph used for transformations.
36type RawDiGraph<'a> = StableDiGraph<GraphType<'a>, EdgeKind, usize>;
37
38/// The immutable, dense graph used for code generation.
39type CookedDiGraph<'a> = DiGraph<GraphType<'a>, EdgeKind, usize>;
40
41/// A mutable intermediate dependency graph of all the types in a [`Spec`],
42/// backed by a sparse [`StableDiGraph`].
43///
44/// This graph is constructed directly from a [`Spec`], and represents
45/// type relationships as they exist in the spec. Transformations like
46/// [`inline_tagged_variants`][Self::inline_tagged_variants] rewrite this graph
47/// in place.
48///
49/// After applying all transformations, call [`cook`][Self::cook] to
50/// turn this graph into a [`CookedGraph`] that's ready for code generation.
51#[derive(Debug)]
52pub struct RawGraph<'a> {
53    arena: &'a Arena,
54    spec: &'a Spec<'a>,
55    graph: RawDiGraph<'a>,
56    ops: &'a [&'a GraphOperation<'a>],
57}
58
59impl<'a> RawGraph<'a> {
60    pub fn new(arena: &'a Arena, spec: &'a Spec<'a>) -> Self {
61        // All roots (named schemas, parameters, request and response bodies),
62        // and all the types within them (inline schemas and primitives).
63        let tys = SpecTypeVisitor::new(
64            spec.schemas
65                .values()
66                .chain(spec.operations.iter().flat_map(|op| op.types().copied())),
67        );
68
69        // Build the nodes and edges.
70        let mut indices = FxHashMap::default();
71        let mut schemas = FxHashMap::default();
72        let mut nodes = vec![];
73        let mut edges = vec![];
74        for (parent, kind, child) in tys {
75            use std::collections::hash_map::Entry;
76            let source = spec.resolve(child);
77            let &mut to = match indices.entry(source) {
78                Entry::Occupied(entry) => entry.into_mut(),
79                Entry::Vacant(entry) => {
80                    // We might see the same schema multiple times if it's
81                    // referenced multiple times in the spec. Only add
82                    // a new node for the schema if we haven't seen it before.
83                    let index = NodeIndex::new(nodes.len());
84                    nodes.push(*entry.key());
85                    entry.insert(index)
86                }
87            };
88            // Track schema names for later lookup.
89            if let ResolvedSpecType::Schema(ty) = source {
90                schemas.entry(ty.name()).or_insert(to);
91            }
92            if let Some(parent) = parent {
93                let destination = spec.resolve(parent);
94                let &mut from = match indices.entry(destination) {
95                    Entry::Occupied(entry) => entry.into_mut(),
96                    Entry::Vacant(entry) => {
97                        let index = NodeIndex::new(nodes.len());
98                        nodes.push(*entry.key());
99                        entry.insert(index)
100                    }
101                };
102                if let ResolvedSpecType::Schema(ty) = destination {
103                    schemas.entry(ty.name()).or_insert(from);
104                }
105                edges.push((from, to, kind));
106            }
107        }
108
109        // Construct a graph from the nodes and edges,
110        // mapping schema type references to graph indices.
111        let mut graph = RawDiGraph::with_capacity(nodes.len(), edges.len());
112        let mapper = TypeMapper::new(arena, |ty: &SpecType<'_>| match ty {
113            SpecType::Schema(s) => indices[&ResolvedSpecType::Schema(s)],
114            SpecType::Inline(i) => indices[&ResolvedSpecType::Inline(i)],
115            SpecType::Ref(r) => schemas[&*r.name()],
116        });
117        for node in nodes {
118            let mapped = match node {
119                ResolvedSpecType::Schema(ty) => GraphType::Schema(mapper.schema(ty)),
120                ResolvedSpecType::Inline(ty) => GraphType::Inline(mapper.inline(ty)),
121            };
122            let index = graph.add_node(mapped);
123            debug_assert_eq!(index, indices[&node]);
124        }
125        for (from, to, kind) in edges {
126            graph.add_edge(from, to, kind);
127        }
128
129        // Map schema type references in operations.
130        let ops = arena.alloc_slice_exact(spec.operations.iter().map(|op| mapper.operation(op)));
131
132        Self {
133            arena,
134            spec,
135            graph,
136            ops,
137        }
138    }
139
140    /// Inlines schema types used as variants of multiple tagged unions
141    /// with different tags.
142    ///
143    /// In OpenAPI's model of tagged unions, the tag always references a field
144    /// that's defined on each struct variant. This model works well for Python
145    /// and TypeScript, but not Rust; Serde doesn't allow struct variants to
146    /// declare fields with the same name as the tag. The Rust generator
147    /// excludes tag fields when generating structs, but this introduces a
148    /// new problem: a struct can't appear a variant of multiple unions
149    /// with different tags [^1].
150    ///
151    /// This transformation finds and inlines these structs, so that
152    /// the Rust generator can safely omit their tag fields.
153    ///
154    /// [^1]: If struct A has fields `foo` and `bar`, A is a variant of
155    /// tagged unions C and D, C's tag is `foo`, and D's tag is `bar`...
156    /// only `foo` should be excluded when A is used in C, and only `bar`
157    /// should be excluded when A is used in D; but this can't be modeled
158    /// in Serde without splitting A into two distinct types.
159    pub fn inline_tagged_variants(&mut self) -> &mut Self {
160        struct TaggedPlan<'a> {
161            tagged_index: NodeIndex<usize>,
162            info: SchemaTypeInfo<'a>,
163            tagged: GraphTagged<'a>,
164            inlines: Vec<VariantInline<'a>>,
165        }
166        struct VariantInline<'a> {
167            node: GraphType<'a>,
168            variant_index: NodeIndex<usize>,
169            parent_indices: &'a [NodeIndex<usize>],
170            name: &'a str,
171            aliases: &'a [&'a str],
172        }
173
174        // Compute the set of types used (as query params, request and response
175        // bodies, etc.) by operations. Operations don't create graph edges,
176        // but still need to be considered when deciding whether to inline a
177        // struct variant. Otherwise, a struct that's used by same-tag unions
178        // _and_ an operation wouldn't be inlined, causing the Rust generator to
179        // incorrectly exclude the tag field from the struct.
180        let used_by_ops: FixedBitSet = self
181            .ops
182            .iter()
183            .flat_map(|op| op.types())
184            .map(|node| node.index())
185            .collect();
186
187        // Collect all inlining decisions before mutating the graph,
188        // so that we can check inlinability per variant.
189        let plans = self
190            .graph
191            .node_indices()
192            .filter_map(|index| {
193                let GraphType::Schema(GraphSchemaType::Tagged(info, tagged)) = self.graph[index]
194                else {
195                    return None;
196                };
197                let mut inlines = vec![];
198
199                for variant in tagged.variants {
200                    let variant_index = variant.ty;
201                    let GraphType::Schema(GraphSchemaType::Struct(variant_info, variant_struct)) =
202                        self.graph[variant_index]
203                    else {
204                        continue;
205                    };
206
207                    // A struct variant only needs inlining if it has multiple
208                    // distinct uses. Skip if (1) no operation uses the struct,
209                    // _and_ (2) every incoming edge is from a tagged union with
210                    // the same tag and fields. If both hold, all uses agree, so
211                    // the struct can be used directly without inlining.
212                    if !used_by_ops.contains(variant_index.index()) {
213                        let Some(first) = ({
214                            self.graph
215                                .neighbors_directed(variant_index, Direction::Incoming)
216                                .find_map(|neighbor| match self.graph[neighbor] {
217                                    GraphType::Schema(GraphSchemaType::Tagged(_, t)) => Some(t),
218                                    _ => None,
219                                })
220                        }) else {
221                            continue;
222                        };
223                        // Check that all the variant's inbound edges are from
224                        // tagged unions, and that all their tags and fields
225                        // match the first union we found.
226                        let all_agree = self
227                            .graph
228                            .neighbors_directed(variant_index, Direction::Incoming)
229                            .all(|neighbor| {
230                                matches!(
231                                    self.graph[neighbor],
232                                    GraphType::Schema(GraphSchemaType::Tagged(_, t))
233                                        if t.tag == first.tag && t.fields == first.fields,
234                                )
235                            });
236                        if all_agree {
237                            continue;
238                        }
239                    }
240
241                    // Skip inlining when the inline copy would be
242                    // identical to the original. This happens when
243                    // the variant doesn't declare the tag as a field _and_
244                    // either (a) the union has no own fields, or
245                    // (b) the variant already inherits from this union,
246                    // so its fields are already reachable.
247                    let (has_tag_field, already_inherits) = {
248                        let inherits = EdgeFiltered::from_fn(&self.graph, |e| {
249                            matches!(e.weight(), EdgeKind::Inherits)
250                        });
251                        let mut dfs = DfsPostOrder::new(&inherits, variant_index);
252                        let mut has_tag_field = false;
253                        let mut already_inherits = false;
254                        while let Some(ancestor) = dfs.next(&inherits)
255                            && !(has_tag_field && already_inherits)
256                        {
257                            already_inherits |= ancestor == index;
258                            has_tag_field |= match self.graph[ancestor] {
259                                GraphType::Schema(GraphSchemaType::Struct(_, s))
260                                | GraphType::Inline(GraphInlineType::Struct(_, s)) => {
261                                    // Check own and inherited fields; OpenAPI 3.2
262                                    // clarifies that the tag can be inherited.
263                                    s.fields.iter().any(|f| {
264                                        matches!(
265                                            f.name,
266                                            StructFieldName::Name(n) if n == tagged.tag,
267                                        )
268                                    })
269                                }
270                                _ => false,
271                            };
272                        }
273                        (has_tag_field, already_inherits)
274                    };
275                    if !has_tag_field && (tagged.fields.is_empty() || already_inherits) {
276                        continue;
277                    }
278
279                    // Build our new inline type, with the same attributes
280                    // as the schema type, but a distinct inline type path.
281                    // The inline struct is a clone of the original, plus
282                    // an inheritance edge to the tagged union for its fields.
283                    let parents = self.arena.alloc_slice(itertools::chain!(
284                        variant_struct.parents.iter().copied(),
285                        std::iter::once(index),
286                    ));
287                    let node = GraphType::Inline(GraphInlineType::Struct(
288                        InlineTypePath {
289                            root: InlineTypePathRoot::Type(info.name),
290                            segments: self.arena.alloc_slice_copy(&[
291                                InlineTypePathSegment::TaggedVariant(variant_info.name),
292                            ]),
293                        },
294                        GraphStruct {
295                            description: variant_struct.description,
296                            fields: variant_struct.fields,
297                            parents,
298                        },
299                    ));
300
301                    inlines.push(VariantInline {
302                        node,
303                        variant_index,
304                        parent_indices: parents,
305                        name: variant.name,
306                        aliases: variant.aliases,
307                    });
308                }
309                if inlines.is_empty() {
310                    return None;
311                }
312
313                Some(TaggedPlan {
314                    tagged_index: index,
315                    info,
316                    tagged,
317                    inlines,
318                })
319            })
320            .collect_vec();
321
322        // Apply the plans to the graph.
323        for plan in plans {
324            let mut new_variants = FxHashMap::default();
325
326            // Add nodes and edges for the inline types.
327            for entry in &plan.inlines {
328                let node_index = self.graph.add_node(entry.node);
329
330                // Reference the original variant so that the inline
331                // inherits the original's transitive dependencies and
332                // SCC membership, but not its inline subtree.
333                self.graph
334                    .add_edge(node_index, entry.variant_index, EdgeKind::Reference);
335
336                // Add inheritance edges back to the inline's parents.
337                for &parent_index in entry.parent_indices {
338                    self.graph
339                        .add_edge(node_index, parent_index, EdgeKind::Inherits);
340                }
341
342                new_variants.insert(
343                    entry.variant_index,
344                    GraphTaggedVariant {
345                        name: entry.name,
346                        aliases: entry.aliases,
347                        ty: node_index,
348                    },
349                );
350            }
351
352            // Retarget reference edges from the tagged union to point to
353            // the new inline variants. We only update edges targeting a
354            // replaced variant; other edges stay.
355            let edges_to_retarget = self
356                .graph
357                .edges_directed(plan.tagged_index, Direction::Outgoing)
358                .filter(|e| {
359                    matches!(e.weight(), EdgeKind::Reference)
360                        && new_variants.contains_key(&e.target())
361                })
362                .map(|e| (e.id(), new_variants[&e.target()].ty))
363                .collect_vec();
364            for (edge_id, new_target) in edges_to_retarget {
365                self.graph.remove_edge(edge_id);
366                self.graph
367                    .add_edge(plan.tagged_index, new_target, EdgeKind::Reference);
368            }
369
370            // Replace the node for the tagged union itself.
371            let modified_tagged = GraphTagged {
372                description: plan.tagged.description,
373                tag: plan.tagged.tag,
374                variants: self.arena.alloc_slice_exact(
375                    plan.tagged
376                        .variants
377                        .iter()
378                        .map(|&v| new_variants.get(&v.ty).copied().unwrap_or(v)),
379                ),
380                fields: plan.tagged.fields,
381            };
382            self.graph[plan.tagged_index] =
383                GraphType::Schema(GraphSchemaType::Tagged(plan.info, modified_tagged));
384        }
385
386        self
387    }
388
389    /// Builds an immutable [`CookedGraph`] from this mutable raw graph.
390    #[inline]
391    pub fn cook(&self) -> CookedGraph<'a> {
392        CookedGraph::new(self)
393    }
394}
395
396/// The final dependency graph of all the types in a [`Spec`],
397/// backed by a dense [`DiGraph`].
398///
399/// This graph has all transformations applied, and is ready for
400/// code generation.
401#[derive(Debug)]
402pub struct CookedGraph<'a> {
403    pub(super) graph: CookedDiGraph<'a>,
404    info: &'a Info,
405    ops: &'a [&'a GraphOperation<'a>],
406    /// Additional metadata for each node.
407    pub(super) metadata: CookedGraphMetadata<'a>,
408}
409
410impl<'a> CookedGraph<'a> {
411    fn new(raw: &RawGraph<'a>) -> Self {
412        // Assign a cooked node index to each raw node index.
413        let indices: FxHashMap<_, _> = raw
414            .graph
415            .node_indices()
416            .enumerate()
417            .map(|(cooked, raw)| (raw, NodeIndex::new(cooked)))
418            .collect();
419
420        // Map sparse graph indices to dense cooked indices.
421        let mapper = TypeMapper::new(raw.arena, |index| indices[&index]);
422
423        // Build a dense graph.
424        let mut graph = CookedDiGraph::with_capacity(indices.len(), raw.graph.edge_count());
425        for index in raw.graph.node_indices() {
426            let node = raw.graph[index];
427            let mapped = match node {
428                GraphType::Schema(ty) => GraphType::Schema(mapper.schema(&ty)),
429                GraphType::Inline(ty) => GraphType::Inline(mapper.inline(&ty)),
430            };
431            let cooked = graph.add_node(mapped);
432            debug_assert_eq!(indices[&index], cooked);
433        }
434
435        // Add edges, preserving original insertion order.
436        for index in raw.graph.node_indices() {
437            let from = indices[&index];
438            // `RawDiGraph::edges` yields edges in reverse insertion
439            // order; collect and reverse to preserve the original order.
440            let edges = raw
441                .graph
442                .edges(index)
443                .map(|e| (indices[&e.target()], *e.weight()))
444                .collect_vec();
445            for (to, kind) in edges.into_iter().rev() {
446                graph.add_edge(from, to, kind);
447            }
448        }
449
450        let sccs = TopoSccs::new(&graph);
451
452        let (metadata, ops) = {
453            let mut metadata = CookedGraphMetadata {
454                scc_indices: {
455                    // Precompute SCC indices, using just the reference edges.
456                    // Inheritance edges don't contribute to cycles.
457                    let refs = EdgeFiltered::from_fn(&graph, |e| {
458                        matches!(e.weight(), EdgeKind::Reference)
459                    });
460                    let mut scc = TarjanScc::new();
461                    scc.run(&refs, |_| ());
462                    graph
463                        .node_indices()
464                        .map(|node| scc.node_component_index(&refs, node))
465                        .collect()
466                },
467                // `GraphNodeMeta` can't implement `Clone` because it contains
468                // an `AtomicRefCell`, so we use this idiom instead of `vec!`.
469                schemas: std::iter::repeat_with(GraphNodeMeta::default)
470                    .take(graph.node_count())
471                    .collect(),
472                operations: FxHashMap::default(),
473            };
474
475            // Remap schema type references in operations.
476            let ops: &_ = raw
477                .arena
478                .alloc_slice_exact(raw.ops.iter().map(|&op| mapper.operation(op)));
479
480            // Precompute the set of type indices that each operation
481            // references directly.
482            for &&op in ops {
483                metadata.operations.entry(op).or_default().types =
484                    op.types().map(|node| node.index()).collect();
485            }
486
487            // Forward propagation: for each type, compute all the types
488            // that it depends on, directly and transitively.
489            {
490                // Condense each of the original graph's strongly connected components
491                // into a single node, forming a DAG.
492                let condensation = sccs.condensation();
493
494                // Compute the transitive closure; discard the reduction.
495                let (_, closure) = tred::dag_transitive_reduction_closure(&condensation);
496
497                // Compute dependencies between SCCs.
498                let mut scc_deps =
499                    vec![FixedBitSet::with_capacity(graph.node_count()); sccs.scc_count()];
500                for (scc, deps) in scc_deps.iter_mut().enumerate() {
501                    // Include the SCC itself, so that cycle members appear
502                    // in each other's dependencies; and its
503                    // transitive neighbors.
504                    deps.extend(
505                        std::iter::once(scc)
506                            .chain(closure.neighbors(scc))
507                            .flat_map(|scc| sccs.sccs[scc].ones()),
508                    );
509                }
510
511                // Expand SCC dependencies to node dependencies, and
512                // transpose dependencies to build dependents.
513                let mut node_dependents =
514                    vec![FixedBitSet::with_capacity(graph.node_count()); graph.node_count()];
515                for node in graph.node_indices() {
516                    let mut deps = scc_deps[sccs.topo_index(node)].clone();
517                    deps.remove(node.index()); // We don't depend on ourselves.
518                    for dep in deps.ones() {
519                        node_dependents[dep].insert(node.index());
520                    }
521                    metadata.schemas[node.index()].dependencies = deps;
522                }
523                for (index, dependents) in node_dependents.into_iter().enumerate() {
524                    metadata.schemas[index].dependents = dependents;
525                }
526            }
527
528            // Backward propagation: propagate each operation to all the
529            // types that it uses, directly and transitively.
530            for &&op in ops {
531                let meta = &metadata.operations[&op];
532
533                // Collect all the types that this operation depends on.
534                let mut transitive_deps = FixedBitSet::with_capacity(graph.node_count());
535                for node in meta.types.ones() {
536                    transitive_deps.insert(node);
537                    transitive_deps.union_with(&metadata.schemas[node].dependencies);
538                }
539
540                // Mark each type as being used by this operation.
541                for node in transitive_deps.ones() {
542                    metadata.schemas[node].used_by.insert(op);
543                }
544            }
545
546            (metadata, ops)
547        };
548
549        Self {
550            graph,
551            info: raw.spec.info,
552            ops,
553            metadata,
554        }
555    }
556
557    /// Returns [`Info`] from the [`Document`][crate::parse::Document]
558    /// used to build this graph.
559    #[inline]
560    pub fn info(&self) -> &'a Info {
561        self.info
562    }
563
564    /// Returns an iterator over all the named schemas in this graph.
565    #[inline]
566    pub fn schemas(&self) -> impl Iterator<Item = SchemaTypeView<'_>> {
567        self.graph
568            .node_indices()
569            .filter_map(|index| match self.graph[index] {
570                GraphType::Schema(ty) => Some(SchemaTypeView::new(self, index, ty)),
571                _ => None,
572            })
573    }
574
575    /// Returns an iterator over all primitive type nodes in this graph.
576    #[inline]
577    pub fn primitives(&self) -> impl Iterator<Item = PrimitiveView<'_>> {
578        self.graph
579            .node_indices()
580            .filter_map(|index| match self.graph[index] {
581                GraphType::Schema(GraphSchemaType::Primitive(_, p))
582                | GraphType::Inline(GraphInlineType::Primitive(_, p)) => {
583                    Some(PrimitiveView::new(self, index, p))
584                }
585                _ => None,
586            })
587    }
588
589    /// Returns an iterator over all the operations in this graph.
590    #[inline]
591    pub fn operations(&self) -> impl Iterator<Item = OperationView<'_>> {
592        self.ops.iter().map(move |&op| OperationView::new(self, op))
593    }
594}
595
596/// An edge between two types in the type graph.
597#[derive(Clone, Copy, Debug, Enum, Eq, Hash, PartialEq)]
598pub enum EdgeKind {
599    /// The source type contains or references the target type.
600    Reference,
601    /// The source type inherits from the target type.
602    Inherits,
603}
604
605/// Precomputed metadata for schema types and operations in the graph.
606#[derive(Debug, Default)]
607pub(super) struct CookedGraphMetadata<'a> {
608    /// Maps each node index to its strongly connected component index.
609    /// Nodes in the same SCC form a cycle.
610    pub scc_indices: Vec<usize>,
611    pub schemas: Vec<GraphNodeMeta<'a>>,
612    pub operations: FxHashMap<GraphOperation<'a>, GraphOperationMeta>,
613}
614
615/// Precomputed metadata for an operation that references
616/// types in the graph.
617#[derive(Debug, Default)]
618pub(super) struct GraphOperationMeta {
619    /// Indices of all the types that this operation directly depends on:
620    /// parameters, request body, and response body.
621    pub types: FixedBitSet,
622}
623
624/// Precomputed metadata for a schema type in the graph.
625#[derive(Default)]
626pub(super) struct GraphNodeMeta<'a> {
627    /// Operations that use this type.
628    pub used_by: FxHashSet<GraphOperation<'a>>,
629    /// Indices of other types that this type transitively depends on.
630    pub dependencies: FixedBitSet,
631    /// Indices of other types that transitively depend on this type.
632    pub dependents: FixedBitSet,
633    /// Opaque extended data for this type.
634    pub extensions: AtomicRefCell<ExtensionMap>,
635}
636
637impl Debug for GraphNodeMeta<'_> {
638    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.debug_struct("GraphNodeMeta")
640            .field("used_by", &self.used_by)
641            .field("dependencies", &self.dependencies)
642            .field("dependents", &self.dependents)
643            .finish_non_exhaustive()
644    }
645}
646
647/// Visits all the types and references contained within a [`SpecType`].
648#[derive(Debug)]
649struct SpecTypeVisitor<'a> {
650    stack: Vec<(Option<&'a SpecType<'a>>, EdgeKind, &'a SpecType<'a>)>,
651}
652
653impl<'a> SpecTypeVisitor<'a> {
654    /// Creates a visitor with `roots` on the stack of types to visit.
655    #[inline]
656    fn new(roots: impl Iterator<Item = &'a SpecType<'a>>) -> Self {
657        let mut stack = roots
658            .map(|root| (None, EdgeKind::Reference, root))
659            .collect_vec();
660        stack.reverse();
661        Self { stack }
662    }
663}
664
665impl<'a> Iterator for SpecTypeVisitor<'a> {
666    type Item = (Option<&'a SpecType<'a>>, EdgeKind, &'a SpecType<'a>);
667
668    fn next(&mut self) -> Option<Self::Item> {
669        let (parent, kind, top) = self.stack.pop()?;
670        match top {
671            SpecType::Schema(SpecSchemaType::Struct(_, ty))
672            | SpecType::Inline(SpecInlineType::Struct(_, ty)) => {
673                self.stack.extend(
674                    itertools::chain!(
675                        ty.fields
676                            .iter()
677                            .map(|field| (EdgeKind::Reference, field.ty)),
678                        ty.parents
679                            .iter()
680                            .map(|parent| (EdgeKind::Inherits, *parent)),
681                    )
682                    .map(|(kind, ty)| (Some(top), kind, ty))
683                    .rev(),
684                );
685            }
686            SpecType::Schema(SpecSchemaType::Untagged(_, ty))
687            | SpecType::Inline(SpecInlineType::Untagged(_, ty)) => {
688                self.stack.extend(
689                    itertools::chain!(
690                        ty.fields
691                            .iter()
692                            .map(|field| (EdgeKind::Reference, field.ty)),
693                        ty.variants.iter().filter_map(|variant| match variant {
694                            SpecUntaggedVariant::Some(_, ty) => {
695                                Some((EdgeKind::Reference, *ty))
696                            }
697                            _ => None,
698                        }),
699                    )
700                    .map(|(kind, ty)| (Some(top), kind, ty))
701                    .rev(),
702                );
703            }
704            SpecType::Schema(SpecSchemaType::Tagged(_, ty))
705            | SpecType::Inline(SpecInlineType::Tagged(_, ty)) => {
706                self.stack.extend(
707                    itertools::chain!(
708                        ty.fields
709                            .iter()
710                            .map(|field| (EdgeKind::Reference, field.ty)),
711                        ty.variants
712                            .iter()
713                            .map(|variant| (EdgeKind::Reference, variant.ty)),
714                    )
715                    .map(|(kind, ty)| (Some(top), kind, ty))
716                    .rev(),
717                );
718            }
719            SpecType::Schema(SpecSchemaType::Container(_, container))
720            | SpecType::Inline(SpecInlineType::Container(_, container)) => {
721                self.stack
722                    .push((Some(top), EdgeKind::Reference, container.inner().ty));
723            }
724            SpecType::Schema(
725                SpecSchemaType::Enum(..) | SpecSchemaType::Primitive(..) | SpecSchemaType::Any(_),
726            )
727            | SpecType::Inline(
728                SpecInlineType::Enum(..) | SpecInlineType::Primitive(..) | SpecInlineType::Any(_),
729            ) => (),
730            SpecType::Ref(_) => (),
731        }
732        Some((parent, kind, top))
733    }
734}
735
736/// A map that can store one value for each type.
737pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
738
739pub trait Extension: Any + Send + Sync {
740    fn into_inner(self: Box<Self>) -> Box<dyn Any>;
741}
742
743impl dyn Extension {
744    #[inline]
745    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
746        (self as &dyn Any).downcast_ref::<T>()
747    }
748}
749
750impl<T: Send + Sync + 'static> Extension for T {
751    #[inline]
752    fn into_inner(self: Box<Self>) -> Box<dyn Any> {
753        self
754    }
755}
756
757/// Strongly connected components (SCCs) in topological order.
758///
759/// [`TopoSccs`] uses Tarjan's single-pass algorithm to find all SCCs,
760/// and provides topological ordering, efficient membership testing, and
761/// condensation for computing the transitive closure. These are
762/// building blocks for cycle detection and dependency propagation.
763struct TopoSccs<'a, N, E> {
764    graph: &'a DiGraph<N, E, usize>,
765    tarjan: TarjanScc<NodeIndex<usize>>,
766    sccs: Vec<FixedBitSet>,
767}
768
769impl<'a, N, E> TopoSccs<'a, N, E> {
770    fn new(graph: &'a DiGraph<N, E, usize>) -> Self {
771        let mut sccs = Vec::new();
772        let mut tarjan = TarjanScc::new();
773        tarjan.run(graph, |scc_nodes| {
774            sccs.push(scc_nodes.iter().map(|node| node.index()).collect());
775        });
776        // Tarjan's algorithm returns SCCs in reverse topological order;
777        // reverse them to get the topological order.
778        sccs.reverse();
779        Self {
780            graph,
781            tarjan,
782            sccs,
783        }
784    }
785
786    #[inline]
787    fn scc_count(&self) -> usize {
788        self.sccs.len()
789    }
790
791    /// Returns the topological index of the SCC that contains the given node.
792    #[inline]
793    fn topo_index(&self, node: NodeIndex<usize>) -> usize {
794        // Tarjan's algorithm returns indices in reverse topological order;
795        // inverting the component index gets us the topological index.
796        self.sccs.len() - 1 - self.tarjan.node_component_index(self.graph, node)
797    }
798
799    /// Iterates over the SCCs in topological order.
800    #[cfg(test)]
801    fn iter(&self) -> std::slice::Iter<'_, FixedBitSet> {
802        self.sccs.iter()
803    }
804
805    /// Builds a condensed DAG of SCCs.
806    ///
807    /// The condensed graph is represented as an adjacency list, where both
808    /// the node indices and the neighbors of each node are stored in
809    /// topological order. This specific ordering is required by
810    /// [`tred::dag_transitive_reduction_closure`].
811    fn condensation(&self) -> UnweightedList<usize> {
812        let mut dag = UnweightedList::with_capacity(self.scc_count());
813        for to in 0..self.scc_count() {
814            dag.add_node();
815            for index in self.sccs[to].ones().map(NodeIndex::new) {
816                for neighbor in self.graph.neighbors_directed(index, Direction::Incoming) {
817                    let from = self.topo_index(neighbor);
818                    if from != to {
819                        dag.update_edge(from, to, ());
820                    }
821                }
822            }
823        }
824        dag
825    }
826}
827
828/// Controls how to continue traversing the graph when at a node.
829#[derive(Clone, Copy, Debug, Eq, PartialEq)]
830pub enum Traversal {
831    /// Yield this node, then explore its neighbors.
832    Visit,
833    /// Yield this node, but skip its neighbors.
834    Stop,
835    /// Don't yield this node, but explore its neighbors.
836    Skip,
837    /// Don't yield this node, and skip its neighbors.
838    Ignore,
839}
840
841/// Edge-kind-aware breadth-first traversal of the type graph.
842///
843/// [`Traverse`] tracks discovered nodes separately per [`EdgeKind`],
844/// so a node that's reachable via both reference and inheritance edges
845/// is visited once for each kind.
846///
847/// Use [`Traverse::run`] with a filter to control which nodes are
848/// yielded and explored.
849pub struct Traverse<'a> {
850    graph: &'a CookedDiGraph<'a>,
851    stack: VecDeque<(EdgeKind, NodeIndex<usize>)>,
852    discovered: EnumMap<EdgeKind, FixedBitSet>,
853    direction: Direction,
854}
855
856impl<'a> Traverse<'a> {
857    /// Starts a breadth-first traversal at a `root` node,
858    /// including `root` in the traversal.
859    pub fn at_root(
860        graph: &'a CookedDiGraph<'a>,
861        root: NodeIndex<usize>,
862        direction: Direction,
863    ) -> Self {
864        let mut discovered = enum_map!(_ => graph.visit_map());
865        discovered[EdgeKind::Reference].grow_and_insert(root.index());
866        Self {
867            graph,
868            stack: VecDeque::from([(EdgeKind::Reference, root)]),
869            discovered,
870            direction,
871        }
872    }
873
874    /// Starts a breadth-first traversal at multiple `roots`,
875    /// including each root in the traversal.
876    pub fn at_roots(
877        graph: &'a CookedDiGraph<'a>,
878        roots: &FixedBitSet,
879        direction: Direction,
880    ) -> Self {
881        let mut stack = VecDeque::new();
882        let mut discovered = enum_map!(_ => graph.visit_map());
883        stack.extend(
884            roots
885                .ones()
886                .map(|index| (EdgeKind::Reference, NodeIndex::new(index))),
887        );
888        discovered[EdgeKind::Reference].union_with(roots);
889        Self {
890            graph,
891            stack,
892            discovered,
893            direction,
894        }
895    }
896
897    /// Starts a breadth-first traversal from the immediate neighbors of `root`,
898    /// excluding `root` itself from the traversal.
899    pub fn from_neighbors(
900        graph: &'a CookedDiGraph<'a>,
901        root: NodeIndex<usize>,
902        direction: Direction,
903    ) -> Self {
904        let mut stack = VecDeque::new();
905        let mut discovered = enum_map! {
906            _ => {
907                let mut map = graph.visit_map();
908                map.visit(root);
909                map
910            }
911        };
912        for (kind, neighbors) in neighbors(graph, root, direction) {
913            stack.extend(
914                neighbors
915                    .difference(&discovered[kind])
916                    .map(|index| (kind, NodeIndex::new(index))),
917            );
918            discovered[kind].union_with(&neighbors);
919        }
920        Self {
921            graph,
922            stack,
923            discovered,
924            direction,
925        }
926    }
927
928    pub fn run<F>(mut self, filter: F) -> impl Iterator<Item = NodeIndex<usize>> + use<'a, F>
929    where
930        F: Fn(EdgeKind, NodeIndex<usize>) -> Traversal,
931    {
932        std::iter::from_fn(move || {
933            while let Some((kind, index)) = self.stack.pop_front() {
934                let traversal = filter(kind, index);
935
936                if matches!(traversal, Traversal::Visit | Traversal::Skip) {
937                    for (kind, neighbors) in neighbors(self.graph, index, self.direction) {
938                        for neighbor in neighbors.difference(&self.discovered[kind]) {
939                            self.stack.push_back((kind, NodeIndex::new(neighbor)));
940                        }
941                        self.discovered[kind].union_with(&neighbors);
942                    }
943                }
944
945                if matches!(traversal, Traversal::Visit | Traversal::Stop) {
946                    return Some(index);
947                }
948
949                // `Skip` and `Ignore` continue the loop without yielding.
950            }
951            None
952        })
953    }
954}
955
956/// Returns the neighbors of `node` in the given `direction`,
957/// grouped by their [`EdgeKind`].
958fn neighbors(
959    graph: &CookedDiGraph<'_>,
960    node: NodeIndex<usize>,
961    direction: Direction,
962) -> EnumMap<EdgeKind, FixedBitSet> {
963    let mut neighbors = enum_map!(_ => graph.visit_map());
964    for edge in graph.edges_directed(node, direction) {
965        let neighbor = match direction {
966            Direction::Outgoing => edge.target(),
967            Direction::Incoming => edge.source(),
968        };
969        neighbors[*edge.weight()].insert(neighbor.index());
970    }
971    neighbors
972}
973
974#[cfg(test)]
975mod tests {
976    use super::*;
977
978    use petgraph::visit::NodeCount;
979
980    use crate::tests::assert_matches;
981
982    /// Creates a simple graph: `A -> B -> C`.
983    fn linear_graph() -> DiGraph<(), (), usize> {
984        let mut g = DiGraph::default();
985        let a = g.add_node(());
986        let b = g.add_node(());
987        let c = g.add_node(());
988        g.extend_with_edges([(a, b), (b, c)]);
989        g
990    }
991
992    /// Creates a cyclic graph: `A -> B -> C -> A`, with `D -> A`.
993    fn cyclic_graph() -> DiGraph<(), (), usize> {
994        let mut g = DiGraph::default();
995        let a = g.add_node(());
996        let b = g.add_node(());
997        let c = g.add_node(());
998        let d = g.add_node(());
999        g.extend_with_edges([(a, b), (b, c), (c, a), (d, a)]);
1000        g
1001    }
1002
1003    // MARK: SCC detection
1004
1005    #[test]
1006    fn test_linear_graph_has_singleton_sccs() {
1007        let g = linear_graph();
1008        let sccs = TopoSccs::new(&g);
1009        let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
1010        assert_matches!(&*sizes, [1, 1, 1]);
1011    }
1012
1013    #[test]
1014    fn test_cyclic_graph_has_one_multi_node_scc() {
1015        let g = cyclic_graph();
1016        let sccs = TopoSccs::new(&g);
1017
1018        // A-B-C form one SCC; D is its own SCC. Since D has an edge to
1019        // the cycle, D must precede the cycle in topological order.
1020        let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
1021        assert_matches!(&*sizes, [1, 3]);
1022    }
1023
1024    // MARK: Topological ordering
1025
1026    #[test]
1027    fn test_sccs_are_in_topological_order() {
1028        let g = cyclic_graph();
1029        let sccs = TopoSccs::new(&g);
1030
1031        let d_topo = sccs.topo_index(3.into());
1032        let a_topo = sccs.topo_index(0.into());
1033        assert!(
1034            d_topo < a_topo,
1035            "D should precede A-B-C in topological order"
1036        );
1037    }
1038
1039    #[test]
1040    fn test_topo_index_consistent_within_scc() {
1041        let g = cyclic_graph();
1042        let sccs = TopoSccs::new(&g);
1043
1044        // A, B, C are in the same SCC, so they should have
1045        // the same topological index.
1046        let a_topo = sccs.topo_index(0.into());
1047        let b_topo = sccs.topo_index(1.into());
1048        let c_topo = sccs.topo_index(2.into());
1049
1050        assert_eq!(a_topo, b_topo);
1051        assert_eq!(b_topo, c_topo);
1052    }
1053
1054    // MARK: Condensation
1055
1056    #[test]
1057    fn test_condensation_has_correct_node_count() {
1058        let g = cyclic_graph();
1059        let sccs = TopoSccs::new(&g);
1060        let dag = sccs.condensation();
1061
1062        assert_eq!(dag.node_count(), 2);
1063    }
1064
1065    #[test]
1066    fn test_condensation_has_correct_edges() {
1067        let g = cyclic_graph();
1068        let sccs = TopoSccs::new(&g);
1069        let dag = sccs.condensation();
1070
1071        // D should have an edge to the A-B-C SCC, and
1072        // A-B-C shouldn't create a self-loop.
1073        let d_topo = sccs.topo_index(3.into());
1074        let abc_topo = sccs.topo_index(0.into());
1075
1076        let d_neighbors = dag.neighbors(d_topo).collect_vec();
1077        assert_eq!(&*d_neighbors, [abc_topo]);
1078
1079        let abc_neighbors = dag.neighbors(abc_topo).collect_vec();
1080        assert!(abc_neighbors.is_empty());
1081    }
1082
1083    #[test]
1084    fn test_condensation_neighbors_in_topological_order() {
1085        // Matches Petgraph's `dag_to_toposorted_adjacency_list` example:
1086        // edges added as `(top, second), (top, first)`, but neighbors should be
1087        // `[first, second]` in topological order.
1088        let mut g = DiGraph::<(), (), usize>::default();
1089        let second = g.add_node(());
1090        let top = g.add_node(());
1091        let first = g.add_node(());
1092        g.extend_with_edges([(top, second), (top, first), (first, second)]);
1093
1094        let sccs = TopoSccs::new(&g);
1095        let dag = sccs.condensation();
1096
1097        let top_topo = sccs.topo_index(top);
1098        assert_eq!(top_topo, 0);
1099
1100        let first_topo = sccs.topo_index(first);
1101        assert_eq!(first_topo, 1);
1102
1103        let second_topo = sccs.topo_index(second);
1104        assert_eq!(second_topo, 2);
1105
1106        let neighbors = dag.neighbors(top_topo).collect_vec();
1107        assert_eq!(&*neighbors, [first_topo, second_topo]);
1108    }
1109}