Skip to main content

duchess_reflect/
upcasts.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use crate::{
4    class_info::{ClassInfo, ClassRef, DotId, Id},
5    substitution::{Substitute, Substitution},
6};
7
8/// A map storing the transitive upcasts for each class that we are generating (and potentially additional classes).
9///
10/// There is one caveat: we only compute the transitive superclasses based on the classes that are input to
11/// the proc macro. The problem is that we can only inspect the tokens presented to us. While we could reflect
12/// on the Java classes directly, we don't know what subset of the supertypes the user has chosen to reflect into
13/// Rust. Therefore, we stop our transitive upcasts at the "water's edge" -- i.e., at the point where we
14/// encounter classes that are outside our package.
15#[derive(Default, Debug)]
16pub struct Upcasts {
17    map: BTreeMap<DotId, ClassUpcasts>,
18}
19
20#[derive(Debug)]
21pub struct ClassUpcasts {
22    generics: Vec<Id>,
23    extends: BTreeSet<ClassRef>,
24}
25
26impl<'a> FromIterator<&'a ClassInfo> for Upcasts {
27    fn from_iter<T: IntoIterator<Item = &'a ClassInfo>>(iter: T) -> Self {
28        let mut upcasts = Upcasts::default();
29
30        for class_info in iter {
31            upcasts.insert_direct_upcasts(class_info);
32        }
33
34        upcasts.insert_hardcoded_upcasts();
35
36        upcasts.compute_transitive_upcasts();
37
38        upcasts
39    }
40}
41
42impl Upcasts {
43    /// Returns the transitive superclasses / interfaces of `name`.
44    /// These will reference generic parameters from in the class declaration of `name`.
45    /// This does NOT include the "reflexive" upcast from `T` to `T`.
46    pub fn upcasts_for_generated_class(&self, name: &DotId) -> &BTreeSet<ClassRef> {
47        &self.map[name].extends
48    }
49
50    /// Insert the direct (declared by user) superclasses of `class` into the map.
51    fn insert_direct_upcasts(&mut self, class: &ClassInfo) {
52        let mut upcasts = ClassUpcasts {
53            generics: class.generics.iter().map(|g| g.id.clone()).collect(),
54            extends: BTreeSet::default(),
55        };
56
57        // Include direct upcasts declared by the user.
58        for c in class.extends.iter().chain(&class.implements) {
59            upcasts.extends.insert(c.clone());
60        }
61
62        // Everything can be upcast to object.
63        let object = DotId::object();
64        if class.name != object {
65            upcasts.extends.insert(ClassRef {
66                name: object,
67                generics: vec![],
68            });
69        }
70
71        let old_value = self.map.insert(class.name.clone(), upcasts);
72        assert!(old_value.is_none());
73    }
74
75    fn insert_hardcoded_upcasts(&mut self) {
76        let mut insert = |c: DotId, d: DotId| {
77            self.map
78                .entry(c)
79                .or_insert_with(|| ClassUpcasts {
80                    generics: vec![],
81                    extends: BTreeSet::default(),
82                })
83                .extends
84                .insert(ClassRef {
85                    name: d,
86                    generics: vec![],
87                });
88        };
89
90        insert(DotId::runtime_exception(), DotId::exception());
91        insert(DotId::exception(), DotId::throwable());
92        insert(DotId::throwable(), DotId::object());
93    }
94
95    /// Extend the map with transitive upcasts for each of its entries. i.e., if class `A` extends `B`,
96    /// and `B` extends `C`, then `A` extends `C`.
97    fn compute_transitive_upcasts(&mut self) {
98        let class_names: Vec<DotId> = self.map.keys().cloned().collect();
99        loop {
100            let mut changed = false;
101
102            for n in &class_names {
103                // Extend by one step: for each class `c` extended by `n`,
104                // find superclasses of `c`.
105                let indirect_upcasts: Vec<ClassRef> = self.map[n]
106                    .extends
107                    .iter()
108                    .flat_map(|c| self.upcasts(c))
109                    .collect();
110
111                // Insert those into the set of superclasses for `n`.
112                // If the set changed size, then we added a new entry,
113                // so we have to iterate again.
114                let c_u = self.map.get_mut(n).unwrap();
115                let len_before = c_u.extends.len();
116                c_u.extends.extend(indirect_upcasts);
117                changed |= c_u.extends.len() != len_before;
118            }
119
120            if !changed {
121                break;
122            }
123        }
124    }
125
126    /// Find the upcasts for `class_ref`: look up the current map entry for
127    /// the given class and substitute the given values for its generic parameters.
128    fn upcasts(&self, class_ref: &ClassRef) -> Vec<ClassRef> {
129        let Some(c_u) = self.map.get(&class_ref.name) else {
130            // Upcasts to classes outside our translation unit:
131            // no visibility, just return empty vector.
132            return vec![];
133        };
134
135        assert_eq!(class_ref.generics.len(), c_u.generics.len());
136
137        let subst: Substitution<'_> = c_u.generics.iter().zip(&class_ref.generics).collect();
138
139        c_u.extends.iter().map(|c| c.substitute(&subst)).collect()
140    }
141}