cosmian_wit_parser/
lib.rs

1use anyhow::{anyhow, bail, Context, Result};
2use id_arena::{Arena, Id};
3use pulldown_cmark::{CodeBlockKind, CowStr, Event, Options, Parser, Tag};
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8pub mod abi;
9mod ast;
10mod sizealign;
11pub use sizealign::*;
12
13#[derive(Debug)]
14pub struct Interface {
15    pub name: String,
16    pub types: Arena<TypeDef>,
17    pub type_lookup: HashMap<String, TypeId>,
18    pub resources: Arena<Resource>,
19    pub resource_lookup: HashMap<String, ResourceId>,
20    pub interfaces: Arena<Interface>,
21    pub interface_lookup: HashMap<String, InterfaceId>,
22    pub functions: Vec<Function>,
23    pub globals: Vec<Global>,
24}
25
26pub type TypeId = Id<TypeDef>;
27pub type ResourceId = Id<Resource>;
28pub type InterfaceId = Id<Interface>;
29
30#[derive(Debug)]
31pub struct TypeDef {
32    pub docs: Docs,
33    pub kind: TypeDefKind,
34    pub name: Option<String>,
35    /// `None` if this type is originally declared in this instance or
36    /// otherwise `Some` if it was originally defined in a different module.
37    pub foreign_module: Option<String>,
38}
39
40#[derive(Debug)]
41pub enum TypeDefKind {
42    Record(Record),
43    Variant(Variant),
44    List(Type),
45    Pointer(Type),
46    ConstPointer(Type),
47    PushBuffer(Type),
48    PullBuffer(Type),
49    Type(Type),
50}
51
52#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
53pub enum Type {
54    U8,
55    U16,
56    U32,
57    U64,
58    S8,
59    S16,
60    S32,
61    S64,
62    F32,
63    F64,
64    Char,
65    CChar,
66    Usize,
67    Handle(ResourceId),
68    Id(TypeId),
69}
70
71#[derive(PartialEq, Debug, Copy, Clone)]
72pub enum Int {
73    U8,
74    U16,
75    U32,
76    U64,
77}
78
79#[derive(Debug)]
80pub struct Record {
81    pub fields: Vec<Field>,
82    pub kind: RecordKind,
83}
84
85#[derive(Copy, Clone, Debug)]
86pub enum RecordKind {
87    Other,
88    Flags(Option<Int>),
89    Tuple,
90}
91
92#[derive(Debug)]
93pub struct Field {
94    pub docs: Docs,
95    pub name: String,
96    pub ty: Type,
97}
98
99impl Record {
100    pub fn is_tuple(&self) -> bool {
101        matches!(self.kind, RecordKind::Tuple)
102    }
103
104    pub fn is_flags(&self) -> bool {
105        matches!(self.kind, RecordKind::Flags(_))
106    }
107
108    pub fn num_i32s(&self) -> usize {
109        (self.fields.len() + 31) / 32
110    }
111}
112
113impl RecordKind {
114    fn infer(types: &Arena<TypeDef>, fields: &[Field]) -> RecordKind {
115        if fields.is_empty() {
116            return RecordKind::Other;
117        }
118
119        // Structs-of-bools are classified to get represented as bitflags.
120        if fields.iter().all(|t| is_bool(&t.ty, types)) {
121            return RecordKind::Flags(None);
122        }
123
124        // fields with consecutive integer names get represented as tuples.
125        if fields
126            .iter()
127            .enumerate()
128            .all(|(i, m)| m.name.as_str().parse().ok() == Some(i))
129        {
130            return RecordKind::Tuple;
131        }
132
133        return RecordKind::Other;
134
135        fn is_bool(t: &Type, types: &Arena<TypeDef>) -> bool {
136            match t {
137                Type::Id(v) => match &types[*v].kind {
138                    TypeDefKind::Variant(v) => v.is_bool(),
139                    TypeDefKind::Type(t) => is_bool(t, types),
140                    _ => false,
141                },
142                _ => false,
143            }
144        }
145    }
146}
147
148#[derive(Debug)]
149pub struct Variant {
150    pub cases: Vec<Case>,
151    /// The bit representation of the width of this variant's tag when the
152    /// variant is stored in memory.
153    pub tag: Int,
154}
155
156#[derive(Debug)]
157pub struct Case {
158    pub docs: Docs,
159    pub name: String,
160    pub ty: Option<Type>,
161}
162
163impl Variant {
164    pub fn infer_tag(cases: usize) -> Int {
165        match cases {
166            n if n <= u8::max_value() as usize => Int::U8,
167            n if n <= u16::max_value() as usize => Int::U16,
168            n if n <= u32::max_value() as usize => Int::U32,
169            n if n <= u64::max_value() as usize => Int::U64,
170            _ => panic!("too many cases to fit in a repr"),
171        }
172    }
173
174    pub fn is_bool(&self) -> bool {
175        self.cases.len() == 2
176            && self.cases[0].name == "false"
177            && self.cases[1].name == "true"
178            && self.cases[0].ty.is_none()
179            && self.cases[1].ty.is_none()
180    }
181
182    pub fn is_enum(&self) -> bool {
183        self.cases.iter().all(|c| c.ty.is_none())
184    }
185
186    pub fn as_option(&self) -> Option<&Type> {
187        if self.cases.len() != 2 {
188            return None;
189        }
190        if self.cases[0].name != "none" || self.cases[0].ty.is_some() {
191            return None;
192        }
193        if self.cases[1].name != "some" {
194            return None;
195        }
196        self.cases[1].ty.as_ref()
197    }
198
199    pub fn as_expected(&self) -> Option<(Option<&Type>, Option<&Type>)> {
200        if self.cases.len() != 2 {
201            return None;
202        }
203        if self.cases[0].name != "ok" {
204            return None;
205        }
206        if self.cases[1].name != "err" {
207            return None;
208        }
209        Some((self.cases[0].ty.as_ref(), self.cases[1].ty.as_ref()))
210    }
211}
212
213#[derive(Clone, Default, Debug)]
214pub struct Docs {
215    pub contents: Option<String>,
216}
217
218#[derive(Debug)]
219pub struct Resource {
220    pub docs: Docs,
221    pub name: String,
222    /// `None` if this resource is defined within the containing instance,
223    /// otherwise `Some` if it's defined in an instance named here.
224    pub foreign_module: Option<String>,
225}
226
227#[derive(Debug)]
228pub struct Global {
229    pub docs: Docs,
230    pub name: String,
231    pub ty: Type,
232}
233
234#[derive(Debug)]
235pub struct Function {
236    pub abi: abi::Abi,
237    pub is_async: bool,
238    pub docs: Docs,
239    pub name: String,
240    pub kind: FunctionKind,
241    pub params: Vec<(String, Type)>,
242    pub results: Vec<(String, Type)>,
243}
244
245#[derive(Debug)]
246pub enum FunctionKind {
247    Freestanding,
248    Static { resource: ResourceId, name: String },
249    Method { resource: ResourceId, name: String },
250}
251
252impl Function {
253    pub fn item_name(&self) -> &str {
254        match &self.kind {
255            FunctionKind::Freestanding => &self.name,
256            FunctionKind::Static { name, .. } => name,
257            FunctionKind::Method { name, .. } => name,
258        }
259    }
260}
261
262fn unwrap_md(contents: &str) -> String {
263    let mut wit = String::new();
264    let mut last_pos = 0;
265    let mut in_wit_code_block = false;
266    Parser::new_ext(contents, Options::empty())
267        .into_offset_iter()
268        .for_each(|(event, range)| match (event, range) {
269            (Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(CowStr::Borrowed("wit")))), _) => {
270                in_wit_code_block = true;
271            }
272            (Event::Text(text), range) if in_wit_code_block => {
273                // Ensure that offsets are correct by inserting newlines to
274                // cover the Markdown content outside of wit code blocks.
275                for _ in contents[last_pos..range.start].lines() {
276                    wit.push_str("\n");
277                }
278                wit.push_str(&text);
279                last_pos = range.end;
280            }
281            (Event::End(Tag::CodeBlock(CodeBlockKind::Fenced(CowStr::Borrowed("wit")))), _) => {
282                in_wit_code_block = false;
283            }
284            _ => {}
285        });
286    wit
287}
288
289impl Interface {
290    pub fn parse(name: &str, input: &str) -> Result<Interface> {
291        Interface::parse_with(name, input, |f| {
292            Err(anyhow!("cannot load submodule `{}`", f))
293        })
294    }
295
296    pub fn parse_file(path: impl AsRef<Path>) -> Result<Interface> {
297        let path = path.as_ref();
298        let parent = path.parent().unwrap();
299        let contents = std::fs::read_to_string(&path)
300            .with_context(|| format!("failed to read: {}", path.display()))?;
301        Interface::parse_with(path, &contents, |path| load_fs(parent, path))
302    }
303
304    pub fn parse_with(
305        filename: impl AsRef<Path>,
306        contents: &str,
307        mut load: impl FnMut(&str) -> Result<(PathBuf, String)>,
308    ) -> Result<Interface> {
309        Interface::_parse_with(
310            filename.as_ref(),
311            contents,
312            &mut load,
313            &mut HashSet::new(),
314            &mut HashMap::new(),
315        )
316    }
317
318    fn _parse_with(
319        filename: &Path,
320        contents: &str,
321        load: &mut dyn FnMut(&str) -> Result<(PathBuf, String)>,
322        visiting: &mut HashSet<PathBuf>,
323        map: &mut HashMap<String, Interface>,
324    ) -> Result<Interface> {
325        let mut name = filename.file_stem().unwrap();
326        let mut contents = contents;
327
328        // If we have a ".md" file, it's a wit file wrapped in a markdown file;
329        // parse the markdown to extract the `wit` code blocks.
330        let md_contents;
331        if filename.extension().and_then(|s| s.to_str()) == Some("md") {
332            md_contents = unwrap_md(contents);
333            contents = &md_contents[..];
334
335            // Also strip the inner ".wit" extension.
336            name = Path::new(name).file_stem().unwrap();
337        }
338
339        // Parse the `contents `into an AST
340        let ast = match ast::Ast::parse(contents) {
341            Ok(ast) => ast,
342            Err(mut e) => {
343                let file = filename.display().to_string();
344                ast::rewrite_error(&mut e, &file, contents);
345                return Err(e);
346            }
347        };
348
349        // Load up any modules into our `map` that have not yet been parsed.
350        if !visiting.insert(filename.to_path_buf()) {
351            bail!("file `{}` recursively imports itself", filename.display())
352        }
353        for item in ast.items.iter() {
354            let u = match item {
355                ast::Item::Use(u) => u,
356                _ => continue,
357            };
358            if map.contains_key(&*u.from[0].name) {
359                continue;
360            }
361            let (filename, contents) = load(&u.from[0].name)
362                // TODO: insert context here about `u.name.span` and `filename`
363                ?;
364            let instance = Interface::_parse_with(&filename, &contents, load, visiting, map)?;
365            map.insert(u.from[0].name.to_string(), instance);
366        }
367        visiting.remove(filename);
368
369        // and finally resolve everything into our final instance
370        match ast.resolve(name.to_str().unwrap(), map) {
371            Ok(i) => Ok(i),
372            Err(mut e) => {
373                let file = filename.display().to_string();
374                ast::rewrite_error(&mut e, &file, contents);
375                Err(e)
376            }
377        }
378    }
379
380    pub fn topological_types(&self) -> Vec<TypeId> {
381        let mut ret = Vec::new();
382        let mut visited = HashSet::new();
383        for (id, _) in self.types.iter() {
384            self.topo_visit(id, &mut ret, &mut visited);
385        }
386        ret
387    }
388
389    fn topo_visit(&self, id: TypeId, list: &mut Vec<TypeId>, visited: &mut HashSet<TypeId>) {
390        if !visited.insert(id) {
391            return;
392        }
393        match &self.types[id].kind {
394            TypeDefKind::Type(t)
395            | TypeDefKind::List(t)
396            | TypeDefKind::PushBuffer(t)
397            | TypeDefKind::PullBuffer(t)
398            | TypeDefKind::Pointer(t)
399            | TypeDefKind::ConstPointer(t) => self.topo_visit_ty(t, list, visited),
400            TypeDefKind::Record(r) => {
401                for f in r.fields.iter() {
402                    self.topo_visit_ty(&f.ty, list, visited);
403                }
404            }
405            TypeDefKind::Variant(v) => {
406                for v in v.cases.iter() {
407                    if let Some(ty) = &v.ty {
408                        self.topo_visit_ty(ty, list, visited);
409                    }
410                }
411            }
412        }
413        list.push(id);
414    }
415
416    fn topo_visit_ty(&self, ty: &Type, list: &mut Vec<TypeId>, visited: &mut HashSet<TypeId>) {
417        if let Type::Id(id) = ty {
418            self.topo_visit(*id, list, visited);
419        }
420    }
421
422    pub fn all_bits_valid(&self, ty: &Type) -> bool {
423        match ty {
424            Type::U8
425            | Type::S8
426            | Type::U16
427            | Type::S16
428            | Type::U32
429            | Type::S32
430            | Type::U64
431            | Type::S64
432            | Type::F32
433            | Type::F64
434            | Type::CChar
435            | Type::Usize => true,
436
437            Type::Char | Type::Handle(_) => false,
438
439            Type::Id(id) => match &self.types[*id].kind {
440                TypeDefKind::List(_)
441                | TypeDefKind::Variant(_)
442                | TypeDefKind::PushBuffer(_)
443                | TypeDefKind::PullBuffer(_) => false,
444                TypeDefKind::Type(t) => self.all_bits_valid(t),
445                TypeDefKind::Record(r) => r.fields.iter().all(|f| self.all_bits_valid(&f.ty)),
446                TypeDefKind::Pointer(_) | TypeDefKind::ConstPointer(_) => true,
447            },
448        }
449    }
450
451    pub fn has_preview1_pointer(&self, ty: &Type) -> bool {
452        match ty {
453            Type::Id(id) => match &self.types[*id].kind {
454                TypeDefKind::List(t) | TypeDefKind::PushBuffer(t) | TypeDefKind::PullBuffer(t) => {
455                    self.has_preview1_pointer(t)
456                }
457                TypeDefKind::Type(t) => self.has_preview1_pointer(t),
458                TypeDefKind::Pointer(_) | TypeDefKind::ConstPointer(_) => true,
459                TypeDefKind::Record(r) => r.fields.iter().any(|f| self.has_preview1_pointer(&f.ty)),
460                TypeDefKind::Variant(v) => v.cases.iter().any(|c| match &c.ty {
461                    Some(ty) => self.has_preview1_pointer(ty),
462                    None => false,
463                }),
464            },
465            _ => false,
466        }
467    }
468}
469
470fn load_fs(root: &Path, name: &str) -> Result<(PathBuf, String)> {
471    // TODO: only read one, not both
472    let wit = root.join(name).with_extension("wit");
473    let witx = root.join(name).with_extension("witx");
474    let contents = fs::read_to_string(&wit)
475        .or_else(|_| fs::read_to_string(&witx))
476        .context(format!("failed to read `{}`", wit.display()))?;
477    Ok((wit, contents))
478}