cosmian_wit_bindgen_gen_core/
lib.rs

1use anyhow::Result;
2use std::collections::{btree_map::Entry, BTreeMap, HashMap, HashSet};
3use std::ops::Deref;
4use std::path::Path;
5use cosmian_wit_parser::abi::Abi;
6use cosmian_wit_parser::*;
7
8pub use cosmian_wit_parser;
9mod ns;
10
11pub use ns::Ns;
12
13/// This is the direction from the user's perspective. Are we importing
14/// functions to call, or defining functions and exporting them to be called?
15///
16/// This is only used outside of `Generator` implementations. Inside of
17/// `Generator` implementations, the `Direction` is translated to an
18/// `AbiVariant` instead. The ABI variant is usually the same as the
19/// `Direction`, but it's different in the case of the Wasmtime host bindings:
20///
21/// In a wasm-calling-wasm use case, one wasm module would use the `Import`
22/// ABI, the other would use the `Export` ABI, and there would be an adapter
23/// layer between the two that translates from one ABI to the other.
24///
25/// But with wasm-calling-host, we don't go through a separate adapter layer;
26/// the binding code we generate on the host side just does everything itself.
27/// So when the host is conceptually "exporting" a function to wasm, it uses
28/// the `Import` ABI so that wasm can also use the `Import` ABI and import it
29/// directly from the host.
30///
31/// These are all implementation details; from the user perspective, and
32/// from the perspective of everything outside of `Generator` implementations,
33/// `export` means I'm exporting functions to be called, and `import` means I'm
34/// importing functions that I'm going to call, in both wasm modules and host
35/// code. The enum here represents this user perspective.
36#[derive(Copy, Clone, Eq, PartialEq)]
37pub enum Direction {
38    Import,
39    Export,
40}
41
42pub trait Generator {
43    fn preprocess_all(&mut self, imports: &[Interface], exports: &[Interface]) {
44        drop((imports, exports));
45    }
46
47    fn preprocess_one(&mut self, iface: &Interface, dir: Direction) {
48        drop((iface, dir));
49    }
50
51    fn type_record(
52        &mut self,
53        iface: &Interface,
54        id: TypeId,
55        name: &str,
56        record: &Record,
57        docs: &Docs,
58    );
59    fn type_variant(
60        &mut self,
61        iface: &Interface,
62        id: TypeId,
63        name: &str,
64        variant: &Variant,
65        docs: &Docs,
66    );
67    fn type_resource(&mut self, iface: &Interface, ty: ResourceId);
68    fn type_alias(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
69    fn type_list(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
70    fn type_pointer(
71        &mut self,
72        iface: &Interface,
73        id: TypeId,
74        name: &str,
75        const_: bool,
76        ty: &Type,
77        docs: &Docs,
78    );
79    fn type_builtin(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs);
80    fn type_push_buffer(
81        &mut self,
82        iface: &Interface,
83        id: TypeId,
84        name: &str,
85        ty: &Type,
86        docs: &Docs,
87    );
88    fn type_pull_buffer(
89        &mut self,
90        iface: &Interface,
91        id: TypeId,
92        name: &str,
93        ty: &Type,
94        docs: &Docs,
95    );
96    // fn const_(&mut self, iface: &Interface, name: &str, ty: &str, val: u64, docs: &Docs);
97    fn import(&mut self, iface: &Interface, func: &Function);
98    fn export(&mut self, iface: &Interface, func: &Function);
99
100    fn finish_one(&mut self, iface: &Interface, files: &mut Files);
101
102    fn finish_all(&mut self, files: &mut Files) {
103        drop(files);
104    }
105
106    fn generate_one(&mut self, iface: &Interface, dir: Direction, files: &mut Files) {
107        self.preprocess_one(iface, dir);
108
109        for (id, ty) in iface.types.iter() {
110            // assert!(ty.foreign_module.is_none()); // TODO
111            let name = match &ty.name {
112                Some(name) => name,
113                None => continue,
114            };
115            match &ty.kind {
116                TypeDefKind::Record(record) => self.type_record(iface, id, name, record, &ty.docs),
117                TypeDefKind::Variant(variant) => {
118                    self.type_variant(iface, id, name, variant, &ty.docs)
119                }
120                TypeDefKind::List(t) => self.type_list(iface, id, name, t, &ty.docs),
121                TypeDefKind::PushBuffer(t) => self.type_push_buffer(iface, id, name, t, &ty.docs),
122                TypeDefKind::PullBuffer(t) => self.type_pull_buffer(iface, id, name, t, &ty.docs),
123                TypeDefKind::Type(t) => self.type_alias(iface, id, name, t, &ty.docs),
124                TypeDefKind::Pointer(t) => self.type_pointer(iface, id, name, false, t, &ty.docs),
125                TypeDefKind::ConstPointer(t) => {
126                    self.type_pointer(iface, id, name, true, t, &ty.docs)
127                }
128            }
129        }
130
131        for (id, _resource) in iface.resources.iter() {
132            self.type_resource(iface, id);
133        }
134
135        // for c in module.constants() {
136        //     self.const_(&c.name, &c.ty, c.value, &c.docs);
137        // }
138
139        for f in iface.functions.iter() {
140            match dir {
141                Direction::Import => self.import(iface, &f),
142                Direction::Export => self.export(iface, &f),
143            }
144        }
145
146        self.finish_one(iface, files)
147    }
148
149    fn generate_all(&mut self, imports: &[Interface], exports: &[Interface], files: &mut Files) {
150        self.preprocess_all(imports, exports);
151
152        for imp in imports {
153            self.generate_one(imp, Direction::Import, files);
154        }
155
156        for exp in exports {
157            self.generate_one(exp, Direction::Export, files);
158        }
159
160        self.finish_all(files);
161    }
162}
163
164#[derive(Default)]
165pub struct Types {
166    type_info: HashMap<TypeId, TypeInfo>,
167    handle_dtors: HashSet<ResourceId>,
168    dtor_funcs: HashSet<String>,
169}
170
171#[derive(Default, Clone, Copy)]
172pub struct TypeInfo {
173    /// Whether or not this type is ever used (transitively) within the
174    /// parameter of a function.
175    pub param: bool,
176
177    /// Whether or not this type is ever used (transitively) within the
178    /// result of a function.
179    pub result: bool,
180
181    /// Whether or not this type (transitively) has a list.
182    pub has_list: bool,
183
184    /// Whether or not this type (transitively) has a handle.
185    pub has_handle: bool,
186
187    /// Whether or not this type (transitively) has a push buffer.
188    pub has_push_buffer: bool,
189
190    /// Whether or not this type (transitively) has a pull buffer.
191    pub has_pull_buffer: bool,
192}
193
194impl std::ops::BitOrAssign for TypeInfo {
195    fn bitor_assign(&mut self, rhs: Self) {
196        self.param |= rhs.param;
197        self.result |= rhs.result;
198        self.has_list |= rhs.has_list;
199        self.has_handle |= rhs.has_handle;
200        self.has_push_buffer |= rhs.has_push_buffer;
201        self.has_pull_buffer |= rhs.has_pull_buffer;
202    }
203}
204
205impl Types {
206    pub fn analyze(&mut self, iface: &Interface) {
207        for (t, _) in iface.types.iter() {
208            self.type_id_info(iface, t);
209        }
210        for f in iface.functions.iter() {
211            for (_, ty) in f.params.iter() {
212                self.set_param_result_ty(iface, ty, true, false);
213            }
214            for (_, ty) in f.results.iter() {
215                self.set_param_result_ty(iface, ty, false, true);
216            }
217            self.maybe_set_preview1_dtor(iface, f);
218        }
219    }
220
221    fn maybe_set_preview1_dtor(&mut self, iface: &Interface, f: &Function) {
222        match f.abi {
223            Abi::Preview1 => {}
224            _ => return,
225        }
226
227        // Dtors only happen when the function has a singular parameter
228        if f.params.len() != 1 {
229            return;
230        }
231
232        // Dtors are inferred to be `${type}_close` right now.
233        let name = f.name.as_str();
234        let prefix = match name.strip_suffix("_close") {
235            Some(prefix) => prefix,
236            None => return,
237        };
238
239        // The singular parameter type name must be the prefix of this
240        // function's own name.
241        let resource = match find_handle(iface, &f.params[0].1) {
242            Some(id) => id,
243            None => return,
244        };
245        if iface.resources[resource].name != prefix {
246            return;
247        }
248
249        self.handle_dtors.insert(resource);
250        self.dtor_funcs.insert(f.name.to_string());
251
252        fn find_handle(iface: &Interface, ty: &Type) -> Option<ResourceId> {
253            match ty {
254                Type::Handle(r) => Some(*r),
255                Type::Id(id) => match &iface.types[*id].kind {
256                    TypeDefKind::Type(t) => find_handle(iface, t),
257                    _ => None,
258                },
259                _ => None,
260            }
261        }
262    }
263
264    pub fn get(&self, id: TypeId) -> TypeInfo {
265        self.type_info[&id]
266    }
267
268    pub fn has_preview1_dtor(&self, resource: ResourceId) -> bool {
269        self.handle_dtors.contains(&resource)
270    }
271
272    pub fn is_preview1_dtor_func(&self, func: &Function) -> bool {
273        self.dtor_funcs.contains(&func.name)
274    }
275
276    pub fn type_id_info(&mut self, iface: &Interface, ty: TypeId) -> TypeInfo {
277        if let Some(info) = self.type_info.get(&ty) {
278            return *info;
279        }
280        let mut info = TypeInfo::default();
281        match &iface.types[ty].kind {
282            TypeDefKind::Record(r) => {
283                for field in r.fields.iter() {
284                    info |= self.type_info(iface, &field.ty);
285                }
286            }
287            TypeDefKind::Variant(v) => {
288                for case in v.cases.iter() {
289                    if let Some(ty) = &case.ty {
290                        info |= self.type_info(iface, ty);
291                    }
292                }
293            }
294            TypeDefKind::List(ty) => {
295                info = self.type_info(iface, ty);
296                info.has_list = true;
297            }
298            TypeDefKind::PushBuffer(ty) => {
299                info = self.type_info(iface, ty);
300                info.has_push_buffer = true;
301            }
302            TypeDefKind::PullBuffer(ty) => {
303                info = self.type_info(iface, ty);
304                info.has_pull_buffer = true;
305            }
306            TypeDefKind::ConstPointer(ty) | TypeDefKind::Pointer(ty) | TypeDefKind::Type(ty) => {
307                info = self.type_info(iface, ty)
308            }
309        }
310        self.type_info.insert(ty, info);
311        return info;
312    }
313
314    pub fn type_info(&mut self, iface: &Interface, ty: &Type) -> TypeInfo {
315        let mut info = TypeInfo::default();
316        match ty {
317            Type::Handle(_) => info.has_handle = true,
318            Type::Id(id) => return self.type_id_info(iface, *id),
319            _ => {}
320        }
321        info
322    }
323
324    fn set_param_result_id(&mut self, iface: &Interface, ty: TypeId, param: bool, result: bool) {
325        match &iface.types[ty].kind {
326            TypeDefKind::Record(r) => {
327                for field in r.fields.iter() {
328                    self.set_param_result_ty(iface, &field.ty, param, result)
329                }
330            }
331            TypeDefKind::Variant(v) => {
332                for case in v.cases.iter() {
333                    if let Some(ty) = &case.ty {
334                        self.set_param_result_ty(iface, ty, param, result)
335                    }
336                }
337            }
338            TypeDefKind::List(ty)
339            | TypeDefKind::PushBuffer(ty)
340            | TypeDefKind::PullBuffer(ty)
341            | TypeDefKind::Pointer(ty)
342            | TypeDefKind::ConstPointer(ty) => self.set_param_result_ty(iface, ty, param, result),
343            TypeDefKind::Type(ty) => self.set_param_result_ty(iface, ty, param, result),
344        }
345    }
346
347    fn set_param_result_ty(&mut self, iface: &Interface, ty: &Type, param: bool, result: bool) {
348        match ty {
349            Type::Id(id) => {
350                self.type_id_info(iface, *id);
351                let info = self.type_info.get_mut(id).unwrap();
352                if (param && !info.param) || (result && !info.result) {
353                    info.param = info.param || param;
354                    info.result = info.result || result;
355                    self.set_param_result_id(iface, *id, param, result);
356                }
357            }
358            _ => {}
359        }
360    }
361}
362
363#[derive(Default)]
364pub struct Files {
365    files: BTreeMap<String, Vec<u8>>,
366}
367
368impl Files {
369    pub fn push(&mut self, name: &str, contents: &[u8]) {
370        match self.files.entry(name.to_owned()) {
371            Entry::Vacant(entry) => {
372                entry.insert(contents.to_owned());
373            }
374            Entry::Occupied(ref mut entry) => {
375                entry.get_mut().extend_from_slice(contents);
376            }
377        }
378    }
379
380    pub fn iter(&self) -> impl Iterator<Item = (&'_ str, &'_ [u8])> {
381        self.files.iter().map(|p| (p.0.as_str(), p.1.as_slice()))
382    }
383}
384
385pub fn load(path: impl AsRef<Path>) -> Result<Interface> {
386    Interface::parse_file(path)
387}
388
389#[derive(Default)]
390pub struct Source {
391    s: String,
392    indent: usize,
393}
394
395impl Source {
396    pub fn push_str(&mut self, src: &str) {
397        let lines = src.lines().collect::<Vec<_>>();
398        for (i, line) in lines.iter().enumerate() {
399            let trimmed = line.trim();
400            if trimmed.starts_with("}") && self.s.ends_with("  ") {
401                self.s.pop();
402                self.s.pop();
403            }
404            self.s.push_str(if lines.len() == 1 {
405                line
406            } else {
407                line.trim_start()
408            });
409            if trimmed.ends_with('{') {
410                self.indent += 1;
411            }
412            if trimmed.starts_with('}') {
413                self.indent -= 1;
414            }
415            if i != lines.len() - 1 || src.ends_with("\n") {
416                self.newline();
417            }
418        }
419    }
420
421    pub fn indent(&mut self, amt: usize) {
422        self.indent += amt;
423    }
424
425    pub fn deindent(&mut self, amt: usize) {
426        self.indent -= amt;
427    }
428
429    fn newline(&mut self) {
430        self.s.push_str("\n");
431        for _ in 0..self.indent {
432            self.s.push_str("  ");
433        }
434    }
435
436    pub fn as_mut_string(&mut self) -> &mut String {
437        &mut self.s
438    }
439}
440
441impl Deref for Source {
442    type Target = str;
443    fn deref(&self) -> &str {
444        &self.s
445    }
446}
447
448impl From<Source> for String {
449    fn from(s: Source) -> String {
450        s.s
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::{Generator, Source};
457
458    #[test]
459    fn simple_append() {
460        let mut s = Source::default();
461        s.push_str("x");
462        assert_eq!(s.s, "x");
463        s.push_str("y");
464        assert_eq!(s.s, "xy");
465        s.push_str("z ");
466        assert_eq!(s.s, "xyz ");
467        s.push_str(" a ");
468        assert_eq!(s.s, "xyz  a ");
469        s.push_str("\na");
470        assert_eq!(s.s, "xyz  a \na");
471    }
472
473    #[test]
474    fn newline_remap() {
475        let mut s = Source::default();
476        s.push_str("function() {\n");
477        s.push_str("y\n");
478        s.push_str("}\n");
479        assert_eq!(s.s, "function() {\n  y\n}\n");
480    }
481
482    #[test]
483    fn if_else() {
484        let mut s = Source::default();
485        s.push_str("if() {\n");
486        s.push_str("y\n");
487        s.push_str("} else if () {\n");
488        s.push_str("z\n");
489        s.push_str("}\n");
490        assert_eq!(s.s, "if() {\n  y\n} else if () {\n  z\n}\n");
491    }
492
493    #[test]
494    fn trim_ws() {
495        let mut s = Source::default();
496        s.push_str(
497            "function() {
498                x
499        }",
500        );
501        assert_eq!(s.s, "function() {\n  x\n}");
502    }
503
504    #[test]
505    fn generator_is_object_safe() {
506        fn _assert(_: &dyn Generator) {}
507    }
508}