ploidy_core/ir/
graph.rs

1use std::{
2    any::{Any, TypeId},
3    fmt::Debug,
4};
5
6use atomic_refcell::AtomicRefCell;
7use by_address::ByAddress;
8use itertools::Itertools;
9use petgraph::algo::tarjan_scc;
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::{Bfs, VisitMap, Visitable};
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use super::{
15    spec::IrSpec,
16    types::{
17        InlineIrType, IrOperation, IrType, IrTypeRef, IrUntaggedVariant, PrimitiveIrType,
18        SchemaIrType,
19    },
20    views::{operation::IrOperationView, schema::SchemaIrTypeView},
21};
22
23/// The type graph.
24pub type IrGraphG<'a> = DiGraph<IrGraphNode<'a>, ()>;
25
26/// A graph of all the types in an [`IrSpec`], where each edge
27/// is a reference from one type to another.
28#[derive(Debug)]
29pub struct IrGraph<'a> {
30    pub(super) spec: &'a IrSpec<'a>,
31    pub(super) g: IrGraphG<'a>,
32    /// An inverted index of nodes to graph indices.
33    pub(super) indices: FxHashMap<IrGraphNode<'a>, NodeIndex>,
34    /// Edges that are part of a cycle.
35    pub(super) circular_refs: FxHashSet<(NodeIndex, NodeIndex)>,
36    /// Additional metadata for each node.
37    pub(super) metadata: FxHashMap<NodeIndex, IrGraphNodeMeta<'a>>,
38}
39
40impl<'a> IrGraph<'a> {
41    pub fn new(spec: &'a IrSpec<'a>) -> Self {
42        let mut g = DiGraph::new();
43        let mut indices = FxHashMap::default();
44
45        // All roots (named schemas, parameters, request and response bodies),
46        // and all the types within them (inline schemas, wrappers, primitives).
47        let tys = IrTypeVisitor::new(
48            spec.schemas
49                .values()
50                .chain(spec.operations.iter().flat_map(|op| op.types())),
51        );
52
53        // Add nodes for all types, and edges for references between them.
54        for (parent, child) in tys {
55            use std::collections::hash_map::Entry;
56            let &mut to = match indices.entry(IrGraphNode::from_ref(spec, child.as_ref())) {
57                // We might see the same schema multiple times, if it's
58                // referenced multiple times in the spec. Only add a new node
59                // for the schema if we haven't seen it before.
60                Entry::Occupied(entry) => entry.into_mut(),
61                Entry::Vacant(entry) => {
62                    let index = g.add_node(*entry.key());
63                    entry.insert(index)
64                }
65            };
66            if let Some(parent) = parent {
67                let &mut from = match indices.entry(IrGraphNode::from_ref(spec, parent.as_ref())) {
68                    Entry::Occupied(entry) => entry.into_mut(),
69                    Entry::Vacant(entry) => {
70                        let index = g.add_node(*entry.key());
71                        entry.insert(index)
72                    }
73                };
74                // Add a directed edge from parent to child.
75                g.add_edge(from, to, ());
76            }
77        }
78
79        // Precompute all circular reference edges, where each edge forms a cycle
80        // that requires indirection to break. This speeds up `needs_indirection_to()`:
81        // Tarjan's algorithm runs in O(V + E) time over the entire graph; a naive DFS
82        // in `needs_indirection_to()` would run in O(N * (V + E)) time, where N is
83        // the total number of fields in all structs.
84        let circular_refs = {
85            let mut edges = FxHashSet::default();
86            for scc in tarjan_scc(&g) {
87                let scc = FxHashSet::from_iter(scc);
88                for &node in &scc {
89                    edges.extend(
90                        g.neighbors(node)
91                            .filter(|neighbor| scc.contains(neighbor))
92                            .map(|neighbor| (node, neighbor)),
93                    );
94                }
95            }
96            edges
97        };
98
99        // Create empty metadata slots for all types.
100        let mut metadata = g
101            .node_indices()
102            .map(|index| (index, IrGraphNodeMeta::default()))
103            .collect::<FxHashMap<_, _>>();
104
105        // Precompute a mapping of types to all the operations that use them.
106        // This speeds up `used_by()`: precomputing runs in O(P * (V + E)) time,
107        // where P is the number of operations; a BFS in `used_by()` would
108        // run in O(C * P * (V + E)) time, where C is the number of calls to
109        // `used_by()`.
110        for op in spec.operations.iter() {
111            let stack = op
112                .types()
113                .map(|ty| IrGraphNode::from_ref(spec, ty.as_ref()))
114                .map(|node| indices[&node])
115                .collect();
116            let mut discovered = g.visit_map();
117            for &index in &stack {
118                discovered.visit(index);
119            }
120            let mut bfs = Bfs { stack, discovered };
121            while let Some(index) = bfs.next(&g) {
122                let meta = metadata.get_mut(&index).unwrap();
123                meta.operations.insert(ByAddress(op));
124            }
125        }
126
127        Self {
128            spec,
129            indices,
130            g,
131            circular_refs,
132            metadata,
133        }
134    }
135
136    /// Returns the spec used to build this graph.
137    #[inline]
138    pub fn spec(&self) -> &'a IrSpec<'a> {
139        self.spec
140    }
141
142    /// Returns an iterator over all the named schemas in this graph.
143    #[inline]
144    pub fn schemas(&self) -> impl Iterator<Item = SchemaIrTypeView<'_>> {
145        self.g
146            .node_indices()
147            .filter_map(|index| match self.g[index] {
148                IrGraphNode::Schema(ty) => Some(SchemaIrTypeView::new(self, index, ty)),
149                _ => None,
150            })
151    }
152
153    /// Returns an iterator over all the operations in this graph.
154    #[inline]
155    pub fn operations(&self) -> impl Iterator<Item = IrOperationView<'_>> {
156        self.spec
157            .operations
158            .iter()
159            .map(move |op| IrOperationView::new(self, op))
160    }
161}
162
163/// A node in the type graph.
164///
165/// The derived [`Hash`][std::hash::Hash] and [`Eq`] implementations
166/// work on the underlying values, so structurally identical types
167/// will be equal. This is important: all types in an [`IrSpec`] are
168/// distinct in memory, but can refer to the same logical type.
169#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
170pub enum IrGraphNode<'a> {
171    Schema(&'a SchemaIrType<'a>),
172    Inline(&'a InlineIrType<'a>),
173    Array(&'a IrType<'a>),
174    Map(&'a IrType<'a>),
175    Nullable(&'a IrType<'a>),
176    Primitive(PrimitiveIrType),
177    Any,
178}
179
180impl<'a> IrGraphNode<'a> {
181    /// Converts an [`IrTypeRef`] to an [`IrGraphNode`],
182    /// recursively resolving referenced schemas.
183    pub fn from_ref(spec: &'a IrSpec<'a>, ty: IrTypeRef<'a>) -> Self {
184        match ty {
185            IrTypeRef::Schema(ty) => IrGraphNode::Schema(ty),
186            IrTypeRef::Inline(ty) => IrGraphNode::Inline(ty),
187            IrTypeRef::Array(ty) => IrGraphNode::Array(ty),
188            IrTypeRef::Map(ty) => IrGraphNode::Map(ty),
189            IrTypeRef::Nullable(ty) => IrGraphNode::Nullable(ty),
190            IrTypeRef::Ref(r) => Self::from_ref(spec, spec.schemas[r.name()].as_ref()),
191            IrTypeRef::Primitive(ty) => IrGraphNode::Primitive(ty),
192            IrTypeRef::Any => IrGraphNode::Any,
193        }
194    }
195}
196
197#[derive(Default)]
198pub(super) struct IrGraphNodeMeta<'a> {
199    /// The set of operations that transitively use this type.
200    pub operations: FxHashSet<ByAddress<&'a IrOperation<'a>>>,
201    /// Opaque extended data for this type.
202    pub extensions: AtomicRefCell<ExtensionMap>,
203}
204
205impl Debug for IrGraphNodeMeta<'_> {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        f.debug_struct("IrGraphNodeMeta")
208            .field("operations", &self.operations)
209            .finish_non_exhaustive()
210    }
211}
212
213/// Visits all the types and references contained within a type.
214#[derive(Debug)]
215struct IrTypeVisitor<'a> {
216    stack: Vec<(Option<&'a IrType<'a>>, &'a IrType<'a>)>,
217}
218
219impl<'a> IrTypeVisitor<'a> {
220    /// Creates a visitor with `root` on the stack of types to visit.
221    #[inline]
222    fn new(roots: impl Iterator<Item = &'a IrType<'a>>) -> Self {
223        let mut stack = roots.map(|root| (None, root)).collect_vec();
224        stack.reverse();
225        Self { stack }
226    }
227}
228
229impl<'a> Iterator for IrTypeVisitor<'a> {
230    type Item = (Option<&'a IrType<'a>>, &'a IrType<'a>);
231
232    fn next(&mut self) -> Option<Self::Item> {
233        let (parent, top) = self.stack.pop()?;
234        match top {
235            IrType::Array(ty) => {
236                self.stack.push((Some(top), ty.as_ref()));
237            }
238            IrType::Map(ty) => {
239                self.stack.push((Some(top), ty.as_ref()));
240            }
241            IrType::Nullable(ty) => {
242                self.stack.push((Some(top), ty.as_ref()));
243            }
244            IrType::Schema(SchemaIrType::Struct(_, ty)) => {
245                self.stack
246                    .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
247            }
248            IrType::Schema(SchemaIrType::Untagged(_, ty)) => {
249                self.stack.extend(
250                    ty.variants
251                        .iter()
252                        .filter_map(|variant| match variant {
253                            IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
254                            _ => None,
255                        })
256                        .rev(),
257                );
258            }
259            IrType::Schema(SchemaIrType::Tagged(_, ty)) => {
260                self.stack.extend(
261                    ty.variants
262                        .iter()
263                        .map(|variant| (Some(top), &variant.ty))
264                        .rev(),
265                );
266            }
267            IrType::Schema(SchemaIrType::Enum(..)) => (),
268            IrType::Any => (),
269            IrType::Primitive(_) => (),
270            IrType::Inline(ty) => match ty {
271                InlineIrType::Enum(..) => (),
272                InlineIrType::Tagged(_, ty) => {
273                    self.stack.extend(
274                        ty.variants
275                            .iter()
276                            .map(|variant| (Some(top), &variant.ty))
277                            .rev(),
278                    );
279                }
280                InlineIrType::Untagged(_, ty) => {
281                    self.stack.extend(
282                        ty.variants
283                            .iter()
284                            .filter_map(|variant| match variant {
285                                IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
286                                _ => None,
287                            })
288                            .rev(),
289                    );
290                }
291                InlineIrType::Struct(_, ty) => {
292                    self.stack
293                        .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
294                }
295            },
296            IrType::Ref(_) => (),
297        }
298        Some((parent, top))
299    }
300}
301
302/// A map that can store one value for each type.
303pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
304
305pub trait Extension: Any + Send + Sync {
306    fn into_inner(self: Box<Self>) -> Box<dyn Any>;
307}
308
309impl dyn Extension {
310    #[inline]
311    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
312        (self as &dyn Any).downcast_ref::<T>()
313    }
314}
315
316impl<T: Send + Sync + 'static> Extension for T {
317    #[inline]
318    fn into_inner(self: Box<Self>) -> Box<dyn Any> {
319        self
320    }
321}