1use std::{
2 any::{Any, TypeId},
3 fmt::Debug,
4};
5
6use atomic_refcell::AtomicRefCell;
7use by_address::ByAddress;
8use fixedbitset::FixedBitSet;
9use itertools::Itertools;
10use petgraph::{
11 Direction,
12 adj::UnweightedList,
13 algo::{TarjanScc, tred},
14 data::Build,
15 graph::{DiGraph, NodeIndex},
16 visit::{IntoNeighbors, NodeCount},
17};
18use rustc_hash::{FxHashMap, FxHashSet};
19
20use super::{
21 spec::IrSpec,
22 types::{
23 InlineIrType, IrOperation, IrType, IrTypeRef, IrUntaggedVariant, PrimitiveIrType,
24 SchemaIrType,
25 },
26 views::{operation::IrOperationView, schema::SchemaIrTypeView, wrappers::IrPrimitiveView},
27};
28
29pub(super) type IrGraphG<'a> = DiGraph<IrGraphNode<'a>, (), usize>;
31
32#[derive(Debug)]
35pub struct IrGraph<'a> {
36 pub(super) spec: &'a IrSpec<'a>,
37 pub(super) g: IrGraphG<'a>,
38 pub(super) indices: FxHashMap<IrGraphNode<'a>, NodeIndex<usize>>,
40 pub(super) circular_refs: FxHashSet<(NodeIndex<usize>, NodeIndex<usize>)>,
42 pub(super) metadata: IrGraphMetadata<'a>,
44}
45
46impl<'a> IrGraph<'a> {
47 pub fn new(spec: &'a IrSpec<'a>) -> Self {
48 let mut g = IrGraphG::default();
49 let mut indices = FxHashMap::default();
50
51 let tys = IrTypeVisitor::new(
54 spec.schemas
55 .values()
56 .chain(spec.operations.iter().flat_map(|op| op.types())),
57 );
58
59 for (parent, child) in tys {
61 use std::collections::hash_map::Entry;
62 let &mut to = match indices.entry(IrGraphNode::from_ref(spec, child.as_ref())) {
63 Entry::Occupied(entry) => entry.into_mut(),
67 Entry::Vacant(entry) => {
68 let index = g.add_node(*entry.key());
69 entry.insert(index)
70 }
71 };
72 if let Some(parent) = parent {
73 let &mut from = match indices.entry(IrGraphNode::from_ref(spec, parent.as_ref())) {
74 Entry::Occupied(entry) => entry.into_mut(),
75 Entry::Vacant(entry) => {
76 let index = g.add_node(*entry.key());
77 entry.insert(index)
78 }
79 };
80 g.add_edge(from, to, ());
82 }
83 }
84
85 let sccs = TopoSccs::new(&g);
86
87 let circular_refs = {
90 let mut edges = FxHashSet::default();
91 for members in sccs.iter() {
92 for node in members.ones().map(NodeIndex::new) {
93 edges.extend(
94 g.neighbors(node)
95 .filter(|neighbor| members.contains(neighbor.index()))
96 .map(|neighbor| (node, neighbor)),
97 );
98 }
99 }
100 edges
101 };
102
103 let metadata = {
104 let mut metadata = IrGraphMetadata::default();
105
106 for op in &spec.operations {
109 metadata.operations.entry(ByAddress(op)).or_default().types = op
110 .types()
111 .filter_map(|ty| {
112 indices
113 .get(&IrGraphNode::from_ref(spec, ty.as_ref()))
114 .map(|node| node.index())
115 })
116 .collect();
117 }
118
119 {
122 let condensation = sccs.condensation();
125
126 let (_, closure) = tred::dag_transitive_reduction_closure(&condensation);
128
129 let mut deps_by_scc =
132 vec![FixedBitSet::with_capacity(g.node_count()); condensation.node_count()];
133 for scc_index in condensation.node_indices() {
134 for dep_scc_index in closure.neighbors(scc_index) {
135 deps_by_scc[scc_index].union_with(sccs.members(dep_scc_index));
136 }
137 deps_by_scc[scc_index].union_with(sccs.members(scc_index));
140 }
141
142 for node in g.node_indices() {
143 let topo_index = sccs.topo_index(node);
144 let mut deps = deps_by_scc[topo_index].clone();
145
146 deps.remove(node.index());
148
149 metadata
150 .schemas
151 .entry(node)
152 .or_default()
153 .dependencies
154 .union_with(&deps);
155
156 for index in deps.into_ones().map(NodeIndex::new) {
159 metadata
160 .schemas
161 .entry(index)
162 .or_default()
163 .dependents
164 .grow_and_insert(node.index());
165 }
166 }
167 }
168
169 for op in &spec.operations {
172 let meta = &metadata.operations[&ByAddress(op)];
173
174 let mut transitive_deps = FixedBitSet::with_capacity(g.node_count());
176 for node in meta.types.ones().map(NodeIndex::new) {
177 transitive_deps.insert(node.index());
178 if let Some(meta) = metadata.schemas.get(&node) {
179 transitive_deps.union_with(&meta.dependencies);
180 }
181 }
182
183 for index in transitive_deps.ones().map(NodeIndex::new) {
185 metadata
186 .schemas
187 .entry(index)
188 .or_default()
189 .used_by
190 .insert(ByAddress(op));
191 }
192 }
193
194 metadata
195 };
196
197 Self {
198 spec,
199 indices,
200 g,
201 circular_refs,
202 metadata,
203 }
204 }
205
206 #[inline]
208 pub fn spec(&self) -> &'a IrSpec<'a> {
209 self.spec
210 }
211
212 #[inline]
214 pub fn schemas(&self) -> impl Iterator<Item = SchemaIrTypeView<'_>> {
215 self.g
216 .node_indices()
217 .filter_map(|index| match self.g[index] {
218 IrGraphNode::Schema(ty) => Some(SchemaIrTypeView::new(self, index, ty)),
219 _ => None,
220 })
221 }
222
223 #[inline]
226 pub fn primitives(&self) -> impl Iterator<Item = IrPrimitiveView<'_>> {
227 self.g
228 .node_indices()
229 .filter_map(|index| match self.g[index] {
230 IrGraphNode::Primitive(ty) => Some(IrPrimitiveView::new(self, index, ty)),
231 _ => None,
232 })
233 }
234
235 #[inline]
237 pub fn operations(&self) -> impl Iterator<Item = IrOperationView<'_>> {
238 self.spec
239 .operations
240 .iter()
241 .map(move |op| IrOperationView::new(self, op))
242 }
243}
244
245#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
252pub enum IrGraphNode<'a> {
253 Schema(&'a SchemaIrType<'a>),
254 Inline(&'a InlineIrType<'a>),
255 Array(&'a IrType<'a>),
256 Map(&'a IrType<'a>),
257 Optional(&'a IrType<'a>),
258 Primitive(PrimitiveIrType),
259 Any,
260}
261
262impl<'a> IrGraphNode<'a> {
263 pub fn from_ref(spec: &'a IrSpec<'a>, ty: IrTypeRef<'a>) -> Self {
266 match ty {
267 IrTypeRef::Schema(ty) => IrGraphNode::Schema(ty),
268 IrTypeRef::Inline(ty) => IrGraphNode::Inline(ty),
269 IrTypeRef::Array(ty) => IrGraphNode::Array(ty),
270 IrTypeRef::Map(ty) => IrGraphNode::Map(ty),
271 IrTypeRef::Optional(ty) => IrGraphNode::Optional(ty),
272 IrTypeRef::Ref(r) => Self::from_ref(spec, spec.schemas[r.name()].as_ref()),
273 IrTypeRef::Primitive(ty) => IrGraphNode::Primitive(ty),
274 IrTypeRef::Any => IrGraphNode::Any,
275 }
276 }
277}
278
279#[derive(Debug, Default)]
281pub struct IrGraphMetadata<'a> {
282 pub schemas: FxHashMap<NodeIndex<usize>, IrGraphNodeMeta<'a>>,
283 pub operations: FxHashMap<ByAddress<&'a IrOperation<'a>>, IrGraphOperationMeta>,
284}
285
286#[derive(Debug, Default)]
289pub struct IrGraphOperationMeta {
290 pub types: FixedBitSet,
293}
294
295#[derive(Default)]
297pub(super) struct IrGraphNodeMeta<'a> {
298 pub used_by: FxHashSet<ByAddress<&'a IrOperation<'a>>>,
300 pub dependencies: FixedBitSet,
302 pub dependents: FixedBitSet,
304 pub extensions: AtomicRefCell<ExtensionMap>,
306}
307
308impl Debug for IrGraphNodeMeta<'_> {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 f.debug_struct("IrGraphNodeMeta")
311 .field("used_by", &self.used_by)
312 .field("dependencies", &self.dependencies)
313 .field("dependents", &self.dependents)
314 .finish_non_exhaustive()
315 }
316}
317
318#[derive(Debug)]
320struct IrTypeVisitor<'a> {
321 stack: Vec<(Option<&'a IrType<'a>>, &'a IrType<'a>)>,
322}
323
324impl<'a> IrTypeVisitor<'a> {
325 #[inline]
327 fn new(roots: impl Iterator<Item = &'a IrType<'a>>) -> Self {
328 let mut stack = roots.map(|root| (None, root)).collect_vec();
329 stack.reverse();
330 Self { stack }
331 }
332}
333
334impl<'a> Iterator for IrTypeVisitor<'a> {
335 type Item = (Option<&'a IrType<'a>>, &'a IrType<'a>);
336
337 fn next(&mut self) -> Option<Self::Item> {
338 let (parent, top) = self.stack.pop()?;
339 match top {
340 IrType::Array(ty) => {
341 self.stack.push((Some(top), ty.as_ref()));
342 }
343 IrType::Map(ty) => {
344 self.stack.push((Some(top), ty.as_ref()));
345 }
346 IrType::Optional(ty) => {
347 self.stack.push((Some(top), ty.as_ref()));
348 }
349 IrType::Schema(SchemaIrType::Struct(_, ty)) => {
350 self.stack
351 .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
352 }
353 IrType::Schema(SchemaIrType::Untagged(_, ty)) => {
354 self.stack.extend(
355 ty.variants
356 .iter()
357 .filter_map(|variant| match variant {
358 IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
359 _ => None,
360 })
361 .rev(),
362 );
363 }
364 IrType::Schema(SchemaIrType::Tagged(_, ty)) => {
365 self.stack.extend(
366 ty.variants
367 .iter()
368 .map(|variant| (Some(top), &variant.ty))
369 .rev(),
370 );
371 }
372 IrType::Schema(SchemaIrType::Enum(..)) => (),
373 IrType::Any => (),
374 IrType::Primitive(_) => (),
375 IrType::Inline(ty) => match ty {
376 InlineIrType::Enum(..) => (),
377 InlineIrType::Tagged(_, ty) => {
378 self.stack.extend(
379 ty.variants
380 .iter()
381 .map(|variant| (Some(top), &variant.ty))
382 .rev(),
383 );
384 }
385 InlineIrType::Untagged(_, ty) => {
386 self.stack.extend(
387 ty.variants
388 .iter()
389 .filter_map(|variant| match variant {
390 IrUntaggedVariant::Some(_, ty) => Some((Some(top), ty)),
391 _ => None,
392 })
393 .rev(),
394 );
395 }
396 InlineIrType::Struct(_, ty) => {
397 self.stack
398 .extend(ty.fields.iter().map(|field| (Some(top), &field.ty)).rev());
399 }
400 },
401 IrType::Ref(_) => (),
402 }
403 Some((parent, top))
404 }
405}
406
407pub(super) type ExtensionMap = FxHashMap<TypeId, Box<dyn Extension>>;
409
410pub trait Extension: Any + Send + Sync {
411 fn into_inner(self: Box<Self>) -> Box<dyn Any>;
412}
413
414impl dyn Extension {
415 #[inline]
416 pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
417 (self as &dyn Any).downcast_ref::<T>()
418 }
419}
420
421impl<T: Send + Sync + 'static> Extension for T {
422 #[inline]
423 fn into_inner(self: Box<Self>) -> Box<dyn Any> {
424 self
425 }
426}
427
428struct TopoSccs<'a, N, E> {
435 graph: &'a DiGraph<N, E, usize>,
436 tarjan: TarjanScc<NodeIndex<usize>>,
437 sccs: Vec<FixedBitSet>,
438}
439
440impl<'a, N, E> TopoSccs<'a, N, E> {
441 fn new(graph: &'a DiGraph<N, E, usize>) -> Self {
442 let mut sccs = Vec::new();
443 let mut tarjan = TarjanScc::new();
444 tarjan.run(graph, |scc_nodes| {
445 sccs.push(scc_nodes.iter().map(|node| node.index()).collect());
446 });
447 sccs.reverse();
450 Self {
451 graph,
452 tarjan,
453 sccs,
454 }
455 }
456
457 #[inline]
459 fn topo_index(&self, node: NodeIndex<usize>) -> usize {
460 self.sccs.len() - 1 - self.tarjan.node_component_index(self.graph, node)
463 }
464
465 #[inline]
467 fn members(&self, index: usize) -> &FixedBitSet {
468 &self.sccs[index]
469 }
470
471 #[inline]
473 fn iter(&self) -> std::slice::Iter<'_, FixedBitSet> {
474 self.sccs.iter()
475 }
476
477 fn condensation(&self) -> UnweightedList<usize> {
484 let scc_count = self.sccs.len();
485 let mut dag = UnweightedList::with_capacity(scc_count);
486 for to in 0..scc_count {
487 dag.add_node();
488 for index in self.sccs[to].ones().map(NodeIndex::new) {
489 for neighbor in self.graph.neighbors_directed(index, Direction::Incoming) {
490 let from = self.topo_index(neighbor);
491 if from != to {
492 dag.update_edge(from, to, ());
493 }
494 }
495 }
496 }
497 dag
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 use crate::tests::assert_matches;
506
507 fn linear_graph() -> DiGraph<(), (), usize> {
509 let mut g = DiGraph::default();
510 let a = g.add_node(());
511 let b = g.add_node(());
512 let c = g.add_node(());
513 g.extend_with_edges([(a, b), (b, c)]);
514 g
515 }
516
517 fn cyclic_graph() -> DiGraph<(), (), usize> {
519 let mut g = DiGraph::default();
520 let a = g.add_node(());
521 let b = g.add_node(());
522 let c = g.add_node(());
523 let d = g.add_node(());
524 g.extend_with_edges([(a, b), (b, c), (c, a), (d, a)]);
525 g
526 }
527
528 #[test]
531 fn test_linear_graph_has_singleton_sccs() {
532 let g = linear_graph();
533 let sccs = TopoSccs::new(&g);
534 let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
535 assert_matches!(&*sizes, [1, 1, 1]);
536 }
537
538 #[test]
539 fn test_cyclic_graph_has_one_multi_node_scc() {
540 let g = cyclic_graph();
541 let sccs = TopoSccs::new(&g);
542
543 let sizes = sccs.iter().map(|scc| scc.count_ones(..)).collect_vec();
546 assert_matches!(&*sizes, [1, 3]);
547 }
548
549 #[test]
552 fn test_sccs_are_in_topological_order() {
553 let g = cyclic_graph();
554 let sccs = TopoSccs::new(&g);
555
556 let d_topo = sccs.topo_index(3.into());
557 let a_topo = sccs.topo_index(0.into());
558 assert!(
559 d_topo < a_topo,
560 "D should precede A-B-C in topological order"
561 );
562 }
563
564 #[test]
565 fn test_topo_index_consistent_within_scc() {
566 let g = cyclic_graph();
567 let sccs = TopoSccs::new(&g);
568
569 let a_topo = sccs.topo_index(0.into());
572 let b_topo = sccs.topo_index(1.into());
573 let c_topo = sccs.topo_index(2.into());
574
575 assert_eq!(a_topo, b_topo);
576 assert_eq!(b_topo, c_topo);
577 }
578
579 #[test]
582 fn test_condensation_has_correct_node_count() {
583 let g = cyclic_graph();
584 let sccs = TopoSccs::new(&g);
585 let dag = sccs.condensation();
586
587 assert_eq!(dag.node_count(), 2);
588 }
589
590 #[test]
591 fn test_condensation_has_correct_edges() {
592 let g = cyclic_graph();
593 let sccs = TopoSccs::new(&g);
594 let dag = sccs.condensation();
595
596 let d_topo = sccs.topo_index(3.into());
599 let abc_topo = sccs.topo_index(0.into());
600
601 let d_neighbors = dag.neighbors(d_topo).collect_vec();
602 assert_eq!(&*d_neighbors, [abc_topo]);
603
604 let abc_neighbors = dag.neighbors(abc_topo).collect_vec();
605 assert!(abc_neighbors.is_empty());
606 }
607
608 #[test]
609 fn test_condensation_neighbors_in_topological_order() {
610 let mut g = DiGraph::<(), (), usize>::default();
614 let second = g.add_node(());
615 let top = g.add_node(());
616 let first = g.add_node(());
617 g.extend_with_edges([(top, second), (top, first), (first, second)]);
618
619 let sccs = TopoSccs::new(&g);
620 let dag = sccs.condensation();
621
622 let top_topo = sccs.topo_index(top);
623 assert_eq!(top_topo, 0);
624
625 let first_topo = sccs.topo_index(first);
626 assert_eq!(first_topo, 1);
627
628 let second_topo = sccs.topo_index(second);
629 assert_eq!(second_topo, 2);
630
631 let neighbors = dag.neighbors(top_topo).collect_vec();
632 assert_eq!(&*neighbors, [first_topo, second_topo]);
633 }
634}