ploidy_core/ir/
graph.rs

1use std::{
2    any::{Any, TypeId},
3    fmt::Debug,
4};
5
6use atomic_refcell::AtomicRefCell;
7use by_address::ByAddress;
8use itertools::Itertools;
9use petgraph::algo::tarjan_scc;
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::{Bfs, VisitMap, Visitable};
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use super::{
15    spec::IrSpec,
16    types::{
17        InlineIrType, IrOperation, IrType, IrTypeRef, IrUntaggedVariant, PrimitiveIrType,
18        SchemaIrType,
19    },
20    views::{operation::IrOperationView, schema::SchemaIrTypeView, wrappers::IrPrimitiveView},
21};
22
23/// The type graph.
24pub type IrGraphG<'a> = DiGraph<IrGraphNode<'a>, ()>;
25
26/// A graph of all the types in an [`IrSpec`], where each edge
27/// is a reference from one type to another.
28#[derive(Debug)]
29pub struct IrGraph<'a> {
30    pub(super) spec: &'a IrSpec<'a>,
31    pub(super) g: IrGraphG<'a>,
32    /// An inverted index of nodes to graph indices.
33    pub(super) indices: FxHashMap<IrGraphNode<'a>, NodeIndex>,
34    /// Edges that are part of a cycle.
35    pub(super) circular_refs: FxHashSet<(NodeIndex, NodeIndex)>,
36    /// Additional metadata for each node.
37    pub(super) metadata: FxHashMap<NodeIndex, IrGraphNodeMeta<'a>>,
38}
39
40impl<'a> IrGraph<'a> {
41    pub fn new(spec: &'a IrSpec<'a>) -> Self {
42        let mut g = DiGraph::new();
43        let mut indices = FxHashMap::default();
44
45        // All roots (named schemas, parameters, request and response bodies),
46        // and all the types within them (inline schemas, wrappers, primitives).
47        let tys = IrTypeVisitor::new(
48            spec.schemas
49                .values()
50                .chain(spec.operations.iter().flat_map(|op| op.types())),
51        );
52
53        // Add nodes for all types, and edges for references between them.
54        for (parent, child) in tys {
55            use std::collections::hash_map::Entry;
56            let &mut to = match indices.entry(IrGraphNode::from_ref(spec, child.as_ref())) {
57                // We might see the same schema multiple times, if it's
58                // referenced multiple times in the spec. Only add a new node
59                // for the schema if we haven't seen it before.
60                Entry::Occupied(entry) => entry.into_mut(),
61                Entry::Vacant(entry) => {
62                    let index = g.add_node(*entry.key());
63                    entry.insert(index)
64                }
65            };
66            if let Some(parent) = parent {
67                let &mut from = match indices.entry(IrGraphNode::from_ref(spec, parent.as_ref())) {
68                    Entry::Occupied(entry) => entry.into_mut(),
69                    Entry::Vacant(entry) => {
70                        let index = g.add_node(*entry.key());
71                        entry.insert(index)
72                    }
73                };
74                // Add a directed edge from parent to child.
75                g.add_edge(from, to, ());
76            }
77        }
78
79        // Precompute all circular reference edges, where each edge forms a cycle
80        // that requires indirection to break. This speeds up `needs_indirection_to()`:
81        // Tarjan's algorithm runs in O(V + E) time over the entire graph; a naive DFS
82        // in `needs_indirection_to()` would run in O(N * (V + E)) time, where N is
83        // the total number of fields in all structs.
84        let circular_refs = {
85            let mut edges = FxHashSet::default();
86            for scc in tarjan_scc(&g) {
87                let scc = FxHashSet::from_iter(scc);
88                for &node in &scc {
89                    edges.extend(
90                        g.neighbors(node)
91                            .filter(|neighbor| scc.contains(neighbor))
92                            .map(|neighbor| (node, neighbor)),
93                    );
94                }
95            }
96            edges
97        };
98
99        // Create empty metadata slots for all types.
100        let mut metadata = g
101            .node_indices()
102            .map(|index| (index, IrGraphNodeMeta::default()))
103            .collect::<FxHashMap<_, _>>();
104
105        // Precompute a mapping of types to all the operations that use them.
106        // This speeds up `used_by()`: precomputing runs in O(P * (V + E)) time,
107        // where P is the number of operations; a BFS in `used_by()` would
108        // run in O(C * P * (V + E)) time, where C is the number of calls to
109        // `used_by()`.
110        for op in spec.operations.iter() {
111            let stack = op
112                .types()
113                .map(|ty| IrGraphNode::from_ref(spec, ty.as_ref()))
114                .map(|node| indices[&node])
115                .collect();
116            let mut discovered = g.visit_map();
117            for &index in &stack {
118                discovered.visit(index);
119            }
120            let mut bfs = Bfs { stack, discovered };
121            while let Some(index) = bfs.next(&g) {
122                let meta = metadata.get_mut(&index).unwrap();
123                meta.operations.insert(ByAddress(op));
124            }
125        }
126
127        Self {
128            spec,
129            indices,
130            g,
131            circular_refs,
132            metadata,
133        }
134    }
135
136    /// Returns the spec used to build this graph.
137    #[inline]
138    pub fn spec(&self) -> &'a IrSpec<'a> {
139        self.spec
140    }
141
142    /// Returns an iterator over all the named schemas in this graph.
143    #[inline]
144    pub fn schemas(&self) -> impl Iterator<Item = SchemaIrTypeView<'_>> {
145        self.g
146            .node_indices()
147            .filter_map(|index| match self.g[index] {
148                IrGraphNode::Schema(ty) => Some(SchemaIrTypeView::new(self, index, ty)),
149                _ => None,
150            })
151    }
152
153    /// Returns an iterator over all the primitive types in this graph. Note that
154    /// a graph contains at most one instance of each primitive type.
155    #[inline]
156    pub fn primitives(&self) -> impl Iterator<Item = IrPrimitiveView<'_>> {
157        self.g
158            .node_indices()
159            .filter_map(|index| match self.g[index] {
160                IrGraphNode::Primitive(ty) => Some(IrPrimitiveView::new(self, index, ty)),
161                _ => None,
162            })
163    }
164
165    /// Returns an iterator over all the operations in this graph.
166    #[inline]
167    pub fn operations(&self) -> impl Iterator<Item = IrOperationView<'_>> {
168        self.spec
169            .operations
170            .iter()
171            .map(move |op| IrOperationView::new(self, op))
172    }
173}
174
175/// A node in the type graph.
176///
177/// The derived [`Hash`][std::hash::Hash] and [`Eq`] implementations
178/// work on the underlying values, so structurally identical types
179/// will be equal. This is important: all types in an [`IrSpec`] are
180/// distinct in memory, but can refer to the same logical type.
181#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
182pub enum IrGraphNode<'a> {
183    Schema(&'a SchemaIrType<'a>),
184    Inline(&'a InlineIrType<'a>),
185    Array(&'a IrType<'a>),
186    Map(&'a IrType<'a>),
187    Optional(&'a IrType<'a>),
188    Primitive(PrimitiveIrType),
189    Any,
190}
191
192impl<'a> IrGraphNode<'a> {
193    /// Converts an [`IrTypeRef`] to an [`IrGraphNode`],
194    /// recursively resolving referenced schemas.
195    pub fn from_ref(spec: &'a IrSpec<'a>, ty: IrTypeRef<'a>) -> Self {
196        match ty {
197            IrTypeRef::Schema(ty) => IrGraphNode::Schema(ty),
198            IrTypeRef::Inline(ty) => IrGraphNode::Inline(ty),
199            IrTypeRef::Array(ty) => IrGraphNode::Array(ty),
200            IrTypeRef::Map(ty) => IrGraphNode::Map(ty),
201            IrTypeRef::Optional(ty) => IrGraphNode::Optional(ty),
202            IrTypeRef::Ref(r) => Self::from_ref(spec, spec.schemas[r.name()].as_ref()),
203            IrTypeRef::Primitive(ty) => IrGraphNode::Primitive(ty),
204            IrTypeRef::Any => IrGraphNode::Any,
205        }
206    }
207}
208
209#[derive(Default)]
210pub(super) struct IrGraphNodeMeta<'a> {
211    /// The set of operations that transitively use this type.
212    pub operations: FxHashSet<ByAddress<&'a IrOperation<'a>>>,
213    /// Opaque extended data for this type.
214    pub extensions: AtomicRefCell<ExtensionMap>,
215}
216
217impl Debug for IrGraphNodeMeta<'_> {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        f.debug_struct("IrGraphNodeMeta")
220            .field("operations", &self.operations)
221            .finish_non_exhaustive()
222    }
223}
224
225/// Visits all the types and references contained within a type.
226#[derive(Debug)]
227struct IrTypeVisitor<'a> {
228    stack: Vec<(Option<&'a IrType<'a>>, &'a IrType<'a>)>,
229}
230
231impl<'a> IrTypeVisitor<'a> {
232    /// Creates a visitor with `root` on the stack of types to visit.
233    #[inline]
234    fn new(roots: impl Iterator<Item = &'a IrType<'a>>) -> Self {
235        let mut stack = roots.map(|root| (None, root)).collect_vec();
236        stack.reverse();
237        Self { stack }
238    }
239}
240
241impl<'a> Iterator for IrTypeVisitor<'a> {
242    type Item = (Option<&'a IrType<'a>>, &'a IrType<'a>);
243
244    fn next(&mut self) -> Option<Self::Item> {
245        let (parent, top) = self.stack.pop()?;
246        match top {
247            IrType::Array(ty) => {
248                self.stack.push((Some(top), ty.as_ref()));
249            }
250            IrType::Map(ty) => {
251                self.stack.push((Some(top), ty.as_ref()));
252            }
253            IrType::Optional(ty) => {
254                self.stack.push((Some(top), ty.as_ref()));
255            }
256            IrType::Schema(SchemaIrType::Struct(_, ty)) => {
257                self.stack
258                    .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
259            }
260            IrType::Schema(SchemaIrType::Untagged(_, ty)) => {
261                self.stack.extend(
262                    ty.variants
263                        .iter()
264                        .filter_map(|variant| match variant {
265                            IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
266                            _ => None,
267                        })
268                        .rev(),
269                );
270            }
271            IrType::Schema(SchemaIrType::Tagged(_, ty)) => {
272                self.stack.extend(
273                    ty.variants
274                        .iter()
275                        .map(|variant| (Some(top), &variant.ty))
276                        .rev(),
277                );
278            }
279            IrType::Schema(SchemaIrType::Enum(..)) => (),
280            IrType::Any => (),
281            IrType::Primitive(_) => (),
282            IrType::Inline(ty) => match ty {
283                InlineIrType::Enum(..) => (),
284                InlineIrType::Tagged(_, ty) => {
285                    self.stack.extend(
286                        ty.variants
287                            .iter()
288                            .map(|variant| (Some(top), &variant.ty))
289                            .rev(),
290                    );
291                }
292                InlineIrType::Untagged(_, ty) => {
293                    self.stack.extend(
294                        ty.variants
295                            .iter()
296                            .filter_map(|variant| match variant {
297                                IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
298                                _ => None,
299                            })
300                            .rev(),
301                    );
302                }
303                InlineIrType::Struct(_, ty) => {
304                    self.stack
305                        .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
306                }
307            },
308            IrType::Ref(_) => (),
309        }
310        Some((parent, top))
311    }
312}
313
314/// A map that can store one value for each type.
315pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
316
317pub trait Extension: Any + Send + Sync {
318    fn into_inner(self: Box<Self>) -> Box<dyn Any>;
319}
320
321impl dyn Extension {
322    #[inline]
323    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
324        (self as &dyn Any).downcast_ref::<T>()
325    }
326}
327
328impl<T: Send + Sync + 'static> Extension for T {
329    #[inline]
330    fn into_inner(self: Box<Self>) -> Box<dyn Any> {
331        self
332    }
333}