wiggle_generate/
config.rs

1use {
2    proc_macro2::{Span, TokenStream},
3    std::{collections::HashMap, path::PathBuf},
4    syn::{
5        Error, Ident, LitStr, Result, Token, braced, bracketed,
6        parse::{Parse, ParseStream},
7        punctuated::Punctuated,
8    },
9};
10
11#[derive(Debug, Clone)]
12pub struct Config {
13    pub witx: WitxConf,
14    pub errors: ErrorConf,
15    pub async_: AsyncConf,
16    pub wasmtime: bool,
17    pub tracing: TracingConf,
18    pub mutable: bool,
19}
20
21mod kw {
22    syn::custom_keyword!(witx);
23    syn::custom_keyword!(witx_literal);
24    syn::custom_keyword!(block_on);
25    syn::custom_keyword!(errors);
26    syn::custom_keyword!(target);
27    syn::custom_keyword!(wasmtime);
28    syn::custom_keyword!(mutable);
29    syn::custom_keyword!(tracing);
30    syn::custom_keyword!(disable_for);
31    syn::custom_keyword!(trappable);
32}
33
34#[derive(Debug, Clone)]
35pub enum ConfigField {
36    Witx(WitxConf),
37    Error(ErrorConf),
38    Async(AsyncConf),
39    Wasmtime(bool),
40    Tracing(TracingConf),
41    Mutable(bool),
42}
43
44impl Parse for ConfigField {
45    fn parse(input: ParseStream) -> Result<Self> {
46        let lookahead = input.lookahead1();
47        if lookahead.peek(kw::witx) {
48            input.parse::<kw::witx>()?;
49            input.parse::<Token![:]>()?;
50            Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
51        } else if lookahead.peek(kw::witx_literal) {
52            input.parse::<kw::witx_literal>()?;
53            input.parse::<Token![:]>()?;
54            Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
55        } else if lookahead.peek(kw::errors) {
56            input.parse::<kw::errors>()?;
57            input.parse::<Token![:]>()?;
58            Ok(ConfigField::Error(input.parse()?))
59        } else if lookahead.peek(Token![async]) {
60            input.parse::<Token![async]>()?;
61            input.parse::<Token![:]>()?;
62            Ok(ConfigField::Async(AsyncConf {
63                block_with: None,
64                functions: input.parse()?,
65            }))
66        } else if lookahead.peek(kw::block_on) {
67            input.parse::<kw::block_on>()?;
68            let block_with = if input.peek(syn::token::Bracket) {
69                let content;
70                let _ = bracketed!(content in input);
71                content.parse()?
72            } else {
73                quote::quote!(wiggle::run_in_dummy_executor)
74            };
75            input.parse::<Token![:]>()?;
76            Ok(ConfigField::Async(AsyncConf {
77                block_with: Some(block_with),
78                functions: input.parse()?,
79            }))
80        } else if lookahead.peek(kw::wasmtime) {
81            input.parse::<kw::wasmtime>()?;
82            input.parse::<Token![:]>()?;
83            Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
84        } else if lookahead.peek(kw::tracing) {
85            input.parse::<kw::tracing>()?;
86            input.parse::<Token![:]>()?;
87            Ok(ConfigField::Tracing(input.parse()?))
88        } else if lookahead.peek(kw::mutable) {
89            input.parse::<kw::mutable>()?;
90            input.parse::<Token![:]>()?;
91            Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
92        } else {
93            Err(lookahead.error())
94        }
95    }
96}
97
98impl Config {
99    pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
100        let mut witx = None;
101        let mut errors = None;
102        let mut async_ = None;
103        let mut wasmtime = None;
104        let mut tracing = None;
105        let mut mutable = None;
106        for f in fields {
107            match f {
108                ConfigField::Witx(c) => {
109                    if witx.is_some() {
110                        return Err(Error::new(err_loc, "duplicate `witx` field"));
111                    }
112                    witx = Some(c);
113                }
114                ConfigField::Error(c) => {
115                    if errors.is_some() {
116                        return Err(Error::new(err_loc, "duplicate `errors` field"));
117                    }
118                    errors = Some(c);
119                }
120                ConfigField::Async(c) => {
121                    if async_.is_some() {
122                        return Err(Error::new(err_loc, "duplicate `async` field"));
123                    }
124                    async_ = Some(c);
125                }
126                ConfigField::Wasmtime(c) => {
127                    if wasmtime.is_some() {
128                        return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
129                    }
130                    wasmtime = Some(c);
131                }
132                ConfigField::Tracing(c) => {
133                    if tracing.is_some() {
134                        return Err(Error::new(err_loc, "duplicate `tracing` field"));
135                    }
136                    tracing = Some(c);
137                }
138                ConfigField::Mutable(c) => {
139                    if mutable.is_some() {
140                        return Err(Error::new(err_loc, "duplicate `mutable` field"));
141                    }
142                    mutable = Some(c);
143                }
144            }
145        }
146        Ok(Config {
147            witx: witx
148                .take()
149                .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
150            errors: errors.take().unwrap_or_default(),
151            async_: async_.take().unwrap_or_default(),
152            wasmtime: wasmtime.unwrap_or(true),
153            tracing: tracing.unwrap_or_default(),
154            mutable: mutable.unwrap_or(true),
155        })
156    }
157
158    /// Load the `witx` document for the configuration.
159    ///
160    /// # Panics
161    ///
162    /// This method will panic if the paths given in the `witx` field were not valid documents.
163    pub fn load_document(&self) -> witx::Document {
164        self.witx.load_document()
165    }
166}
167
168impl Parse for Config {
169    fn parse(input: ParseStream) -> Result<Self> {
170        let contents;
171        let _lbrace = braced!(contents in input);
172        let fields: Punctuated<ConfigField, Token![,]> =
173            contents.parse_terminated(ConfigField::parse, Token![,])?;
174        Ok(Config::build(fields.into_iter(), input.span())?)
175    }
176}
177
178/// The witx document(s) that will be loaded from a [`Config`](struct.Config.html).
179///
180/// A witx interface definition can be provided either as a collection of relative paths to
181/// documents, or as a single inlined string literal. Note that `(use ...)` directives are not
182/// permitted when providing a string literal.
183#[derive(Debug, Clone)]
184pub enum WitxConf {
185    /// A collection of paths pointing to witx files.
186    Paths(Paths),
187    /// A single witx document, provided as a string literal.
188    Literal(Literal),
189}
190
191impl WitxConf {
192    /// Load the `witx` document.
193    ///
194    /// # Panics
195    ///
196    /// This method will panic if the paths given in the `witx` field were not valid documents, or
197    /// if any of the given documents were not syntactically valid.
198    pub fn load_document(&self) -> witx::Document {
199        match self {
200            Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
201            Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
202        }
203    }
204}
205
206/// A collection of paths, pointing to witx documents.
207#[derive(Debug, Clone)]
208pub struct Paths(Vec<PathBuf>);
209
210impl Paths {
211    /// Create a new, empty collection of paths.
212    pub fn new() -> Self {
213        Default::default()
214    }
215}
216
217impl Default for Paths {
218    fn default() -> Self {
219        Self(Default::default())
220    }
221}
222
223impl AsRef<[PathBuf]> for Paths {
224    fn as_ref(&self) -> &[PathBuf] {
225        self.0.as_ref()
226    }
227}
228
229impl AsMut<[PathBuf]> for Paths {
230    fn as_mut(&mut self) -> &mut [PathBuf] {
231        self.0.as_mut()
232    }
233}
234
235impl FromIterator<PathBuf> for Paths {
236    fn from_iter<I>(iter: I) -> Self
237    where
238        I: IntoIterator<Item = PathBuf>,
239    {
240        Self(iter.into_iter().collect())
241    }
242}
243
244impl Parse for Paths {
245    fn parse(input: ParseStream) -> Result<Self> {
246        let content;
247        let _ = bracketed!(content in input);
248        let path_lits: Punctuated<LitStr, Token![,]> =
249            content.parse_terminated(Parse::parse, Token![,])?;
250
251        let expanded_paths = path_lits
252            .iter()
253            .map(|lit| {
254                PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()).join(lit.value())
255            })
256            .collect::<Vec<PathBuf>>();
257
258        Ok(Paths(expanded_paths))
259    }
260}
261
262/// A single witx document, provided as a string literal.
263#[derive(Debug, Clone)]
264pub struct Literal(String);
265
266impl AsRef<str> for Literal {
267    fn as_ref(&self) -> &str {
268        self.0.as_ref()
269    }
270}
271
272impl Parse for Literal {
273    fn parse(input: ParseStream) -> Result<Self> {
274        Ok(Self(input.parse::<syn::LitStr>()?.value()))
275    }
276}
277
278#[derive(Clone, Default, Debug)]
279/// Map from abi error type to rich error type
280pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
281
282impl ErrorConf {
283    pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
284        self.0.iter()
285    }
286}
287
288impl Parse for ErrorConf {
289    fn parse(input: ParseStream) -> Result<Self> {
290        let content;
291        let _ = braced!(content in input);
292        let items: Punctuated<ErrorConfField, Token![,]> =
293            content.parse_terminated(Parse::parse, Token![,])?;
294        let mut m = HashMap::new();
295        for i in items {
296            match m.insert(i.abi_error().clone(), i.clone()) {
297                None => {}
298                Some(prev_def) => {
299                    return Err(Error::new(
300                        *i.err_loc(),
301                        format!(
302                            "duplicate definition of rich error type for {:?}: previously defined at {:?}",
303                            i.abi_error(),
304                            prev_def.err_loc(),
305                        ),
306                    ));
307                }
308            }
309        }
310        Ok(ErrorConf(m))
311    }
312}
313
314#[derive(Debug, Clone)]
315pub enum ErrorConfField {
316    Trappable(TrappableErrorConfField),
317    User(UserErrorConfField),
318}
319impl ErrorConfField {
320    pub fn abi_error(&self) -> &Ident {
321        match self {
322            Self::Trappable(t) => &t.abi_error,
323            Self::User(u) => &u.abi_error,
324        }
325    }
326    pub fn err_loc(&self) -> &Span {
327        match self {
328            Self::Trappable(t) => &t.err_loc,
329            Self::User(u) => &u.err_loc,
330        }
331    }
332}
333
334impl Parse for ErrorConfField {
335    fn parse(input: ParseStream) -> Result<Self> {
336        let err_loc = input.span();
337        let abi_error = input.parse::<Ident>()?;
338        let _arrow: Token![=>] = input.parse()?;
339
340        let lookahead = input.lookahead1();
341        if lookahead.peek(kw::trappable) {
342            let _ = input.parse::<kw::trappable>()?;
343            let rich_error = input.parse()?;
344            Ok(ErrorConfField::Trappable(TrappableErrorConfField {
345                abi_error,
346                rich_error,
347                err_loc,
348            }))
349        } else {
350            let rich_error = input.parse::<syn::Path>()?;
351            Ok(ErrorConfField::User(UserErrorConfField {
352                abi_error,
353                rich_error,
354                err_loc,
355            }))
356        }
357    }
358}
359
360#[derive(Clone, Debug)]
361pub struct TrappableErrorConfField {
362    pub abi_error: Ident,
363    pub rich_error: Ident,
364    pub err_loc: Span,
365}
366
367#[derive(Clone)]
368pub struct UserErrorConfField {
369    pub abi_error: Ident,
370    pub rich_error: syn::Path,
371    pub err_loc: Span,
372}
373
374impl std::fmt::Debug for UserErrorConfField {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        f.debug_struct("ErrorConfField")
377            .field("abi_error", &self.abi_error)
378            .field("rich_error", &"(...)")
379            .field("err_loc", &self.err_loc)
380            .finish()
381    }
382}
383
384#[derive(Clone, Default, Debug)]
385/// Modules and funcs that have async signatures
386pub struct AsyncConf {
387    block_with: Option<TokenStream>,
388    functions: AsyncFunctions,
389}
390
391#[derive(Clone, Debug)]
392pub enum Asyncness {
393    /// Wiggle function is synchronous, wasmtime Func is synchronous
394    Sync,
395    /// Wiggle function is asynchronous, but wasmtime Func is synchronous
396    Blocking { block_with: TokenStream },
397    /// Wiggle function and wasmtime Func are asynchronous.
398    Async,
399}
400
401impl Asyncness {
402    pub fn is_async(&self) -> bool {
403        match self {
404            Self::Async => true,
405            _ => false,
406        }
407    }
408    pub fn blocking(&self) -> Option<&TokenStream> {
409        match self {
410            Self::Blocking { block_with } => Some(block_with),
411            _ => None,
412        }
413    }
414    pub fn is_sync(&self) -> bool {
415        match self {
416            Self::Sync => true,
417            _ => false,
418        }
419    }
420}
421
422#[derive(Clone, Debug)]
423pub enum AsyncFunctions {
424    Some(HashMap<String, Vec<String>>),
425    All,
426}
427impl Default for AsyncFunctions {
428    fn default() -> Self {
429        AsyncFunctions::Some(HashMap::default())
430    }
431}
432
433impl AsyncConf {
434    pub fn get(&self, module: &str, function: &str) -> Asyncness {
435        let a = match &self.block_with {
436            Some(block_with) => Asyncness::Blocking {
437                block_with: block_with.clone(),
438            },
439            None => Asyncness::Async,
440        };
441        match &self.functions {
442            AsyncFunctions::Some(fs) => {
443                if fs
444                    .get(module)
445                    .and_then(|fs| fs.iter().find(|f| *f == function))
446                    .is_some()
447                {
448                    a
449                } else {
450                    Asyncness::Sync
451                }
452            }
453            AsyncFunctions::All => a,
454        }
455    }
456
457    pub fn contains_async(&self, module: &witx::Module) -> bool {
458        for f in module.funcs() {
459            if self.get(module.name.as_str(), f.name.as_str()).is_async() {
460                return true;
461            }
462        }
463        false
464    }
465}
466
467impl Parse for AsyncFunctions {
468    fn parse(input: ParseStream) -> Result<Self> {
469        let content;
470        let lookahead = input.lookahead1();
471        if lookahead.peek(syn::token::Brace) {
472            let _ = braced!(content in input);
473            let items: Punctuated<FunctionField, Token![,]> =
474                content.parse_terminated(Parse::parse, Token![,])?;
475            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
476            use std::collections::hash_map::Entry;
477            for i in items {
478                let function_names = i
479                    .function_names
480                    .iter()
481                    .map(|i| i.to_string())
482                    .collect::<Vec<String>>();
483                match functions.entry(i.module_name.to_string()) {
484                    Entry::Occupied(o) => o.into_mut().extend(function_names),
485                    Entry::Vacant(v) => {
486                        v.insert(function_names);
487                    }
488                }
489            }
490            Ok(AsyncFunctions::Some(functions))
491        } else if lookahead.peek(Token![*]) {
492            let _: Token![*] = input.parse().unwrap();
493            Ok(AsyncFunctions::All)
494        } else {
495            Err(lookahead.error())
496        }
497    }
498}
499
500#[derive(Clone)]
501pub struct FunctionField {
502    pub module_name: Ident,
503    pub function_names: Vec<Ident>,
504    pub err_loc: Span,
505}
506
507impl Parse for FunctionField {
508    fn parse(input: ParseStream) -> Result<Self> {
509        let err_loc = input.span();
510        let module_name = input.parse::<Ident>()?;
511        let _doublecolon: Token![::] = input.parse()?;
512        let lookahead = input.lookahead1();
513        if lookahead.peek(syn::token::Brace) {
514            let content;
515            let _ = braced!(content in input);
516            let function_names: Punctuated<Ident, Token![,]> =
517                content.parse_terminated(Parse::parse, Token![,])?;
518            Ok(FunctionField {
519                module_name,
520                function_names: function_names.iter().cloned().collect(),
521                err_loc,
522            })
523        } else if lookahead.peek(Ident) {
524            let name = input.parse()?;
525            Ok(FunctionField {
526                module_name,
527                function_names: vec![name],
528                err_loc,
529            })
530        } else {
531            Err(lookahead.error())
532        }
533    }
534}
535
536#[derive(Clone)]
537pub struct WasmtimeConfig {
538    pub c: Config,
539    pub target: syn::Path,
540}
541
542#[derive(Clone)]
543pub enum WasmtimeConfigField {
544    Core(ConfigField),
545    Target(syn::Path),
546}
547impl WasmtimeConfig {
548    pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
549        let mut target = None;
550        let mut cs = Vec::new();
551        for f in fields {
552            match f {
553                WasmtimeConfigField::Target(c) => {
554                    if target.is_some() {
555                        return Err(Error::new(err_loc, "duplicate `target` field"));
556                    }
557                    target = Some(c);
558                }
559                WasmtimeConfigField::Core(c) => cs.push(c),
560            }
561        }
562        let c = Config::build(cs.into_iter(), err_loc)?;
563        Ok(WasmtimeConfig {
564            c,
565            target: target
566                .take()
567                .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
568        })
569    }
570}
571
572impl Parse for WasmtimeConfig {
573    fn parse(input: ParseStream) -> Result<Self> {
574        let contents;
575        let _lbrace = braced!(contents in input);
576        let fields: Punctuated<WasmtimeConfigField, Token![,]> =
577            contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
578        Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
579    }
580}
581
582impl Parse for WasmtimeConfigField {
583    fn parse(input: ParseStream) -> Result<Self> {
584        if input.peek(kw::target) {
585            input.parse::<kw::target>()?;
586            input.parse::<Token![:]>()?;
587            Ok(WasmtimeConfigField::Target(input.parse()?))
588        } else {
589            Ok(WasmtimeConfigField::Core(input.parse()?))
590        }
591    }
592}
593
594#[derive(Clone, Debug)]
595pub struct TracingConf {
596    enabled: bool,
597    excluded_functions: HashMap<String, Vec<String>>,
598}
599
600impl TracingConf {
601    pub fn enabled_for(&self, module: &str, function: &str) -> bool {
602        if !self.enabled {
603            return false;
604        }
605        self.excluded_functions
606            .get(module)
607            .and_then(|fs| fs.iter().find(|f| *f == function))
608            .is_none()
609    }
610}
611
612impl Default for TracingConf {
613    fn default() -> Self {
614        Self {
615            enabled: true,
616            excluded_functions: HashMap::new(),
617        }
618    }
619}
620
621impl Parse for TracingConf {
622    fn parse(input: ParseStream) -> Result<Self> {
623        let enabled = input.parse::<syn::LitBool>()?.value;
624
625        let lookahead = input.lookahead1();
626        if lookahead.peek(kw::disable_for) {
627            input.parse::<kw::disable_for>()?;
628            let content;
629            let _ = braced!(content in input);
630            let items: Punctuated<FunctionField, Token![,]> =
631                content.parse_terminated(Parse::parse, Token![,])?;
632            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
633            use std::collections::hash_map::Entry;
634            for i in items {
635                let function_names = i
636                    .function_names
637                    .iter()
638                    .map(|i| i.to_string())
639                    .collect::<Vec<String>>();
640                match functions.entry(i.module_name.to_string()) {
641                    Entry::Occupied(o) => o.into_mut().extend(function_names),
642                    Entry::Vacant(v) => {
643                        v.insert(function_names);
644                    }
645                }
646            }
647
648            Ok(TracingConf {
649                enabled,
650                excluded_functions: functions,
651            })
652        } else {
653            Ok(TracingConf {
654                enabled,
655                excluded_functions: HashMap::new(),
656            })
657        }
658    }
659}