cosmian_wit_bindgen_gen_rust/
lib.rs

1use cosmian_wit_bindgen_gen_core::cosmian_wit_parser::abi::{Bitcast, LiftLower, WasmType};
2use cosmian_wit_bindgen_gen_core::{cosmian_wit_parser::*, TypeInfo, Types};
3use heck::*;
4
5#[derive(Debug, Copy, Clone, PartialEq)]
6pub enum TypeMode {
7    Owned,
8    AllBorrowed(&'static str),
9    LeafBorrowed(&'static str),
10    HandlesBorrowed(&'static str),
11}
12
13pub trait RustGenerator {
14    fn push_str(&mut self, s: &str);
15    fn info(&self, ty: TypeId) -> TypeInfo;
16    fn types_mut(&mut self) -> &mut Types;
17    fn print_usize(&mut self);
18    fn print_pointer(&mut self, iface: &Interface, const_: bool, ty: &Type);
19    fn print_borrowed_slice(
20        &mut self,
21        iface: &Interface,
22        mutbl: bool,
23        ty: &Type,
24        lifetime: &'static str,
25    );
26    fn print_borrowed_str(&mut self, lifetime: &'static str);
27    fn print_lib_buffer(
28        &mut self,
29        iface: &Interface,
30        push: bool,
31        ty: &Type,
32        mode: TypeMode,
33        lt: &'static str,
34    );
35    fn default_param_mode(&self) -> TypeMode;
36    fn handle_projection(&self) -> Option<(&'static str, String)>;
37    fn handle_wrapper(&self) -> Option<&'static str>;
38    fn handle_in_super(&self) -> bool {
39        false
40    }
41
42    fn rustdoc(&mut self, docs: &Docs) {
43        let docs = match &docs.contents {
44            Some(docs) => docs,
45            None => return,
46        };
47        for line in docs.trim().lines() {
48            self.push_str("/// ");
49            self.push_str(line);
50            self.push_str("\n");
51        }
52    }
53
54    fn rustdoc_params(&mut self, docs: &[(String, Type)], header: &str) {
55        drop((docs, header));
56        // let docs = docs
57        //     .iter()
58        //     .filter(|param| param.docs.trim().len() > 0)
59        //     .collect::<Vec<_>>();
60        // if docs.len() == 0 {
61        //     return;
62        // }
63
64        // self.push_str("///\n");
65        // self.push_str("/// ## ");
66        // self.push_str(header);
67        // self.push_str("\n");
68        // self.push_str("///\n");
69
70        // for param in docs {
71        //     for (i, line) in param.docs.lines().enumerate() {
72        //         self.push_str("/// ");
73        //         // Currently wasi only has at most one return value, so there's no
74        //         // need to indent it or name it.
75        //         if header != "Return" {
76        //             if i == 0 {
77        //                 self.push_str("* `");
78        //                 self.push_str(to_rust_ident(param.name.as_str()));
79        //                 self.push_str("` - ");
80        //             } else {
81        //                 self.push_str("  ");
82        //             }
83        //         }
84        //         self.push_str(line);
85        //         self.push_str("\n");
86        //     }
87        // }
88    }
89
90    fn print_signature(
91        &mut self,
92        iface: &Interface,
93        func: &Function,
94        param_mode: TypeMode,
95        sig: &FnSig,
96    ) -> Vec<String> {
97        let params = self.print_docs_and_params(iface, func, param_mode, &sig);
98        if func.results.len() > 0 {
99            self.push_str(" -> ");
100            self.print_results(iface, func);
101        }
102        params
103    }
104
105    fn print_docs_and_params(
106        &mut self,
107        iface: &Interface,
108        func: &Function,
109        param_mode: TypeMode,
110        sig: &FnSig,
111    ) -> Vec<String> {
112        self.rustdoc(&func.docs);
113        self.rustdoc_params(&func.params, "Parameters");
114        self.rustdoc_params(&func.results, "Return");
115
116        if !sig.private {
117            self.push_str("pub ");
118        }
119        if sig.unsafe_ {
120            self.push_str("unsafe ");
121        }
122        if sig.async_ {
123            self.push_str("async ");
124        }
125        self.push_str("fn ");
126        let func_name = if sig.use_item_name {
127            func.item_name()
128        } else {
129            &func.name
130        };
131        self.push_str(&to_rust_ident(&func_name));
132        if let Some(generics) = &sig.generics {
133            self.push_str(generics);
134        }
135        self.push_str("(");
136        if let Some(arg) = &sig.self_arg {
137            self.push_str(arg);
138            self.push_str(",");
139        }
140        let mut params = Vec::new();
141        for (i, (name, param)) in func.params.iter().enumerate() {
142            if i == 0 && sig.self_is_first_param {
143                params.push("self".to_string());
144                continue;
145            }
146            let name = to_rust_ident(name);
147            self.push_str(&name);
148            params.push(name);
149            self.push_str(": ");
150            self.print_ty(iface, param, param_mode);
151            self.push_str(",");
152        }
153        self.push_str(")");
154        params
155    }
156
157    fn print_results(&mut self, iface: &Interface, func: &Function) {
158        match func.results.len() {
159            0 => self.push_str("()"),
160            1 => {
161                self.print_ty(iface, &func.results[0].1, TypeMode::Owned);
162            }
163            _ => {
164                self.push_str("(");
165                for (_, result) in func.results.iter() {
166                    self.print_ty(iface, result, TypeMode::Owned);
167                    self.push_str(", ");
168                }
169                self.push_str(")");
170            }
171        }
172    }
173
174    fn print_ty(&mut self, iface: &Interface, ty: &Type, mode: TypeMode) {
175        match ty {
176            Type::Id(t) => self.print_tyid(iface, *t, mode),
177            Type::Handle(r) => {
178                let mut info = TypeInfo::default();
179                info.has_handle = true;
180                let lt = self.lifetime_for(&info, mode);
181                // Borrowed handles are always behind a reference since
182                // in that case we never take ownership of the handle.
183                if let Some(lt) = lt {
184                    self.push_str("&");
185                    if lt != "'_" {
186                        self.push_str(lt);
187                    }
188                    self.push_str(" ");
189                }
190
191                let suffix = match self.handle_wrapper() {
192                    Some(wrapper) => {
193                        self.push_str(wrapper);
194                        self.push_str("<");
195                        ">"
196                    }
197                    None => "",
198                };
199                if self.handle_in_super() {
200                    self.push_str("super::");
201                }
202                if let Some((proj, _)) = self.handle_projection() {
203                    self.push_str(proj);
204                    self.push_str("::");
205                }
206                self.push_str(&iface.resources[*r].name.to_camel_case());
207                self.push_str(suffix);
208            }
209
210            Type::U8 => self.push_str("u8"),
211            Type::CChar => self.push_str("u8"),
212            Type::U16 => self.push_str("u16"),
213            Type::U32 => self.push_str("u32"),
214            Type::Usize => self.print_usize(),
215            Type::U64 => self.push_str("u64"),
216            Type::S8 => self.push_str("i8"),
217            Type::S16 => self.push_str("i16"),
218            Type::S32 => self.push_str("i32"),
219            Type::S64 => self.push_str("i64"),
220            Type::F32 => self.push_str("f32"),
221            Type::F64 => self.push_str("f64"),
222            Type::Char => self.push_str("char"),
223        }
224    }
225
226    fn print_tyid(&mut self, iface: &Interface, id: TypeId, mode: TypeMode) {
227        let info = self.info(id);
228        let lt = self.lifetime_for(&info, mode);
229        let ty = &iface.types[id];
230        if ty.name.is_some() {
231            let name = if lt.is_some() {
232                self.param_name(iface, id)
233            } else {
234                self.result_name(iface, id)
235            };
236            self.push_str(&name);
237
238            // If the type recursively owns data and it's a
239            // variant/record/list, then we need to place the
240            // lifetime parameter on the type as well.
241            if info.owns_data() && needs_generics(iface, &ty.kind) {
242                self.print_generics(&info, lt, false);
243            }
244
245            return;
246
247            fn needs_generics(iface: &Interface, ty: &TypeDefKind) -> bool {
248                match ty {
249                    TypeDefKind::Variant(_)
250                    | TypeDefKind::Record(_)
251                    | TypeDefKind::List(_)
252                    | TypeDefKind::PushBuffer(_)
253                    | TypeDefKind::PullBuffer(_) => true,
254                    TypeDefKind::Type(Type::Id(t)) => needs_generics(iface, &iface.types[*t].kind),
255                    TypeDefKind::Type(Type::Handle(_)) => true,
256                    _ => false,
257                }
258            }
259        }
260
261        match &ty.kind {
262            TypeDefKind::List(t) => self.print_list(iface, t, mode),
263
264            TypeDefKind::Pointer(t) => self.print_pointer(iface, false, t),
265            TypeDefKind::ConstPointer(t) => self.print_pointer(iface, true, t),
266
267            // Variants can be printed natively if they're `Option`,
268            // `Result` , or `bool`, otherwise they must be named for now.
269            TypeDefKind::Variant(v) if v.is_bool() => self.push_str("bool"),
270            TypeDefKind::Variant(v) => match v.as_expected() {
271                Some((ok, err)) => {
272                    self.push_str("Result<");
273                    match ok {
274                        Some(ty) => self.print_ty(iface, ty, mode),
275                        None => self.push_str("()"),
276                    }
277                    self.push_str(",");
278                    match err {
279                        Some(ty) => self.print_ty(iface, ty, mode),
280                        None => self.push_str("()"),
281                    }
282                    self.push_str(">");
283                }
284                None => match v.as_option() {
285                    Some(ty) => {
286                        self.push_str("Option<");
287                        self.print_ty(iface, ty, mode);
288                        self.push_str(">");
289                    }
290                    None => panic!("unsupported anonymous variant"),
291                },
292            },
293
294            // Tuple-like records are mapped directly to Rust tuples of
295            // types. Note the trailing comma after each member to
296            // appropriately handle 1-tuples.
297            TypeDefKind::Record(r) if r.is_tuple() => {
298                self.push_str("(");
299                for field in r.fields.iter() {
300                    self.print_ty(iface, &field.ty, mode);
301                    self.push_str(",");
302                }
303                self.push_str(")");
304            }
305            TypeDefKind::Record(_) => {
306                panic!("unsupported anonymous type reference: record")
307            }
308
309            TypeDefKind::PushBuffer(r) => self.print_buffer(iface, true, r, mode),
310            TypeDefKind::PullBuffer(r) => self.print_buffer(iface, false, r, mode),
311
312            TypeDefKind::Type(t) => self.print_ty(iface, t, mode),
313        }
314    }
315
316    fn print_list(&mut self, iface: &Interface, ty: &Type, mode: TypeMode) {
317        match ty {
318            Type::Char => match mode {
319                TypeMode::AllBorrowed(lt) | TypeMode::LeafBorrowed(lt) => {
320                    self.print_borrowed_str(lt)
321                }
322                TypeMode::Owned | TypeMode::HandlesBorrowed(_) => self.push_str("String"),
323            },
324            t => match mode {
325                TypeMode::AllBorrowed(lt) => {
326                    let mutbl = self.needs_mutable_slice(iface, ty);
327                    self.print_borrowed_slice(iface, mutbl, ty, lt);
328                }
329                TypeMode::LeafBorrowed(lt) => {
330                    if iface.all_bits_valid(t) {
331                        let mutbl = self.needs_mutable_slice(iface, ty);
332                        self.print_borrowed_slice(iface, mutbl, ty, lt);
333                    } else {
334                        self.push_str("Vec<");
335                        self.print_ty(iface, ty, mode);
336                        self.push_str(">");
337                    }
338                }
339                TypeMode::HandlesBorrowed(_) | TypeMode::Owned => {
340                    self.push_str("Vec<");
341                    self.print_ty(iface, ty, mode);
342                    self.push_str(">");
343                }
344            },
345        }
346    }
347
348    fn print_buffer(&mut self, iface: &Interface, push: bool, ty: &Type, mode: TypeMode) {
349        let lt = match mode {
350            TypeMode::AllBorrowed(s) | TypeMode::HandlesBorrowed(s) | TypeMode::LeafBorrowed(s) => {
351                s
352            }
353            TypeMode::Owned => unimplemented!(),
354        };
355        if iface.all_bits_valid(ty) {
356            self.print_borrowed_slice(iface, push, ty, lt)
357        } else {
358            self.print_lib_buffer(iface, push, ty, mode, lt)
359        }
360    }
361
362    fn print_rust_slice(
363        &mut self,
364        iface: &Interface,
365        mutbl: bool,
366        ty: &Type,
367        lifetime: &'static str,
368    ) {
369        self.push_str("&");
370        if lifetime != "'_" {
371            self.push_str(lifetime);
372            self.push_str(" ");
373        }
374        if mutbl {
375            self.push_str(" mut ");
376        }
377        self.push_str("[");
378        self.print_ty(iface, ty, TypeMode::AllBorrowed(lifetime));
379        self.push_str("]");
380    }
381
382    fn print_generics(&mut self, info: &TypeInfo, lifetime: Option<&str>, bound: bool) {
383        let proj = if info.has_handle {
384            self.handle_projection()
385        } else {
386            None
387        };
388        if lifetime.is_none() && proj.is_none() {
389            return;
390        }
391        self.push_str("<");
392        if let Some(lt) = lifetime {
393            self.push_str(lt);
394            self.push_str(",");
395        }
396        if let Some((proj, trait_bound)) = proj {
397            self.push_str(proj);
398            if bound {
399                self.push_str(": ");
400                self.push_str(&trait_bound);
401            }
402        }
403        self.push_str(">");
404    }
405
406    fn int_repr(&mut self, repr: Int) {
407        self.push_str(int_repr(repr));
408    }
409
410    fn wasm_type(&mut self, ty: WasmType) {
411        self.push_str(wasm_type(ty));
412    }
413
414    fn modes_of(&self, iface: &Interface, ty: TypeId) -> Vec<(String, TypeMode)> {
415        let info = self.info(ty);
416        let mut result = Vec::new();
417        if info.param {
418            result.push((self.param_name(iface, ty), self.default_param_mode()));
419        }
420        if info.result && (!info.param || self.uses_two_names(&info)) {
421            result.push((self.result_name(iface, ty), TypeMode::Owned));
422        }
423        return result;
424    }
425
426    fn print_typedef_record(
427        &mut self,
428        iface: &Interface,
429        id: TypeId,
430        record: &Record,
431        docs: &Docs,
432    ) {
433        let info = self.info(id);
434        for (name, mode) in self.modes_of(iface, id) {
435            let lt = self.lifetime_for(&info, mode);
436            self.rustdoc(docs);
437            if record.is_tuple() {
438                self.push_str(&format!("pub type {}", name));
439                self.print_generics(&info, lt, true);
440                self.push_str(" = (");
441                for field in record.fields.iter() {
442                    self.print_ty(iface, &field.ty, mode);
443                    self.push_str(",");
444                }
445                self.push_str(");\n");
446            } else {
447                if info.has_pull_buffer || info.has_push_buffer {
448                    // skip copy/clone ...
449                } else if !info.owns_data() {
450                    self.push_str("#[repr(C)]\n");
451                    self.push_str("#[derive(Copy, Clone)]\n");
452                } else if !info.has_handle {
453                    self.push_str("#[derive(Clone)]\n");
454                }
455                self.push_str(&format!("pub struct {}", name));
456                self.print_generics(&info, lt, true);
457                self.push_str(" {\n");
458                for field in record.fields.iter() {
459                    self.rustdoc(&field.docs);
460                    self.push_str("pub ");
461                    self.push_str(&to_rust_ident(&field.name));
462                    self.push_str(": ");
463                    self.print_ty(iface, &field.ty, mode);
464                    self.push_str(",\n");
465                }
466                self.push_str("}\n");
467
468                self.push_str("impl");
469                self.print_generics(&info, lt, true);
470                self.push_str(" std::fmt::Debug for ");
471                self.push_str(&name);
472                self.print_generics(&info, lt, false);
473                self.push_str(" {\n");
474                self.push_str(
475                    "fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n",
476                );
477                self.push_str(&format!("f.debug_struct(\"{}\")", name));
478                for field in record.fields.iter() {
479                    self.push_str(&format!(
480                        ".field(\"{}\", &self.{})",
481                        field.name,
482                        to_rust_ident(&field.name)
483                    ));
484                }
485                self.push_str(".finish()");
486                self.push_str("}\n");
487                self.push_str("}\n");
488            }
489        }
490    }
491
492    fn print_typedef_variant(
493        &mut self,
494        iface: &Interface,
495        id: TypeId,
496        name: &str,
497        variant: &Variant,
498        docs: &Docs,
499    ) {
500        // TODO: should this perhaps be an attribute in the wit file?
501        let is_error = name.contains("errno") && variant.is_enum();
502        let info = self.info(id);
503
504        for (name, mode) in self.modes_of(iface, id) {
505            self.rustdoc(docs);
506            let lt = self.lifetime_for(&info, mode);
507            if variant.is_bool() {
508                self.push_str(&format!("pub type {} = bool;\n", name));
509                continue;
510            } else if let Some(ty) = variant.as_option() {
511                self.push_str(&format!("pub type {}", name));
512                self.print_generics(&info, lt, true);
513                self.push_str("= Option<");
514                self.print_ty(iface, ty, mode);
515                self.push_str(">;\n");
516                continue;
517            } else if let Some((ok, err)) = variant.as_expected() {
518                self.push_str(&format!("pub type {}", name));
519                self.print_generics(&info, lt, true);
520                self.push_str("= Result<");
521                match ok {
522                    Some(ty) => self.print_ty(iface, ty, mode),
523                    None => self.push_str("()"),
524                }
525                self.push_str(",");
526                match err {
527                    Some(ty) => self.print_ty(iface, ty, mode),
528                    None => self.push_str("()"),
529                }
530                self.push_str(">;\n");
531                continue;
532            }
533            if variant.is_enum() {
534                self.push_str("#[repr(");
535                self.int_repr(variant.tag);
536                self.push_str(")]\n#[derive(Clone, Copy, PartialEq, Eq)]\n");
537            } else if info.has_pull_buffer || info.has_push_buffer {
538                // skip copy/clone
539            } else if !info.owns_data() {
540                self.push_str("#[derive(Clone, Copy)]\n");
541            }
542            self.push_str(&format!("pub enum {}", name.to_camel_case()));
543            self.print_generics(&info, lt, true);
544            self.push_str("{\n");
545            for case in variant.cases.iter() {
546                self.rustdoc(&case.docs);
547                self.push_str(&case_name(&case.name));
548                if let Some(ty) = &case.ty {
549                    self.push_str("(");
550                    self.print_ty(iface, ty, mode);
551                    self.push_str(")")
552                }
553                self.push_str(",\n");
554            }
555            self.push_str("}\n");
556
557            // Auto-synthesize an implementation of the standard `Error` trait for
558            // error-looking types based on their name.
559            if is_error {
560                self.push_str("impl ");
561                self.push_str(&name);
562                self.push_str("{\n");
563
564                self.push_str("pub fn name(&self) -> &'static str {\n");
565                self.push_str("match self {\n");
566                for case in variant.cases.iter() {
567                    self.push_str(&name);
568                    self.push_str("::");
569                    self.push_str(&case_name(&case.name));
570                    self.push_str(" => \"");
571                    self.push_str(case.name.as_str());
572                    self.push_str("\",\n");
573                }
574                self.push_str("}\n");
575                self.push_str("}\n");
576
577                self.push_str("pub fn message(&self) -> &'static str {\n");
578                self.push_str("match self {\n");
579                for case in variant.cases.iter() {
580                    self.push_str(&name);
581                    self.push_str("::");
582                    self.push_str(&case_name(&case.name));
583                    self.push_str(" => \"");
584                    if let Some(contents) = &case.docs.contents {
585                        self.push_str(contents.trim());
586                    }
587                    self.push_str("\",\n");
588                }
589                self.push_str("}\n");
590                self.push_str("}\n");
591
592                self.push_str("}\n");
593
594                self.push_str("impl std::fmt::Debug for ");
595                self.push_str(&name);
596                self.push_str(
597                    "{\nfn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n",
598                );
599                self.push_str("f.debug_struct(\"");
600                self.push_str(&name);
601                self.push_str("\")\n");
602                self.push_str(".field(\"code\", &(*self as i32))\n");
603                self.push_str(".field(\"name\", &self.name())\n");
604                self.push_str(".field(\"message\", &self.message())\n");
605                self.push_str(".finish()\n");
606                self.push_str("}\n");
607                self.push_str("}\n");
608
609                self.push_str("impl std::fmt::Display for ");
610                self.push_str(&name);
611                self.push_str(
612                    "{\nfn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n",
613                );
614                self.push_str("write!(f, \"{} (error {})\", self.name(), *self as i32)");
615                self.push_str("}\n");
616                self.push_str("}\n");
617                self.push_str("\n");
618                self.push_str("impl std::error::Error for ");
619                self.push_str(&name);
620                self.push_str("{}\n");
621            } else {
622                self.push_str("impl");
623                self.print_generics(&info, lt, true);
624                self.push_str(" std::fmt::Debug for ");
625                self.push_str(&name);
626                self.print_generics(&info, lt, false);
627                self.push_str(" {\n");
628                self.push_str(
629                    "fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n",
630                );
631                self.push_str("match self {\n");
632                for case in variant.cases.iter() {
633                    self.push_str(&name);
634                    self.push_str("::");
635                    self.push_str(&case_name(&case.name));
636                    if case.ty.is_some() {
637                        self.push_str("(e)");
638                    }
639                    self.push_str(" => {\n");
640                    self.push_str(&format!(
641                        "f.debug_tuple(\"{}::{}\")",
642                        name,
643                        case_name(&case.name)
644                    ));
645                    if case.ty.is_some() {
646                        self.push_str(".field(e)");
647                    }
648                    self.push_str(".finish()\n");
649                    self.push_str("}\n");
650                }
651                self.push_str("}\n");
652                self.push_str("}\n");
653                self.push_str("}\n");
654            }
655        }
656    }
657
658    fn print_typedef_alias(&mut self, iface: &Interface, id: TypeId, ty: &Type, docs: &Docs) {
659        let info = self.info(id);
660        for (name, mode) in self.modes_of(iface, id) {
661            self.rustdoc(docs);
662            self.push_str(&format!("pub type {}", name));
663            let lt = self.lifetime_for(&info, mode);
664            self.print_generics(&info, lt, true);
665            self.push_str(" = ");
666            self.print_ty(iface, ty, mode);
667            self.push_str(";\n");
668        }
669    }
670
671    fn print_type_list(&mut self, iface: &Interface, id: TypeId, ty: &Type, docs: &Docs) {
672        let info = self.info(id);
673        for (name, mode) in self.modes_of(iface, id) {
674            let lt = self.lifetime_for(&info, mode);
675            self.rustdoc(docs);
676            self.push_str(&format!("pub type {}", name));
677            self.print_generics(&info, lt, true);
678            self.push_str(" = ");
679            self.print_list(iface, ty, mode);
680            self.push_str(";\n");
681        }
682    }
683
684    fn print_typedef_buffer(
685        &mut self,
686        iface: &Interface,
687        id: TypeId,
688        push: bool,
689        ty: &Type,
690        docs: &Docs,
691    ) {
692        let info = self.info(id);
693        for (name, mode) in self.modes_of(iface, id) {
694            let lt = self.lifetime_for(&info, mode);
695            self.rustdoc(docs);
696            self.push_str(&format!("pub type {}", name));
697            self.print_generics(&info, lt, true);
698            self.push_str(" = ");
699            self.print_buffer(iface, push, ty, mode);
700            self.push_str(";\n");
701        }
702    }
703
704    fn param_name(&self, iface: &Interface, ty: TypeId) -> String {
705        let info = self.info(ty);
706        let name = iface.types[ty].name.as_ref().unwrap().to_camel_case();
707        if self.uses_two_names(&info) {
708            format!("{}Param", name)
709        } else {
710            name
711        }
712    }
713
714    fn result_name(&self, iface: &Interface, ty: TypeId) -> String {
715        let info = self.info(ty);
716        let name = iface.types[ty].name.as_ref().unwrap().to_camel_case();
717        if self.uses_two_names(&info) {
718            format!("{}Result", name)
719        } else {
720            name
721        }
722    }
723
724    fn uses_two_names(&self, info: &TypeInfo) -> bool {
725        info.owns_data()
726            && info.param
727            && info.result
728            && match self.default_param_mode() {
729                TypeMode::AllBorrowed(_) | TypeMode::LeafBorrowed(_) => true,
730                TypeMode::HandlesBorrowed(_) => info.has_handle,
731                TypeMode::Owned => false,
732            }
733    }
734
735    fn lifetime_for(&self, info: &TypeInfo, mode: TypeMode) -> Option<&'static str> {
736        match mode {
737            TypeMode::AllBorrowed(s) | TypeMode::LeafBorrowed(s)
738                if info.has_list
739                    || info.has_handle
740                    || info.has_push_buffer
741                    || info.has_pull_buffer =>
742            {
743                Some(s)
744            }
745            TypeMode::HandlesBorrowed(s)
746                if info.has_handle || info.has_pull_buffer || info.has_push_buffer =>
747            {
748                Some(s)
749            }
750            _ => None,
751        }
752    }
753
754    fn needs_mutable_slice(&mut self, iface: &Interface, ty: &Type) -> bool {
755        let info = self.types_mut().type_info(iface, ty);
756        // If there's any out-buffers transitively then a mutable slice is
757        // required because the out-buffers could be modified. Otherwise a
758        // mutable slice is also required if, transitively, `InBuffer` is used
759        // which is used when we're a buffer of a type where not all bits are
760        // valid (e.g. the rust representation and the canonical abi may differ).
761        info.has_push_buffer || self.has_pull_buffer_invalid_bits(iface, ty)
762    }
763
764    fn has_pull_buffer_invalid_bits(&self, iface: &Interface, ty: &Type) -> bool {
765        let id = match ty {
766            Type::Id(id) => *id,
767            _ => return false,
768        };
769        match &iface.types[id].kind {
770            TypeDefKind::Type(t)
771            | TypeDefKind::Pointer(t)
772            | TypeDefKind::ConstPointer(t)
773            | TypeDefKind::PushBuffer(t)
774            | TypeDefKind::List(t) => self.has_pull_buffer_invalid_bits(iface, t),
775            TypeDefKind::Record(r) => r
776                .fields
777                .iter()
778                .any(|t| self.has_pull_buffer_invalid_bits(iface, &t.ty)),
779            TypeDefKind::Variant(v) => v
780                .cases
781                .iter()
782                .filter_map(|c| c.ty.as_ref())
783                .any(|t| self.has_pull_buffer_invalid_bits(iface, t)),
784            TypeDefKind::PullBuffer(t) => {
785                !iface.all_bits_valid(t) || self.has_pull_buffer_invalid_bits(iface, t)
786            }
787        }
788    }
789}
790
791#[derive(Default)]
792pub struct FnSig {
793    pub async_: bool,
794    pub unsafe_: bool,
795    pub private: bool,
796    pub use_item_name: bool,
797    pub generics: Option<String>,
798    pub self_arg: Option<String>,
799    pub self_is_first_param: bool,
800}
801
802pub trait RustFunctionGenerator {
803    fn push_str(&mut self, s: &str);
804    fn tmp(&mut self) -> usize;
805    fn rust_gen(&self) -> &dyn RustGenerator;
806    fn lift_lower(&self) -> LiftLower;
807
808    fn let_results(&mut self, amt: usize, results: &mut Vec<String>) {
809        match amt {
810            0 => {}
811            1 => {
812                let tmp = self.tmp();
813                let res = format!("result{}", tmp);
814                self.push_str("let ");
815                self.push_str(&res);
816                results.push(res);
817                self.push_str(" = ");
818            }
819            n => {
820                let tmp = self.tmp();
821                self.push_str("let (");
822                for i in 0..n {
823                    let arg = format!("result{}_{}", tmp, i);
824                    self.push_str(&arg);
825                    self.push_str(",");
826                    results.push(arg);
827                }
828                self.push_str(") = ");
829            }
830        }
831    }
832
833    fn record_lower(
834        &mut self,
835        iface: &Interface,
836        id: TypeId,
837        record: &Record,
838        operand: &str,
839        results: &mut Vec<String>,
840    ) {
841        let tmp = self.tmp();
842        if record.is_tuple() {
843            self.push_str("let (");
844            for i in 0..record.fields.len() {
845                let arg = format!("t{}_{}", tmp, i);
846                self.push_str(&arg);
847                self.push_str(", ");
848                results.push(arg);
849            }
850            self.push_str(") = ");
851            self.push_str(operand);
852            self.push_str(";\n");
853        } else {
854            self.push_str("let ");
855            let name = self.typename_lower(iface, id);
856            self.push_str(&name);
857            self.push_str("{ ");
858            for field in record.fields.iter() {
859                let name = to_rust_ident(&field.name);
860                let arg = format!("{}{}", name, tmp);
861                self.push_str(&name);
862                self.push_str(":");
863                self.push_str(&arg);
864                self.push_str(", ");
865                results.push(arg);
866            }
867            self.push_str("} = ");
868            self.push_str(operand);
869            self.push_str(";\n");
870        }
871    }
872
873    fn record_lift(
874        &mut self,
875        iface: &Interface,
876        id: TypeId,
877        ty: &Record,
878        operands: &[String],
879        results: &mut Vec<String>,
880    ) {
881        if ty.is_tuple() {
882            if operands.len() == 1 {
883                results.push(format!("({},)", operands[0]));
884            } else {
885                results.push(format!("({})", operands.join(", ")));
886            }
887        } else {
888            let mut result = self.typename_lift(iface, id);
889            result.push_str("{");
890            for (field, val) in ty.fields.iter().zip(operands) {
891                result.push_str(&to_rust_ident(&field.name));
892                result.push_str(":");
893                result.push_str(&val);
894                result.push_str(", ");
895            }
896            result.push_str("}");
897            results.push(result);
898        }
899    }
900
901    fn typename_lower(&self, iface: &Interface, id: TypeId) -> String {
902        match self.lift_lower() {
903            LiftLower::LowerArgsLiftResults => self.rust_gen().param_name(iface, id),
904            LiftLower::LiftArgsLowerResults => self.rust_gen().result_name(iface, id),
905        }
906    }
907
908    fn typename_lift(&self, iface: &Interface, id: TypeId) -> String {
909        match self.lift_lower() {
910            LiftLower::LiftArgsLowerResults => self.rust_gen().param_name(iface, id),
911            LiftLower::LowerArgsLiftResults => self.rust_gen().result_name(iface, id),
912        }
913    }
914
915    fn variant_lower(
916        &mut self,
917        iface: &Interface,
918        id: TypeId,
919        ty: &Variant,
920        nresults: usize,
921        operand: &str,
922        results: &mut Vec<String>,
923        blocks: Vec<String>,
924    ) {
925        // If this is a named enum with no type payloads and we're
926        // producing a singular result, then we know we're directly
927        // converting from the Rust enum to the integer discriminant. In
928        // this scenario we can optimize a bit and use just `as i32`
929        // instead of letting LLVM figure out it can do the same with
930        // optimizing the `match` generated below.
931        let has_name = iface.types[id].name.is_some();
932        if nresults == 1 && has_name && ty.cases.iter().all(|c| c.ty.is_none()) {
933            results.push(format!("{} as i32", operand));
934            return;
935        }
936
937        self.let_results(nresults, results);
938        self.push_str("match ");
939        self.push_str(operand);
940        self.push_str("{\n");
941        for (case, block) in ty.cases.iter().zip(blocks) {
942            if ty.is_bool() {
943                self.push_str(case.name.as_str());
944            } else if ty.as_expected().is_some() {
945                self.push_str(&case.name.to_camel_case());
946                self.push_str("(");
947                self.push_str(if case.ty.is_some() { "e" } else { "()" });
948                self.push_str(")");
949            } else if ty.as_option().is_some() {
950                self.push_str(&case.name.to_camel_case());
951                if case.ty.is_some() {
952                    self.push_str("(e)");
953                }
954            } else if has_name {
955                let name = self.typename_lower(iface, id);
956                self.push_str(&name);
957                self.push_str("::");
958                self.push_str(&case_name(&case.name));
959                if case.ty.is_some() {
960                    self.push_str("(e)");
961                }
962            } else {
963                unimplemented!()
964            }
965            self.push_str(" => { ");
966            self.push_str(&block);
967            self.push_str("}\n");
968        }
969        self.push_str("};\n");
970    }
971
972    fn variant_lift_case(
973        &mut self,
974        iface: &Interface,
975        id: TypeId,
976        ty: &Variant,
977        case: &Case,
978        block: &str,
979        result: &mut String,
980    ) {
981        if ty.is_bool() {
982            result.push_str(case.name.as_str());
983        } else if ty.as_expected().is_some() {
984            result.push_str(&case.name.to_camel_case());
985            result.push_str("(");
986            result.push_str(block);
987            result.push_str(")");
988        } else if ty.as_option().is_some() {
989            result.push_str(&case.name.to_camel_case());
990            if case.ty.is_some() {
991                result.push_str("(");
992                result.push_str(block);
993                result.push_str(")");
994            }
995        } else if iface.types[id].name.is_some() {
996            result.push_str(&self.typename_lift(iface, id));
997            result.push_str("::");
998            result.push_str(&case_name(&case.name));
999            if case.ty.is_some() {
1000                result.push_str("(");
1001                result.push_str(block);
1002                result.push_str(")");
1003            }
1004        } else {
1005            unimplemented!()
1006        }
1007    }
1008}
1009
1010pub fn to_rust_ident(name: &str) -> String {
1011    match name {
1012        "in" => "in_".into(),
1013        "type" => "type_".into(),
1014        "where" => "where_".into(),
1015        "yield" => "yield_".into(),
1016        "async" => "async_".into(),
1017        "self" => "self_".into(),
1018        s => s.to_snake_case(),
1019    }
1020}
1021
1022pub fn wasm_type(ty: WasmType) -> &'static str {
1023    match ty {
1024        WasmType::I32 => "i32",
1025        WasmType::I64 => "i64",
1026        WasmType::F32 => "f32",
1027        WasmType::F64 => "f64",
1028    }
1029}
1030
1031pub fn int_repr(repr: Int) -> &'static str {
1032    match repr {
1033        Int::U8 => "u8",
1034        Int::U16 => "u16",
1035        Int::U32 => "u32",
1036        Int::U64 => "u64",
1037    }
1038}
1039
1040trait TypeInfoExt {
1041    fn owns_data(&self) -> bool;
1042}
1043
1044impl TypeInfoExt for TypeInfo {
1045    fn owns_data(&self) -> bool {
1046        self.has_list || self.has_handle || self.has_pull_buffer || self.has_push_buffer
1047    }
1048}
1049
1050pub fn case_name(id: &str) -> String {
1051    if id.chars().next().unwrap().is_alphabetic() {
1052        id.to_camel_case()
1053    } else {
1054        format!("V{}", id)
1055    }
1056}
1057
1058pub fn bitcast(casts: &[Bitcast], operands: &[String], results: &mut Vec<String>) {
1059    for (cast, operand) in casts.iter().zip(operands) {
1060        results.push(match cast {
1061            Bitcast::None => operand.clone(),
1062            Bitcast::F32ToF64 => format!("f64::from({})", operand),
1063            Bitcast::I32ToI64 => format!("i64::from({})", operand),
1064            Bitcast::F32ToI32 => format!("({}).to_bits() as i32", operand),
1065            Bitcast::F64ToI64 => format!("({}).to_bits() as i64", operand),
1066            Bitcast::F64ToF32 => format!("{} as f32", operand),
1067            Bitcast::I64ToI32 => format!("{} as i32", operand),
1068            Bitcast::I32ToF32 => format!("f32::from_bits({} as u32)", operand),
1069            Bitcast::I64ToF64 => format!("f64::from_bits({} as u64)", operand),
1070            Bitcast::F32ToI64 => format!("i64::from(({}).to_bits())", operand),
1071            Bitcast::I64ToF32 => format!("f32::from_bits({} as u32)", operand),
1072        });
1073    }
1074}