1use {
2 proc_macro2::{Span, TokenStream},
3 std::{collections::HashMap, path::PathBuf},
4 syn::{
5 braced, bracketed,
6 parse::{Parse, ParseStream},
7 punctuated::Punctuated,
8 Error, Ident, LitStr, Result, Token,
9 },
10};
11
12#[derive(Debug, Clone)]
13pub struct Config {
14 pub witx: WitxConf,
15 pub errors: ErrorConf,
16 pub async_: AsyncConf,
17 pub wasmtime: bool,
18 pub tracing: TracingConf,
19 pub mutable: bool,
20}
21
22mod kw {
23 syn::custom_keyword!(witx);
24 syn::custom_keyword!(witx_literal);
25 syn::custom_keyword!(block_on);
26 syn::custom_keyword!(errors);
27 syn::custom_keyword!(target);
28 syn::custom_keyword!(wasmtime);
29 syn::custom_keyword!(mutable);
30 syn::custom_keyword!(tracing);
31 syn::custom_keyword!(disable_for);
32 syn::custom_keyword!(trappable);
33}
34
35#[derive(Debug, Clone)]
36pub enum ConfigField {
37 Witx(WitxConf),
38 Error(ErrorConf),
39 Async(AsyncConf),
40 Wasmtime(bool),
41 Tracing(TracingConf),
42 Mutable(bool),
43}
44
45impl Parse for ConfigField {
46 fn parse(input: ParseStream) -> Result<Self> {
47 let lookahead = input.lookahead1();
48 if lookahead.peek(kw::witx) {
49 input.parse::<kw::witx>()?;
50 input.parse::<Token![:]>()?;
51 Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
52 } else if lookahead.peek(kw::witx_literal) {
53 input.parse::<kw::witx_literal>()?;
54 input.parse::<Token![:]>()?;
55 Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
56 } else if lookahead.peek(kw::errors) {
57 input.parse::<kw::errors>()?;
58 input.parse::<Token![:]>()?;
59 Ok(ConfigField::Error(input.parse()?))
60 } else if lookahead.peek(Token![async]) {
61 input.parse::<Token![async]>()?;
62 input.parse::<Token![:]>()?;
63 Ok(ConfigField::Async(AsyncConf {
64 block_with: None,
65 functions: input.parse()?,
66 }))
67 } else if lookahead.peek(kw::block_on) {
68 input.parse::<kw::block_on>()?;
69 let block_with = if input.peek(syn::token::Bracket) {
70 let content;
71 let _ = bracketed!(content in input);
72 content.parse()?
73 } else {
74 quote::quote!(wiggle::run_in_dummy_executor)
75 };
76 input.parse::<Token![:]>()?;
77 Ok(ConfigField::Async(AsyncConf {
78 block_with: Some(block_with),
79 functions: input.parse()?,
80 }))
81 } else if lookahead.peek(kw::wasmtime) {
82 input.parse::<kw::wasmtime>()?;
83 input.parse::<Token![:]>()?;
84 Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
85 } else if lookahead.peek(kw::tracing) {
86 input.parse::<kw::tracing>()?;
87 input.parse::<Token![:]>()?;
88 Ok(ConfigField::Tracing(input.parse()?))
89 } else if lookahead.peek(kw::mutable) {
90 input.parse::<kw::mutable>()?;
91 input.parse::<Token![:]>()?;
92 Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
93 } else {
94 Err(lookahead.error())
95 }
96 }
97}
98
99impl Config {
100 pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
101 let mut witx = None;
102 let mut errors = None;
103 let mut async_ = None;
104 let mut wasmtime = None;
105 let mut tracing = None;
106 let mut mutable = None;
107 for f in fields {
108 match f {
109 ConfigField::Witx(c) => {
110 if witx.is_some() {
111 return Err(Error::new(err_loc, "duplicate `witx` field"));
112 }
113 witx = Some(c);
114 }
115 ConfigField::Error(c) => {
116 if errors.is_some() {
117 return Err(Error::new(err_loc, "duplicate `errors` field"));
118 }
119 errors = Some(c);
120 }
121 ConfigField::Async(c) => {
122 if async_.is_some() {
123 return Err(Error::new(err_loc, "duplicate `async` field"));
124 }
125 async_ = Some(c);
126 }
127 ConfigField::Wasmtime(c) => {
128 if wasmtime.is_some() {
129 return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
130 }
131 wasmtime = Some(c);
132 }
133 ConfigField::Tracing(c) => {
134 if tracing.is_some() {
135 return Err(Error::new(err_loc, "duplicate `tracing` field"));
136 }
137 tracing = Some(c);
138 }
139 ConfigField::Mutable(c) => {
140 if mutable.is_some() {
141 return Err(Error::new(err_loc, "duplicate `mutable` field"));
142 }
143 mutable = Some(c);
144 }
145 }
146 }
147 Ok(Config {
148 witx: witx
149 .take()
150 .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
151 errors: errors.take().unwrap_or_default(),
152 async_: async_.take().unwrap_or_default(),
153 wasmtime: wasmtime.unwrap_or(true),
154 tracing: tracing.unwrap_or_default(),
155 mutable: mutable.unwrap_or(true),
156 })
157 }
158
159 pub fn load_document(&self) -> witx::Document {
165 self.witx.load_document()
166 }
167}
168
169impl Parse for Config {
170 fn parse(input: ParseStream) -> Result<Self> {
171 let contents;
172 let _lbrace = braced!(contents in input);
173 let fields: Punctuated<ConfigField, Token![,]> =
174 contents.parse_terminated(ConfigField::parse, Token![,])?;
175 Ok(Config::build(fields.into_iter(), input.span())?)
176 }
177}
178
179#[derive(Debug, Clone)]
185pub enum WitxConf {
186 Paths(Paths),
188 Literal(Literal),
190}
191
192impl WitxConf {
193 pub fn load_document(&self) -> witx::Document {
200 match self {
201 Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
202 Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct Paths(Vec<PathBuf>);
210
211impl Paths {
212 pub fn new() -> Self {
214 Default::default()
215 }
216}
217
218impl Default for Paths {
219 fn default() -> Self {
220 Self(Default::default())
221 }
222}
223
224impl AsRef<[PathBuf]> for Paths {
225 fn as_ref(&self) -> &[PathBuf] {
226 self.0.as_ref()
227 }
228}
229
230impl AsMut<[PathBuf]> for Paths {
231 fn as_mut(&mut self) -> &mut [PathBuf] {
232 self.0.as_mut()
233 }
234}
235
236impl FromIterator<PathBuf> for Paths {
237 fn from_iter<I>(iter: I) -> Self
238 where
239 I: IntoIterator<Item = PathBuf>,
240 {
241 Self(iter.into_iter().collect())
242 }
243}
244
245impl Parse for Paths {
246 fn parse(input: ParseStream) -> Result<Self> {
247 let content;
248 let _ = bracketed!(content in input);
249 let path_lits: Punctuated<LitStr, Token![,]> =
250 content.parse_terminated(Parse::parse, Token![,])?;
251
252 let expanded_paths = path_lits
253 .iter()
254 .map(|lit| {
255 PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()).join(lit.value())
256 })
257 .collect::<Vec<PathBuf>>();
258
259 Ok(Paths(expanded_paths))
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct Literal(String);
266
267impl AsRef<str> for Literal {
268 fn as_ref(&self) -> &str {
269 self.0.as_ref()
270 }
271}
272
273impl Parse for Literal {
274 fn parse(input: ParseStream) -> Result<Self> {
275 Ok(Self(input.parse::<syn::LitStr>()?.value()))
276 }
277}
278
279#[derive(Clone, Default, Debug)]
280pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
282
283impl ErrorConf {
284 pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
285 self.0.iter()
286 }
287}
288
289impl Parse for ErrorConf {
290 fn parse(input: ParseStream) -> Result<Self> {
291 let content;
292 let _ = braced!(content in input);
293 let items: Punctuated<ErrorConfField, Token![,]> =
294 content.parse_terminated(Parse::parse, Token![,])?;
295 let mut m = HashMap::new();
296 for i in items {
297 match m.insert(i.abi_error().clone(), i.clone()) {
298 None => {}
299 Some(prev_def) => {
300 return Err(Error::new(
301 *i.err_loc(),
302 format!(
303 "duplicate definition of rich error type for {:?}: previously defined at {:?}",
304 i.abi_error(), prev_def.err_loc(),
305 ),
306 ))
307 }
308 }
309 }
310 Ok(ErrorConf(m))
311 }
312}
313
314#[derive(Debug, Clone)]
315pub enum ErrorConfField {
316 Trappable(TrappableErrorConfField),
317 User(UserErrorConfField),
318}
319impl ErrorConfField {
320 pub fn abi_error(&self) -> &Ident {
321 match self {
322 Self::Trappable(t) => &t.abi_error,
323 Self::User(u) => &u.abi_error,
324 }
325 }
326 pub fn err_loc(&self) -> &Span {
327 match self {
328 Self::Trappable(t) => &t.err_loc,
329 Self::User(u) => &u.err_loc,
330 }
331 }
332}
333
334impl Parse for ErrorConfField {
335 fn parse(input: ParseStream) -> Result<Self> {
336 let err_loc = input.span();
337 let abi_error = input.parse::<Ident>()?;
338 let _arrow: Token![=>] = input.parse()?;
339
340 let lookahead = input.lookahead1();
341 if lookahead.peek(kw::trappable) {
342 let _ = input.parse::<kw::trappable>()?;
343 let rich_error = input.parse()?;
344 Ok(ErrorConfField::Trappable(TrappableErrorConfField {
345 abi_error,
346 rich_error,
347 err_loc,
348 }))
349 } else {
350 let rich_error = input.parse::<syn::Path>()?;
351 Ok(ErrorConfField::User(UserErrorConfField {
352 abi_error,
353 rich_error,
354 err_loc,
355 }))
356 }
357 }
358}
359
360#[derive(Clone, Debug)]
361pub struct TrappableErrorConfField {
362 pub abi_error: Ident,
363 pub rich_error: Ident,
364 pub err_loc: Span,
365}
366
367#[derive(Clone)]
368pub struct UserErrorConfField {
369 pub abi_error: Ident,
370 pub rich_error: syn::Path,
371 pub err_loc: Span,
372}
373
374impl std::fmt::Debug for UserErrorConfField {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 f.debug_struct("ErrorConfField")
377 .field("abi_error", &self.abi_error)
378 .field("rich_error", &"(...)")
379 .field("err_loc", &self.err_loc)
380 .finish()
381 }
382}
383
384#[derive(Clone, Default, Debug)]
385pub struct AsyncConf {
387 block_with: Option<TokenStream>,
388 functions: AsyncFunctions,
389}
390
391#[derive(Clone, Debug)]
392pub enum Asyncness {
393 Sync,
395 Blocking { block_with: TokenStream },
397 Async,
399}
400
401impl Asyncness {
402 pub fn is_async(&self) -> bool {
403 match self {
404 Self::Async => true,
405 _ => false,
406 }
407 }
408 pub fn blocking(&self) -> Option<&TokenStream> {
409 match self {
410 Self::Blocking { block_with } => Some(block_with),
411 _ => None,
412 }
413 }
414 pub fn is_sync(&self) -> bool {
415 match self {
416 Self::Sync => true,
417 _ => false,
418 }
419 }
420}
421
422#[derive(Clone, Debug)]
423pub enum AsyncFunctions {
424 Some(HashMap<String, Vec<String>>),
425 All,
426}
427impl Default for AsyncFunctions {
428 fn default() -> Self {
429 AsyncFunctions::Some(HashMap::default())
430 }
431}
432
433impl AsyncConf {
434 pub fn get(&self, module: &str, function: &str) -> Asyncness {
435 let a = match &self.block_with {
436 Some(block_with) => Asyncness::Blocking {
437 block_with: block_with.clone(),
438 },
439 None => Asyncness::Async,
440 };
441 match &self.functions {
442 AsyncFunctions::Some(fs) => {
443 if fs
444 .get(module)
445 .and_then(|fs| fs.iter().find(|f| *f == function))
446 .is_some()
447 {
448 a
449 } else {
450 Asyncness::Sync
451 }
452 }
453 AsyncFunctions::All => a,
454 }
455 }
456
457 pub fn contains_async(&self, module: &witx::Module) -> bool {
458 for f in module.funcs() {
459 if self.get(module.name.as_str(), f.name.as_str()).is_async() {
460 return true;
461 }
462 }
463 false
464 }
465}
466
467impl Parse for AsyncFunctions {
468 fn parse(input: ParseStream) -> Result<Self> {
469 let content;
470 let lookahead = input.lookahead1();
471 if lookahead.peek(syn::token::Brace) {
472 let _ = braced!(content in input);
473 let items: Punctuated<FunctionField, Token![,]> =
474 content.parse_terminated(Parse::parse, Token![,])?;
475 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
476 use std::collections::hash_map::Entry;
477 for i in items {
478 let function_names = i
479 .function_names
480 .iter()
481 .map(|i| i.to_string())
482 .collect::<Vec<String>>();
483 match functions.entry(i.module_name.to_string()) {
484 Entry::Occupied(o) => o.into_mut().extend(function_names),
485 Entry::Vacant(v) => {
486 v.insert(function_names);
487 }
488 }
489 }
490 Ok(AsyncFunctions::Some(functions))
491 } else if lookahead.peek(Token![*]) {
492 let _: Token![*] = input.parse().unwrap();
493 Ok(AsyncFunctions::All)
494 } else {
495 Err(lookahead.error())
496 }
497 }
498}
499
500#[derive(Clone)]
501pub struct FunctionField {
502 pub module_name: Ident,
503 pub function_names: Vec<Ident>,
504 pub err_loc: Span,
505}
506
507impl Parse for FunctionField {
508 fn parse(input: ParseStream) -> Result<Self> {
509 let err_loc = input.span();
510 let module_name = input.parse::<Ident>()?;
511 let _doublecolon: Token![::] = input.parse()?;
512 let lookahead = input.lookahead1();
513 if lookahead.peek(syn::token::Brace) {
514 let content;
515 let _ = braced!(content in input);
516 let function_names: Punctuated<Ident, Token![,]> =
517 content.parse_terminated(Parse::parse, Token![,])?;
518 Ok(FunctionField {
519 module_name,
520 function_names: function_names.iter().cloned().collect(),
521 err_loc,
522 })
523 } else if lookahead.peek(Ident) {
524 let name = input.parse()?;
525 Ok(FunctionField {
526 module_name,
527 function_names: vec![name],
528 err_loc,
529 })
530 } else {
531 Err(lookahead.error())
532 }
533 }
534}
535
536#[derive(Clone)]
537pub struct WasmtimeConfig {
538 pub c: Config,
539 pub target: syn::Path,
540}
541
542#[derive(Clone)]
543pub enum WasmtimeConfigField {
544 Core(ConfigField),
545 Target(syn::Path),
546}
547impl WasmtimeConfig {
548 pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
549 let mut target = None;
550 let mut cs = Vec::new();
551 for f in fields {
552 match f {
553 WasmtimeConfigField::Target(c) => {
554 if target.is_some() {
555 return Err(Error::new(err_loc, "duplicate `target` field"));
556 }
557 target = Some(c);
558 }
559 WasmtimeConfigField::Core(c) => cs.push(c),
560 }
561 }
562 let c = Config::build(cs.into_iter(), err_loc)?;
563 Ok(WasmtimeConfig {
564 c,
565 target: target
566 .take()
567 .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
568 })
569 }
570}
571
572impl Parse for WasmtimeConfig {
573 fn parse(input: ParseStream) -> Result<Self> {
574 let contents;
575 let _lbrace = braced!(contents in input);
576 let fields: Punctuated<WasmtimeConfigField, Token![,]> =
577 contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
578 Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
579 }
580}
581
582impl Parse for WasmtimeConfigField {
583 fn parse(input: ParseStream) -> Result<Self> {
584 if input.peek(kw::target) {
585 input.parse::<kw::target>()?;
586 input.parse::<Token![:]>()?;
587 Ok(WasmtimeConfigField::Target(input.parse()?))
588 } else {
589 Ok(WasmtimeConfigField::Core(input.parse()?))
590 }
591 }
592}
593
594#[derive(Clone, Debug)]
595pub struct TracingConf {
596 enabled: bool,
597 excluded_functions: HashMap<String, Vec<String>>,
598}
599
600impl TracingConf {
601 pub fn enabled_for(&self, module: &str, function: &str) -> bool {
602 if !self.enabled {
603 return false;
604 }
605 self.excluded_functions
606 .get(module)
607 .and_then(|fs| fs.iter().find(|f| *f == function))
608 .is_none()
609 }
610}
611
612impl Default for TracingConf {
613 fn default() -> Self {
614 Self {
615 enabled: true,
616 excluded_functions: HashMap::new(),
617 }
618 }
619}
620
621impl Parse for TracingConf {
622 fn parse(input: ParseStream) -> Result<Self> {
623 let enabled = input.parse::<syn::LitBool>()?.value;
624
625 let lookahead = input.lookahead1();
626 if lookahead.peek(kw::disable_for) {
627 input.parse::<kw::disable_for>()?;
628 let content;
629 let _ = braced!(content in input);
630 let items: Punctuated<FunctionField, Token![,]> =
631 content.parse_terminated(Parse::parse, Token![,])?;
632 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
633 use std::collections::hash_map::Entry;
634 for i in items {
635 let function_names = i
636 .function_names
637 .iter()
638 .map(|i| i.to_string())
639 .collect::<Vec<String>>();
640 match functions.entry(i.module_name.to_string()) {
641 Entry::Occupied(o) => o.into_mut().extend(function_names),
642 Entry::Vacant(v) => {
643 v.insert(function_names);
644 }
645 }
646 }
647
648 Ok(TracingConf {
649 enabled,
650 excluded_functions: functions,
651 })
652 } else {
653 Ok(TracingConf {
654 enabled,
655 excluded_functions: HashMap::new(),
656 })
657 }
658 }
659}