witx_codegen/rust/
function.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
use std::io::Write;

use super::*;

impl RustGenerator {
    pub fn define_func<T: Write>(
        w: &mut PrettyWriter<T>,
        module_name: &str,
        func_witx: &witx::Function,
    ) -> Result<(), Error> {
        assert_eq!(func_witx.abi, witx::Abi::Preview1);
        let name = func_witx.name.as_str().to_string();
        let params_witx = &func_witx.params;
        let mut params = vec![];
        for param_witx in params_witx {
            let param_name = param_witx.name.as_str();
            let param_type = ASType::from(&param_witx.tref);
            params.push((param_name.to_string(), param_type));
        }

        let results_witx = &func_witx.results;
        assert_eq!(results_witx.len(), 1);
        let result_witx = &results_witx[0];
        let result = ASType::from(&result_witx.tref);
        let result = match result {
            ASType::Result(result) => result,
            _ => unreachable!(),
        };

        let ok_type = result.ok_type.clone();

        let docs = &func_witx.docs;
        if !docs.is_empty() {
            Self::write_docs(w, docs)?;
        }

        let mut params_decomposed = vec![];

        for param in &params {
            let mut decomposed = param.1.decompose(&param.0, false);
            params_decomposed.append(&mut decomposed);
        }

        let mut results = vec![];
        // A tuple in a result is expanded into additional parameters, transformed to
        // pointers
        if let ASType::Tuple(tuple_members) = ok_type.as_ref().leaf() {
            for (i, tuple_member) in tuple_members.iter().enumerate() {
                let name = format!("result{}_ptr", i);
                results.push((name, tuple_member.type_.clone()));
            }
        } else {
            let name = "result_ptr";
            results.push((name.to_string(), ok_type));
        }
        let mut results_decomposed = vec![];
        for result in &results {
            let mut decomposed = result.1.decompose(&result.0, true);
            results_decomposed.append(&mut decomposed);
        }

        Self::define_func_raw(
            w,
            module_name,
            &name,
            &params_decomposed,
            &results_decomposed,
            &result,
        )?;

        let signature_witx = func_witx.wasm_signature(witx::CallMode::DefinedImport);
        let params_count_witx = signature_witx.params.len() + signature_witx.results.len();
        assert_eq!(
            params_count_witx,
            params_decomposed.len() + results_decomposed.len() + 1
        );

        Ok(())
    }

    fn define_func_raw<T: Write>(
        w: &mut PrettyWriter<T>,
        module_name: &str,
        name: &str,
        params_decomposed: &[ASTypeDecomposed],
        results_decomposed: &[ASTypeDecomposed],
        result: &ASResult,
    ) -> Result<(), Error> {
        let results_decomposed_deref = results_decomposed
            .iter()
            .map(|result_ptr_type| match result_ptr_type.type_.as_ref() {
                ASType::MutPtr(result_type) => ASTypeDecomposed {
                    name: result_ptr_type.name.clone(),
                    type_: result_type.clone(),
                },
                _ => panic!("Result type is not a pointer"),
            })
            .collect::<Vec<_>>();
        let results_set = results_decomposed_deref
            .iter()
            .map(|result| result.type_.as_lang())
            .collect::<Vec<_>>();
        let rust_fn_result_str = match results_set.len() {
            0 => "()".to_string(),
            1 => results_set[0].clone(),
            _ => format!("({})", results_set.join(", ")),
        };
        w.indent()?.write(format!("pub fn {}(", name.as_fn()))?;
        if !params_decomposed.is_empty() || !results_decomposed.is_empty() {
            w.eol()?;
        }
        for param in params_decomposed {
            w.write_line_continued(format!(
                "{}: {},",
                param.name.as_var(),
                param.type_.as_lang(),
            ))?;
        }
        w.write_line(format!(") -> Result<{}, Error> {{", rust_fn_result_str))?;
        {
            let mut w = w.new_block();

            // Inner (raw) definition
            {
                w.write_line(format!("#[link(wasm_import_module = \"{}\")]", module_name))?;
                w.write_line("extern \"C\" {")?;
                {
                    let mut w = w.new_block();
                    w.indent()?.write(format!("fn {}(", name.as_fn()))?;
                    if !params_decomposed.is_empty() {
                        w.eol()?;
                    }
                    for param in params_decomposed.iter().chain(results_decomposed.iter()) {
                        w.write_line_continued(format!(
                            "{}: {},",
                            param.name.as_var(),
                            param.type_.as_lang(),
                        ))?;
                    }
                    w.write_line(format!(") -> {};", result.error_type.as_lang()))?;
                }
                w.write_line("}")?;
            }

            // Wrapper
            for result in &results_decomposed_deref {
                w.write_line(format!(
                    "let mut {} = std::mem::MaybeUninit::uninit();",
                    result.name.as_var()
                ))?;
            }

            w.write_line(format!("let res = unsafe {{ {}(", name.as_fn()))?;
            for param in params_decomposed {
                w.write_line_continued(format!("{},", param.name.as_var()))?;
            }
            for result in results_decomposed_deref.iter() {
                w.write_line_continued(format!("{}.as_mut_ptr(),", result.name.as_var()))?;
            }
            w.write_line(")};")?;
            w.write_lines(
                "if res != 0 {
    return Err(Error::WasiError(res as _));
}",
            )?;
            let res_str = match results_decomposed.len() {
                0 => "()".to_string(),
                1 => format!(
                    "unsafe {{ {}.assume_init() }}",
                    results_decomposed_deref[0].name.as_var()
                ),
                _ => format!(
                    "unsafe {{ ({}) }}",
                    results_decomposed_deref
                        .iter()
                        .map(|result| format!("{}.assume_init()", result.name.as_var()))
                        .collect::<Vec<_>>()
                        .join(", ")
                ),
            };
            w.write_line(format!("Ok({})", res_str))?;
        };
        w.write_line("}")?;
        w.eob()?;

        Ok(())
    }
}