wit_parser/
live.rs

1use crate::{
2    Function, FunctionKind, InterfaceId, Resolve, Type, TypeDef, TypeDefKind, TypeId, WorldId,
3    WorldItem,
4};
5use indexmap::IndexSet;
6
7#[derive(Default)]
8pub struct LiveTypes {
9    set: IndexSet<TypeId>,
10}
11
12impl LiveTypes {
13    pub fn iter(&self) -> impl Iterator<Item = TypeId> + '_ {
14        self.set.iter().copied()
15    }
16
17    pub fn len(&self) -> usize {
18        self.set.len()
19    }
20
21    pub fn add_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
22        self.visit_interface(resolve, iface);
23    }
24
25    pub fn add_world(&mut self, resolve: &Resolve, world: WorldId) {
26        self.visit_world(resolve, world);
27    }
28
29    pub fn add_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
30        self.visit_world_item(resolve, item);
31    }
32
33    pub fn add_func(&mut self, resolve: &Resolve, func: &Function) {
34        self.visit_func(resolve, func);
35    }
36
37    pub fn add_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
38        self.visit_type_id(resolve, ty);
39    }
40
41    pub fn add_type(&mut self, resolve: &Resolve, ty: &Type) {
42        self.visit_type(resolve, ty);
43    }
44}
45
46impl TypeIdVisitor for LiveTypes {
47    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
48        !self.set.contains(&id)
49    }
50
51    fn after_visit_type_id(&mut self, id: TypeId) {
52        assert!(self.set.insert(id));
53    }
54}
55
56/// Helper trait to walk the structure of a type and visit all `TypeId`s that
57/// it refers to, possibly transitively.
58pub trait TypeIdVisitor {
59    /// Callback invoked just before a type is visited.
60    ///
61    /// If this function returns `false` the type is not visited, otherwise it's
62    /// recursed into.
63    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
64        let _ = id;
65        true
66    }
67
68    /// Callback invoked once a type is finished being visited.
69    fn after_visit_type_id(&mut self, id: TypeId) {
70        let _ = id;
71    }
72
73    fn visit_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
74        let iface = &resolve.interfaces[iface];
75        for (_, id) in iface.types.iter() {
76            self.visit_type_id(resolve, *id);
77        }
78        for (_, func) in iface.functions.iter() {
79            self.visit_func(resolve, func);
80        }
81    }
82
83    fn visit_world(&mut self, resolve: &Resolve, world: WorldId) {
84        let world = &resolve.worlds[world];
85        for (_, item) in world.imports.iter().chain(world.exports.iter()) {
86            self.visit_world_item(resolve, item);
87        }
88    }
89
90    fn visit_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
91        match item {
92            WorldItem::Interface { id, .. } => self.visit_interface(resolve, *id),
93            WorldItem::Function(f) => self.visit_func(resolve, f),
94            WorldItem::Type(t) => self.visit_type_id(resolve, *t),
95        }
96    }
97
98    fn visit_func(&mut self, resolve: &Resolve, func: &Function) {
99        match func.kind {
100            // This resource is live as it's attached to a static method but
101            // it's not guaranteed to be present in either params or results, so
102            // be sure to attach it here.
103            FunctionKind::Static(id) | FunctionKind::AsyncStatic(id) => {
104                self.visit_type_id(resolve, id)
105            }
106
107            // The resource these are attached to is in the params/results, so
108            // no need to re-add it here.
109            FunctionKind::Method(_)
110            | FunctionKind::AsyncMethod(_)
111            | FunctionKind::Constructor(_) => {}
112
113            FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {}
114        }
115
116        for (_, ty) in func.params.iter() {
117            self.visit_type(resolve, ty);
118        }
119        if let Some(ty) = &func.result {
120            self.visit_type(resolve, ty);
121        }
122    }
123
124    fn visit_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
125        if self.before_visit_type_id(ty) {
126            self.visit_type_def(resolve, &resolve.types[ty]);
127            self.after_visit_type_id(ty);
128        }
129    }
130
131    fn visit_type_def(&mut self, resolve: &Resolve, ty: &TypeDef) {
132        match &ty.kind {
133            TypeDefKind::Type(t)
134            | TypeDefKind::List(t)
135            | TypeDefKind::FixedSizeList(t, ..)
136            | TypeDefKind::Option(t)
137            | TypeDefKind::Future(Some(t))
138            | TypeDefKind::Stream(Some(t)) => self.visit_type(resolve, t),
139            TypeDefKind::Handle(handle) => match handle {
140                crate::Handle::Own(ty) => self.visit_type_id(resolve, *ty),
141                crate::Handle::Borrow(ty) => self.visit_type_id(resolve, *ty),
142            },
143            TypeDefKind::Resource => {}
144            TypeDefKind::Record(r) => {
145                for field in r.fields.iter() {
146                    self.visit_type(resolve, &field.ty);
147                }
148            }
149            TypeDefKind::Tuple(r) => {
150                for ty in r.types.iter() {
151                    self.visit_type(resolve, ty);
152                }
153            }
154            TypeDefKind::Variant(v) => {
155                for case in v.cases.iter() {
156                    if let Some(ty) = &case.ty {
157                        self.visit_type(resolve, ty);
158                    }
159                }
160            }
161            TypeDefKind::Result(r) => {
162                if let Some(ty) = &r.ok {
163                    self.visit_type(resolve, ty);
164                }
165                if let Some(ty) = &r.err {
166                    self.visit_type(resolve, ty);
167                }
168            }
169            TypeDefKind::Flags(_)
170            | TypeDefKind::Enum(_)
171            | TypeDefKind::Future(None)
172            | TypeDefKind::Stream(None) => {}
173            TypeDefKind::Unknown => unreachable!(),
174        }
175    }
176
177    fn visit_type(&mut self, resolve: &Resolve, ty: &Type) {
178        match ty {
179            Type::Id(id) => self.visit_type_id(resolve, *id),
180            _ => {}
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::{LiveTypes, Resolve};
188
189    fn live(wit: &str, ty: &str) -> Vec<String> {
190        let mut resolve = Resolve::default();
191        resolve.push_str("test.wit", wit).unwrap();
192        let (_, interface) = resolve.interfaces.iter().next_back().unwrap();
193        let ty = interface.types[ty];
194        let mut live = LiveTypes::default();
195        live.add_type_id(&resolve, ty);
196
197        live.iter()
198            .filter_map(|ty| resolve.types[ty].name.clone())
199            .collect()
200    }
201
202    #[test]
203    fn no_deps() {
204        let types = live(
205            "
206                package foo:bar;
207
208                interface foo {
209                    type t = u32;
210                }
211            ",
212            "t",
213        );
214        assert_eq!(types, ["t"]);
215    }
216
217    #[test]
218    fn one_dep() {
219        let types = live(
220            "
221                package foo:bar;
222
223                interface foo {
224                    type t = u32;
225                    type u = t;
226                }
227            ",
228            "u",
229        );
230        assert_eq!(types, ["t", "u"]);
231    }
232
233    #[test]
234    fn chain() {
235        let types = live(
236            "
237                package foo:bar;
238
239                interface foo {
240                    resource t1;
241                    record t2 {
242                        x: t1,
243                    }
244                    variant t3 {
245                        x(t2),
246                    }
247                    flags t4 { a }
248                    enum t5 { a }
249                    type t6 = tuple<t5, t4, t3>;
250                }
251            ",
252            "t6",
253        );
254        assert_eq!(types, ["t5", "t4", "t1", "t2", "t3", "t6"]);
255    }
256}