Skip to main content

ploidy_core/ir/
graph.rs

1use std::{
2    any::{Any, TypeId},
3    fmt::Debug,
4};
5
6use atomic_refcell::AtomicRefCell;
7use by_address::ByAddress;
8use fixedbitset::FixedBitSet;
9use itertools::Itertools;
10use petgraph::{
11    Direction,
12    adj::UnweightedList,
13    algo::{TarjanScc, tred},
14    data::Build,
15    graph::{DiGraph, NodeIndex},
16    visit::{IntoNeighbors, NodeCount},
17};
18use rustc_hash::{FxHashMap, FxHashSet};
19
20use super::{
21    spec::IrSpec,
22    types::{
23        InlineIrType, IrOperation, IrType, IrTypeRef, IrUntaggedVariant, PrimitiveIrType,
24        SchemaIrType,
25    },
26    views::{operation::IrOperationView, schema::SchemaIrTypeView, wrappers::IrPrimitiveView},
27};
28
29/// The type graph.
30pub(super) type IrGraphG<'a> = DiGraph<IrGraphNode<'a>, (), usize>;
31
32/// A graph of all the types in an [`IrSpec`], where each edge
33/// is a reference from one type to another.
34#[derive(Debug)]
35pub struct IrGraph<'a> {
36    pub(super) spec: &'a IrSpec<'a>,
37    pub(super) g: IrGraphG<'a>,
38    /// An inverted index of nodes to graph indices.
39    pub(super) indices: FxHashMap<IrGraphNode<'a>, NodeIndex<usize>>,
40    /// Edges that are part of a cycle.
41    pub(super) circular_refs: FxHashSet<(NodeIndex<usize>, NodeIndex<usize>)>,
42    /// Additional metadata for each node.
43    pub(super) metadata: IrGraphMetadata<'a>,
44}
45
46impl<'a> IrGraph<'a> {
47    pub fn new(spec: &'a IrSpec<'a>) -> Self {
48        let mut g = IrGraphG::default();
49        let mut indices = FxHashMap::default();
50
51        // All roots (named schemas, parameters, request and response bodies),
52        // and all the types within them (inline schemas, wrappers, primitives).
53        let tys = IrTypeVisitor::new(
54            spec.schemas
55                .values()
56                .chain(spec.operations.iter().flat_map(|op| op.types())),
57        );
58
59        // Add nodes for all types, and edges for references between them.
60        for (parent, child) in tys {
61            use std::collections::hash_map::Entry;
62            let &mut to = match indices.entry(IrGraphNode::from_ref(spec, child.as_ref())) {
63                // We might see the same schema multiple times, if it's
64                // referenced multiple times in the spec. Only add a new node
65                // for the schema if we haven't seen it before.
66                Entry::Occupied(entry) => entry.into_mut(),
67                Entry::Vacant(entry) => {
68                    let index = g.add_node(*entry.key());
69                    entry.insert(index)
70                }
71            };
72            if let Some(parent) = parent {
73                let &mut from = match indices.entry(IrGraphNode::from_ref(spec, parent.as_ref())) {
74                    Entry::Occupied(entry) => entry.into_mut(),
75                    Entry::Vacant(entry) => {
76                        let index = g.add_node(*entry.key());
77                        entry.insert(index)
78                    }
79                };
80                // Add a directed edge from parent to child.
81                g.add_edge(from, to, ());
82            }
83        }
84
85        let sccs = TopoSccs::new(&g);
86
87        // Precompute all circular reference edges, where both endpoints
88        // are in the same SCC, to speed up `needs_indirection()`.
89        let circular_refs = {
90            let mut edges = FxHashSet::default();
91            for members in sccs.iter() {
92                for node in members.ones().map(NodeIndex::new) {
93                    edges.extend(
94                        g.neighbors(node)
95                            .filter(|neighbor| members.contains(neighbor.index()))
96                            .map(|neighbor| (node, neighbor)),
97                    );
98                }
99            }
100            edges
101        };
102
103        let metadata = {
104            let mut metadata = IrGraphMetadata::default();
105
106            // Precompute the set of type indices that each operation
107            // references directly.
108            for op in &spec.operations {
109                metadata.operations.entry(ByAddress(op)).or_default().types = op
110                    .types()
111                    .filter_map(|ty| {
112                        indices
113                            .get(&IrGraphNode::from_ref(spec, ty.as_ref()))
114                            .map(|node| node.index())
115                    })
116                    .collect();
117            }
118
119            // Forward propagation: for each type, compute all the types
120            // that it depends on, directly and transitively.
121            {
122                // Condense each of the original graph's strongly connected components
123                // into a single node, forming a DAG.
124                let condensation = sccs.condensation();
125
126                // Compute the transitive closure; discard the reduction.
127                let (_, closure) = tred::dag_transitive_reduction_closure(&condensation);
128
129                // Expand SCC-level dependencies to node-level: for each SCC,
130                // form a union of all nodes from all the SCCs it depends on.
131                let mut deps_by_scc =
132                    vec![FixedBitSet::with_capacity(g.node_count()); condensation.node_count()];
133                for scc_index in condensation.node_indices() {
134                    for dep_scc_index in closure.neighbors(scc_index) {
135                        deps_by_scc[scc_index].union_with(sccs.members(dep_scc_index));
136                    }
137                    // Include the other members of this SCC; these depend on
138                    // each other because they're in a cycle.
139                    deps_by_scc[scc_index].union_with(sccs.members(scc_index));
140                }
141
142                for node in g.node_indices() {
143                    let topo_index = sccs.topo_index(node);
144                    let mut deps = deps_by_scc[topo_index].clone();
145
146                    // Exclude ourselves from our dependencies and dependents.
147                    deps.remove(node.index());
148
149                    metadata
150                        .schemas
151                        .entry(node)
152                        .or_default()
153                        .dependencies
154                        .union_with(&deps);
155
156                    // Add ourselves to the dependents of all the types
157                    // that we depend on.
158                    for index in deps.into_ones().map(NodeIndex::new) {
159                        metadata
160                            .schemas
161                            .entry(index)
162                            .or_default()
163                            .dependents
164                            .grow_and_insert(node.index());
165                    }
166                }
167            }
168
169            // Backward propagation: propagate each operation to all the types
170            // that it uses, directly and transitively.
171            for op in &spec.operations {
172                let meta = &metadata.operations[&ByAddress(op)];
173
174                // Collect all the types that this operation depends on.
175                let mut transitive_deps = FixedBitSet::with_capacity(g.node_count());
176                for node in meta.types.ones().map(NodeIndex::new) {
177                    transitive_deps.insert(node.index());
178                    if let Some(meta) = metadata.schemas.get(&node) {
179                        transitive_deps.union_with(&meta.dependencies);
180                    }
181                }
182
183                // Mark each type as being used by this operation.
184                for index in transitive_deps.ones().map(NodeIndex::new) {
185                    metadata
186                        .schemas
187                        .entry(index)
188                        .or_default()
189                        .used_by
190                        .insert(ByAddress(op));
191                }
192            }
193
194            metadata
195        };
196
197        Self {
198            spec,
199            indices,
200            g,
201            circular_refs,
202            metadata,
203        }
204    }
205
206    /// Returns the spec used to build this graph.
207    #[inline]
208    pub fn spec(&self) -> &'a IrSpec<'a> {
209        self.spec
210    }
211
212    /// Returns an iterator over all the named schemas in this graph.
213    #[inline]
214    pub fn schemas(&self) -> impl Iterator<Item = SchemaIrTypeView<'_>> {
215        self.g
216            .node_indices()
217            .filter_map(|index| match self.g[index] {
218                IrGraphNode::Schema(ty) => Some(SchemaIrTypeView::new(self, index, ty)),
219                _ => None,
220            })
221    }
222
223    /// Returns an iterator over all the primitive types in this graph. Note that
224    /// a graph contains at most one instance of each primitive type.
225    #[inline]
226    pub fn primitives(&self) -> impl Iterator<Item = IrPrimitiveView<'_>> {
227        self.g
228            .node_indices()
229            .filter_map(|index| match self.g[index] {
230                IrGraphNode::Primitive(ty) => Some(IrPrimitiveView::new(self, index, ty)),
231                _ => None,
232            })
233    }
234
235    /// Returns an iterator over all the operations in this graph.
236    #[inline]
237    pub fn operations(&self) -> impl Iterator<Item = IrOperationView<'_>> {
238        self.spec
239            .operations
240            .iter()
241            .map(move |op| IrOperationView::new(self, op))
242    }
243}
244
245/// A node in the type graph.
246///
247/// The derived [`Hash`][std::hash::Hash] and [`Eq`] implementations
248/// work on the underlying values, so structurally identical types
249/// will be equal. This is important: all types in an [`IrSpec`] are
250/// distinct in memory, but can refer to the same logical type.
251#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
252pub enum IrGraphNode<'a> {
253    Schema(&'a SchemaIrType<'a>),
254    Inline(&'a InlineIrType<'a>),
255    Array(&'a IrType<'a>),
256    Map(&'a IrType<'a>),
257    Optional(&'a IrType<'a>),
258    Primitive(PrimitiveIrType),
259    Any,
260}
261
262impl<'a> IrGraphNode<'a> {
263    /// Converts an [`IrTypeRef`] to an [`IrGraphNode`],
264    /// recursively resolving referenced schemas.
265    pub fn from_ref(spec: &'a IrSpec<'a>, ty: IrTypeRef<'a>) -> Self {
266        match ty {
267            IrTypeRef::Schema(ty) => IrGraphNode::Schema(ty),
268            IrTypeRef::Inline(ty) => IrGraphNode::Inline(ty),
269            IrTypeRef::Array(ty) => IrGraphNode::Array(ty),
270            IrTypeRef::Map(ty) => IrGraphNode::Map(ty),
271            IrTypeRef::Optional(ty) => IrGraphNode::Optional(ty),
272            IrTypeRef::Ref(r) => Self::from_ref(spec, spec.schemas[r.name()].as_ref()),
273            IrTypeRef::Primitive(ty) => IrGraphNode::Primitive(ty),
274            IrTypeRef::Any => IrGraphNode::Any,
275        }
276    }
277}
278
279/// Precomputed metadata for schema types and operations in the graph.
280#[derive(Debug, Default)]
281pub struct IrGraphMetadata<'a> {
282    pub schemas: FxHashMap<NodeIndex<usize>, IrGraphNodeMeta<'a>>,
283    pub operations: FxHashMap<ByAddress<&'a IrOperation<'a>>, IrGraphOperationMeta>,
284}
285
286/// Precomputed metadata for an operation that references
287/// types in the graph.
288#[derive(Debug, Default)]
289pub struct IrGraphOperationMeta {
290    /// Indices of all the types that this operation directly depends on:
291    /// parameters, request body, and response body.
292    pub types: FixedBitSet,
293}
294
295/// Precomputed metadata for a schema type in the graph.
296#[derive(Default)]
297pub(super) struct IrGraphNodeMeta<'a> {
298    /// Operations that use this type.
299    pub used_by: FxHashSet<ByAddress<&'a IrOperation<'a>>>,
300    /// Indices of other types that this type transitively depends on.
301    pub dependencies: FixedBitSet,
302    /// Indices of other types that transitively depend on this type.
303    pub dependents: FixedBitSet,
304    /// Opaque extended data for this type.
305    pub extensions: AtomicRefCell<ExtensionMap>,
306}
307
308impl Debug for IrGraphNodeMeta<'_> {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("IrGraphNodeMeta")
311            .field("used_by", &self.used_by)
312            .field("dependencies", &self.dependencies)
313            .field("dependents", &self.dependents)
314            .finish_non_exhaustive()
315    }
316}
317
318/// Visits all the types and references contained within a type.
319#[derive(Debug)]
320struct IrTypeVisitor<'a> {
321    stack: Vec<(Option<&'a IrType<'a>>, &'a IrType<'a>)>,
322}
323
324impl<'a> IrTypeVisitor<'a> {
325    /// Creates a visitor with `root` on the stack of types to visit.
326    #[inline]
327    fn new(roots: impl Iterator<Item = &'a IrType<'a>>) -> Self {
328        let mut stack = roots.map(|root| (None, root)).collect_vec();
329        stack.reverse();
330        Self { stack }
331    }
332}
333
334impl<'a> Iterator for IrTypeVisitor<'a> {
335    type Item = (Option<&'a IrType<'a>>, &'a IrType<'a>);
336
337    fn next(&mut self) -> Option<Self::Item> {
338        let (parent, top) = self.stack.pop()?;
339        match top {
340            IrType::Array(ty) => {
341                self.stack.push((Some(top), ty.as_ref()));
342            }
343            IrType::Map(ty) => {
344                self.stack.push((Some(top), ty.as_ref()));
345            }
346            IrType::Optional(ty) => {
347                self.stack.push((Some(top), ty.as_ref()));
348            }
349            IrType::Schema(SchemaIrType::Struct(_, ty)) => {
350                self.stack
351                    .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
352            }
353            IrType::Schema(SchemaIrType::Untagged(_, ty)) => {
354                self.stack.extend(
355                    ty.variants
356                        .iter()
357                        .filter_map(|variant| match variant {
358                            IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
359                            _ => None,
360                        })
361                        .rev(),
362                );
363            }
364            IrType::Schema(SchemaIrType::Tagged(_, ty)) => {
365                self.stack.extend(
366                    ty.variants
367                        .iter()
368                        .map(|variant| (Some(top), &variant.ty))
369                        .rev(),
370                );
371            }
372            IrType::Schema(SchemaIrType::Enum(..)) => (),
373            IrType::Any => (),
374            IrType::Primitive(_) => (),
375            IrType::Inline(ty) => match ty {
376                InlineIrType::Enum(..) => (),
377                InlineIrType::Tagged(_, ty) => {
378                    self.stack.extend(
379                        ty.variants
380                            .iter()
381                            .map(|variant| (Some(top), &variant.ty))
382                            .rev(),
383                    );
384                }
385                InlineIrType::Untagged(_, ty) => {
386                    self.stack.extend(
387                        ty.variants
388                            .iter()
389                            .filter_map(|variant| match variant {
390                                IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
391                                _ => None,
392                            })
393                            .rev(),
394                    );
395                }
396                InlineIrType::Struct(_, ty) => {
397                    self.stack
398                        .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
399                }
400            },
401            IrType::Ref(_) => (),
402        }
403        Some((parent, top))
404    }
405}
406
407/// A map that can store one value for each type.
408pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
409
410pub trait Extension: Any + Send + Sync {
411    fn into_inner(self: Box<Self>) -> Box<dyn Any>;
412}
413
414impl dyn Extension {
415    #[inline]
416    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
417        (self as &dyn Any).downcast_ref::<T>()
418    }
419}
420
421impl<T: Send + Sync + 'static> Extension for T {
422    #[inline]
423    fn into_inner(self: Box<Self>) -> Box<dyn Any> {
424        self
425    }
426}
427
428/// Strongly connected components (SCCs) in topological order.
429///
430/// [`TopoSccs`] uses Tarjan's single-pass algorithm to find all SCCs,
431/// and provides topological ordering, efficient membership testing, and
432/// condensation for computing the transitive closure. These are
433/// building blocks for cycle detection and dependency propagation.
434struct TopoSccs<'a, N, E> {
435    graph: &'a DiGraph<N, E, usize>,
436    tarjan: TarjanScc<NodeIndex<usize>>,
437    sccs: Vec<FixedBitSet>,
438}
439
440impl<'a, N, E> TopoSccs<'a, N, E> {
441    fn new(graph: &'a DiGraph<N, E, usize>) -> Self {
442        let mut sccs = Vec::new();
443        let mut tarjan = TarjanScc::new();
444        tarjan.run(graph, |scc_nodes| {
445            sccs.push(scc_nodes.iter().map(|node| node.index()).collect());
446        });
447        // Tarjan's algorithm returns SCCs in reverse topological order;
448        // reverse them to get the topological order.
449        sccs.reverse();
450        Self {
451            graph,
452            tarjan,
453            sccs,
454        }
455    }
456
457    /// Returns the topological index of the SCC that contains the given node.
458    #[inline]
459    fn topo_index(&self, node: NodeIndex<usize>) -> usize {
460        // Tarjan's algorithm returns indices in reverse topological order;
461        // inverting the component index gets us the topological index.
462        self.sccs.len() - 1 - self.tarjan.node_component_index(self.graph, node)
463    }
464
465    /// Returns the members of the SCC at the given topological index.
466    #[inline]
467    fn members(&self, index: usize) -> &FixedBitSet {
468        &self.sccs[index]
469    }
470
471    /// Iterates over the SCCs in topological order.
472    #[inline]
473    fn iter(&self) -> std::slice::Iter<'_, FixedBitSet> {
474        self.sccs.iter()
475    }
476
477    /// Builds a condensed DAG of SCCs.
478    ///
479    /// The condensed graph is represented as an adjacency list, where both
480    /// the node indices and the neighbors of each node are stored in
481    /// topological order. This specific ordering is required by
482    /// [`tred::dag_transitive_reduction_closure`].
483    fn condensation(&self) -> UnweightedList<usize> {
484        let scc_count = self.sccs.len();
485        let mut dag = UnweightedList::with_capacity(scc_count);
486        for to in 0..scc_count {
487            dag.add_node();
488            for index in self.sccs[to].ones().map(NodeIndex::new) {
489                for neighbor in self.graph.neighbors_directed(index, Direction::Incoming) {
490                    let from = self.topo_index(neighbor);
491                    if from != to {
492                        dag.update_edge(from, to, ());
493                    }
494                }
495            }
496        }
497        dag
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    use crate::tests::assert_matches;
506
507    /// Creates a simple graph: `A -> B -> C`.
508    fn linear_graph() -> DiGraph<(), (), usize> {
509        let mut g = DiGraph::default();
510        let a = g.add_node(());
511        let b = g.add_node(());
512        let c = g.add_node(());
513        g.extend_with_edges([(a, b), (b, c)]);
514        g
515    }
516
517    /// Creates a cyclic graph: `A -> B -> C -> A`, with `D -> A`.
518    fn cyclic_graph() -> DiGraph<(), (), usize> {
519        let mut g = DiGraph::default();
520        let a = g.add_node(());
521        let b = g.add_node(());
522        let c = g.add_node(());
523        let d = g.add_node(());
524        g.extend_with_edges([(a, b), (b, c), (c, a), (d, a)]);
525        g
526    }
527
528    // MARK: SCC detection
529
530    #[test]
531    fn test_linear_graph_has_singleton_sccs() {
532        let g = linear_graph();
533        let sccs = TopoSccs::new(&g);
534        let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
535        assert_matches!(&*sizes, [1, 1, 1]);
536    }
537
538    #[test]
539    fn test_cyclic_graph_has_one_multi_node_scc() {
540        let g = cyclic_graph();
541        let sccs = TopoSccs::new(&g);
542
543        // A-B-C form one SCC; D is its own SCC. Since D has an edge to
544        // the cycle, D must precede the cycle in topological order.
545        let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
546        assert_matches!(&*sizes, [1, 3]);
547    }
548
549    // MARK: Topological ordering
550
551    #[test]
552    fn test_sccs_are_in_topological_order() {
553        let g = cyclic_graph();
554        let sccs = TopoSccs::new(&g);
555
556        let d_topo = sccs.topo_index(3.into());
557        let a_topo = sccs.topo_index(0.into());
558        assert!(
559            d_topo < a_topo,
560            "D should precede A-B-C in topological order"
561        );
562    }
563
564    #[test]
565    fn test_topo_index_consistent_within_scc() {
566        let g = cyclic_graph();
567        let sccs = TopoSccs::new(&g);
568
569        // A, B, C are in the same SCC, so they should have
570        // the same topological index.
571        let a_topo = sccs.topo_index(0.into());
572        let b_topo = sccs.topo_index(1.into());
573        let c_topo = sccs.topo_index(2.into());
574
575        assert_eq!(a_topo, b_topo);
576        assert_eq!(b_topo, c_topo);
577    }
578
579    // MARK: Condensation
580
581    #[test]
582    fn test_condensation_has_correct_node_count() {
583        let g = cyclic_graph();
584        let sccs = TopoSccs::new(&g);
585        let dag = sccs.condensation();
586
587        assert_eq!(dag.node_count(), 2);
588    }
589
590    #[test]
591    fn test_condensation_has_correct_edges() {
592        let g = cyclic_graph();
593        let sccs = TopoSccs::new(&g);
594        let dag = sccs.condensation();
595
596        // D should have an edge to the A-B-C SCC, and
597        // A-B-C shouldn't create a self-loop.
598        let d_topo = sccs.topo_index(3.into());
599        let abc_topo = sccs.topo_index(0.into());
600
601        let d_neighbors = dag.neighbors(d_topo).collect_vec();
602        assert_eq!(&*d_neighbors, [abc_topo]);
603
604        let abc_neighbors = dag.neighbors(abc_topo).collect_vec();
605        assert!(abc_neighbors.is_empty());
606    }
607
608    #[test]
609    fn test_condensation_neighbors_in_topological_order() {
610        // Matches Petgraph's `dag_to_toposorted_adjacency_list` example:
611        // edges added as `(top, second), (top, first)`, but neighbors should be
612        // `[first, second]` in topological order.
613        let mut g = DiGraph::<(), (), usize>::default();
614        let second = g.add_node(());
615        let top = g.add_node(());
616        let first = g.add_node(());
617        g.extend_with_edges([(top, second), (top, first), (first, second)]);
618
619        let sccs = TopoSccs::new(&g);
620        let dag = sccs.condensation();
621
622        let top_topo = sccs.topo_index(top);
623        assert_eq!(top_topo, 0);
624
625        let first_topo = sccs.topo_index(first);
626        assert_eq!(first_topo, 1);
627
628        let second_topo = sccs.topo_index(second);
629        assert_eq!(second_topo, 2);
630
631        let neighbors = dag.neighbors(top_topo).collect_vec();
632        assert_eq!(&*neighbors, [first_topo, second_topo]);
633    }
634}