wasm_wave/
ast.rs

1//! Abstract syntax tree types
2
3use core::str::FromStr;
4
5use alloc::{
6    borrow::Cow,
7    collections::BTreeMap,
8    string::{String, ToString},
9    vec::Vec,
10};
11
12use crate::{
13    lex::Span,
14    parser::{ParserError, ParserErrorKind},
15    strings::{StringPartsIter, unescape},
16    wasm::{WasmType, WasmTypeKind, WasmValue, WasmValueError},
17};
18
19/// A WAVE AST node.
20#[derive(Clone, Debug)]
21pub struct Node {
22    ty: NodeType,
23    span: Span,
24    children: Vec<Node>,
25}
26
27impl Node {
28    pub(crate) fn new(
29        ty: NodeType,
30        span: impl Into<Span>,
31        children: impl IntoIterator<Item = Node>,
32    ) -> Self {
33        Self {
34            ty,
35            span: span.into(),
36            children: Vec::from_iter(children),
37        }
38    }
39
40    /// Returns this node's type.
41    pub fn ty(&self) -> NodeType {
42        self.ty
43    }
44
45    /// Returns this node's span.
46    pub fn span(&self) -> Span {
47        self.span.clone()
48    }
49
50    /// Returns a bool value if this node represents a bool.
51    pub fn as_bool(&self) -> Result<bool, ParserError> {
52        match self.ty {
53            NodeType::BoolTrue => Ok(true),
54            NodeType::BoolFalse => Ok(false),
55            _ => Err(self.error(ParserErrorKind::InvalidType)),
56        }
57    }
58
59    /// Returns a number value of the given type (integer or float) if this node
60    /// can represent a number of that type.
61    pub fn as_number<T: FromStr>(&self, src: &str) -> Result<T, ParserError> {
62        self.ensure_type(NodeType::Number)?;
63        self.slice(src)
64            .parse()
65            .map_err(|_| self.error(ParserErrorKind::InvalidValue))
66    }
67
68    /// Returns a char value if this node represents a valid char.
69    pub fn as_char(&self, src: &str) -> Result<char, ParserError> {
70        self.ensure_type(NodeType::Char)?;
71        let inner = &src[self.span.start + 1..self.span.end - 1];
72        let (ch, len) = if inner.starts_with('\\') {
73            unescape(inner).ok_or_else(|| self.error(ParserErrorKind::InvalidEscape))?
74        } else {
75            let ch = inner.chars().next().unwrap();
76            (ch, ch.len_utf8())
77        };
78        // Verify length
79        if len != inner.len() {
80            return Err(self.error(ParserErrorKind::MultipleChars));
81        }
82        Ok(ch)
83    }
84
85    /// Returns a str value if this node represents a valid string.
86    pub fn as_str<'src>(&self, src: &'src str) -> Result<Cow<'src, str>, ParserError> {
87        let mut parts = self.iter_str(src)?;
88        let Some(first) = parts.next().transpose()? else {
89            return Ok("".into());
90        };
91        match parts.next().transpose()? {
92            // Single part may be borrowed
93            None => Ok(first),
94            // Multiple parts must be collected into a single owned String
95            Some(second) => {
96                let s: String = [Ok(first), Ok(second)]
97                    .into_iter()
98                    .chain(parts)
99                    .collect::<Result<_, _>>()?;
100                Ok(s.into())
101            }
102        }
103    }
104
105    /// Returns an iterator of string "parts" which together form a decoded
106    /// string value if this node represents a valid string.
107    pub fn iter_str<'src>(
108        &self,
109        src: &'src str,
110    ) -> Result<impl Iterator<Item = Result<Cow<'src, str>, ParserError>>, ParserError> {
111        match self.ty {
112            NodeType::String => {
113                let span = self.span.start + 1..self.span.end - 1;
114                Ok(StringPartsIter::new(&src[span.clone()], span.start))
115            }
116            NodeType::MultilineString => {
117                let span = self.span.start + 3..self.span.end - 3;
118                Ok(StringPartsIter::new_multiline(
119                    &src[span.clone()],
120                    span.start,
121                )?)
122            }
123            _ => Err(self.error(ParserErrorKind::InvalidType)),
124        }
125    }
126
127    /// Returns an iterator of value nodes if this node represents a tuple.
128    pub fn as_tuple(&self) -> Result<impl ExactSizeIterator<Item = &Node>, ParserError> {
129        self.ensure_type(NodeType::Tuple)?;
130        Ok(self.children.iter())
131    }
132
133    /// Returns an iterator of value nodes if this node represents a list.
134    pub fn as_list(&self) -> Result<impl ExactSizeIterator<Item = &Node>, ParserError> {
135        self.ensure_type(NodeType::List)?;
136        Ok(self.children.iter())
137    }
138
139    /// Returns an iterator of field name and value node pairs if this node
140    /// represents a record value.
141    pub fn as_record<'this, 'src>(
142        &'this self,
143        src: &'src str,
144    ) -> Result<impl ExactSizeIterator<Item = (&'src str, &'this Node)>, ParserError> {
145        self.ensure_type(NodeType::Record)?;
146        Ok(self
147            .children
148            .chunks(2)
149            .map(|chunk| (chunk[0].as_label(src).unwrap(), &chunk[1])))
150    }
151
152    /// Returns a variant label and optional payload if this node can represent
153    /// a variant value.
154    pub fn as_variant<'this, 'src>(
155        &'this self,
156        src: &'src str,
157    ) -> Result<(&'src str, Option<&'this Node>), ParserError> {
158        match self.ty {
159            NodeType::Label => Ok((self.as_label(src)?, None)),
160            NodeType::VariantWithPayload => {
161                let label = self.children[0].as_label(src)?;
162                let value = &self.children[1];
163                Ok((label, Some(value)))
164            }
165            _ => Err(self.error(ParserErrorKind::InvalidType)),
166        }
167    }
168
169    /// Returns an enum value label if this node represents a label.
170    pub fn as_enum<'src>(&self, src: &'src str) -> Result<&'src str, ParserError> {
171        self.as_label(src)
172    }
173
174    /// Returns an option value if this node represents an option.
175    pub fn as_option(&self) -> Result<Option<&Node>, ParserError> {
176        match self.ty {
177            NodeType::OptionSome => Ok(Some(&self.children[0])),
178            NodeType::OptionNone => Ok(None),
179            _ => Err(self.error(ParserErrorKind::InvalidType)),
180        }
181    }
182
183    /// Returns a result value with optional payload value if this node
184    /// represents a result.
185    pub fn as_result(&self) -> Result<Result<Option<&Node>, Option<&Node>>, ParserError> {
186        let payload = self.children.first();
187        match self.ty {
188            NodeType::ResultOk => Ok(Ok(payload)),
189            NodeType::ResultErr => Ok(Err(payload)),
190            _ => Err(self.error(ParserErrorKind::InvalidType)),
191        }
192    }
193
194    /// Returns an iterator of flag labels if this node represents flags.
195    pub fn as_flags<'this, 'src: 'this>(
196        &'this self,
197        src: &'src str,
198    ) -> Result<impl Iterator<Item = &'src str> + 'this, ParserError> {
199        self.ensure_type(NodeType::Flags)?;
200        Ok(self.children.iter().map(|node| {
201            debug_assert_eq!(node.ty, NodeType::Label);
202            node.slice(src)
203        }))
204    }
205
206    fn as_label<'src>(&self, src: &'src str) -> Result<&'src str, ParserError> {
207        self.ensure_type(NodeType::Label)?;
208        let label = self.slice(src);
209        let label = label.strip_prefix('%').unwrap_or(label);
210        Ok(label)
211    }
212
213    /// Converts this node into the given typed value from the given input source.
214    pub fn to_wasm_value<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
215        Ok(match ty.kind() {
216            WasmTypeKind::Bool => V::make_bool(self.as_bool()?),
217            WasmTypeKind::S8 => V::make_s8(self.as_number(src)?),
218            WasmTypeKind::S16 => V::make_s16(self.as_number(src)?),
219            WasmTypeKind::S32 => V::make_s32(self.as_number(src)?),
220            WasmTypeKind::S64 => V::make_s64(self.as_number(src)?),
221            WasmTypeKind::U8 => V::make_u8(self.as_number(src)?),
222            WasmTypeKind::U16 => V::make_u16(self.as_number(src)?),
223            WasmTypeKind::U32 => V::make_u32(self.as_number(src)?),
224            WasmTypeKind::U64 => V::make_u64(self.as_number(src)?),
225            WasmTypeKind::F32 => V::make_f32(self.as_number(src)?),
226            WasmTypeKind::F64 => V::make_f64(self.as_number(src)?),
227            WasmTypeKind::Char => V::make_char(self.as_char(src)?),
228            WasmTypeKind::String => V::make_string(self.as_str(src)?),
229            WasmTypeKind::List => self.to_wasm_list(ty, src)?,
230            WasmTypeKind::Record => self.to_wasm_record(ty, src)?,
231            WasmTypeKind::Tuple => self.to_wasm_tuple(ty, src)?,
232            WasmTypeKind::Variant => self.to_wasm_variant(ty, src)?,
233            WasmTypeKind::Enum => self.to_wasm_enum(ty, src)?,
234            WasmTypeKind::Option => self.to_wasm_option(ty, src)?,
235            WasmTypeKind::Result => self.to_wasm_result(ty, src)?,
236            WasmTypeKind::Flags => self.to_wasm_flags(ty, src)?,
237            other => {
238                return Err(
239                    self.wasm_value_error(WasmValueError::UnsupportedType(other.to_string()))
240                );
241            }
242        })
243    }
244
245    /// Converts this node into the given types.
246    /// See [`crate::untyped::UntypedFuncCall::to_wasm_params`].
247    pub fn to_wasm_params<'types, V: WasmValue + 'static>(
248        &self,
249        types: impl IntoIterator<Item = &'types V::Type>,
250        src: &str,
251    ) -> Result<Vec<V>, ParserError> {
252        let mut types = types.into_iter();
253        let mut values = self
254            .as_tuple()?
255            .map(|node| {
256                let ty = types.next().ok_or_else(|| {
257                    ParserError::with_detail(
258                        ParserErrorKind::InvalidParams,
259                        node.span().clone(),
260                        "more param(s) than expected",
261                    )
262                })?;
263                node.to_wasm_value::<V>(ty, src)
264            })
265            .collect::<Result<Vec<_>, _>>()?;
266        // Handle trailing optional fields
267        for ty in types {
268            if ty.kind() == WasmTypeKind::Option {
269                values.push(V::make_option(ty, None).map_err(|err| self.wasm_value_error(err))?);
270            } else {
271                return Err(ParserError::with_detail(
272                    ParserErrorKind::InvalidParams,
273                    self.span.end - 1..self.span.end,
274                    "missing required param(s)",
275                ));
276            }
277        }
278        Ok(values)
279    }
280
281    fn to_wasm_list<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
282        let element_type = ty.list_element_type().unwrap();
283        let elements = self
284            .as_list()?
285            .map(|node| node.to_wasm_value(&element_type, src))
286            .collect::<Result<Vec<_>, _>>()?;
287        V::make_list(ty, elements).map_err(|err| self.wasm_value_error(err))
288    }
289
290    fn to_wasm_record<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
291        let values = self.as_record(src)?.collect::<BTreeMap<_, _>>();
292        let record_fields = ty.record_fields().collect::<Vec<_>>();
293        let fields = record_fields
294            .iter()
295            .map(|(name, field_type)| {
296                let value = match (values.get(name.as_ref()), field_type.kind()) {
297                    (Some(node), _) => node.to_wasm_value(field_type, src)?,
298                    (None, WasmTypeKind::Option) => V::make_option(field_type, None)
299                        .map_err(|err| self.wasm_value_error(err))?,
300                    _ => {
301                        return Err(
302                            self.wasm_value_error(WasmValueError::MissingField(name.to_string()))
303                        );
304                    }
305                };
306                Ok((name.as_ref(), value))
307            })
308            .collect::<Result<Vec<_>, _>>()?;
309        V::make_record(ty, fields).map_err(|err| self.wasm_value_error(err))
310    }
311
312    fn to_wasm_tuple<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
313        let types = ty.tuple_element_types().collect::<Vec<_>>();
314        let values = self.as_tuple()?;
315        if types.len() != values.len() {
316            return Err(
317                self.wasm_value_error(WasmValueError::WrongNumberOfTupleValues {
318                    want: types.len(),
319                    got: values.len(),
320                }),
321            );
322        }
323        let values = ty
324            .tuple_element_types()
325            .zip(self.as_tuple()?)
326            .map(|(ty, node)| node.to_wasm_value(&ty, src))
327            .collect::<Result<Vec<_>, _>>()?;
328        V::make_tuple(ty, values).map_err(|err| self.wasm_value_error(err))
329    }
330
331    fn to_wasm_variant<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
332        let (label, payload) = self.as_variant(src)?;
333        let payload_type = ty
334            .variant_cases()
335            .find_map(|(case, payload)| (case == label).then_some(payload))
336            .ok_or_else(|| self.wasm_value_error(WasmValueError::UnknownCase(label.into())))?;
337        let payload_value = self.to_wasm_maybe_payload(label, &payload_type, payload, src)?;
338        V::make_variant(ty, label, payload_value).map_err(|err| self.wasm_value_error(err))
339    }
340
341    fn to_wasm_enum<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
342        V::make_enum(ty, self.as_enum(src)?).map_err(|err| self.wasm_value_error(err))
343    }
344
345    fn to_wasm_option<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
346        let payload_type = ty.option_some_type().unwrap();
347        let value = match self.ty {
348            NodeType::OptionSome => {
349                self.to_wasm_maybe_payload("some", &Some(payload_type), self.as_option()?, src)?
350            }
351            NodeType::OptionNone => {
352                self.to_wasm_maybe_payload("none", &None, self.as_option()?, src)?
353            }
354            _ if flattenable(payload_type.kind()) => Some(self.to_wasm_value(&payload_type, src)?),
355            _ => {
356                return Err(self.error(ParserErrorKind::InvalidType));
357            }
358        };
359        V::make_option(ty, value).map_err(|err| self.wasm_value_error(err))
360    }
361
362    fn to_wasm_result<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
363        let (ok_type, err_type) = ty.result_types().unwrap();
364        let value = match self.ty {
365            NodeType::ResultOk => {
366                Ok(self.to_wasm_maybe_payload("ok", &ok_type, self.as_result()?.unwrap(), src)?)
367            }
368            NodeType::ResultErr => Err(self.to_wasm_maybe_payload(
369                "err",
370                &err_type,
371                self.as_result()?.unwrap_err(),
372                src,
373            )?),
374            _ => match ok_type {
375                Some(ty) if flattenable(ty.kind()) => Ok(Some(self.to_wasm_value(&ty, src)?)),
376                _ => return Err(self.error(ParserErrorKind::InvalidType)),
377            },
378        };
379        V::make_result(ty, value).map_err(|err| self.wasm_value_error(err))
380    }
381
382    fn to_wasm_flags<V: WasmValue>(&self, ty: &V::Type, src: &str) -> Result<V, ParserError> {
383        V::make_flags(ty, self.as_flags(src)?).map_err(|err| self.wasm_value_error(err))
384    }
385
386    fn to_wasm_maybe_payload<V: WasmValue>(
387        &self,
388        case: &str,
389        payload_type: &Option<V::Type>,
390        payload: Option<&Node>,
391        src: &str,
392    ) -> Result<Option<V>, ParserError> {
393        match (payload_type.as_ref(), payload) {
394            (Some(ty), Some(node)) => Ok(Some(node.to_wasm_value(ty, src)?)),
395            (None, None) => Ok(None),
396            (Some(_), None) => {
397                Err(self.wasm_value_error(WasmValueError::MissingPayload(case.into())))
398            }
399            (None, Some(_)) => {
400                Err(self.wasm_value_error(WasmValueError::UnexpectedPayload(case.into())))
401            }
402        }
403    }
404
405    fn wasm_value_error(&self, err: WasmValueError) -> ParserError {
406        ParserError::with_source(ParserErrorKind::WasmValueError, self.span(), err)
407    }
408
409    pub(crate) fn slice<'src>(&self, src: &'src str) -> &'src str {
410        &src[self.span()]
411    }
412
413    fn ensure_type(&self, ty: NodeType) -> Result<(), ParserError> {
414        if self.ty == ty {
415            Ok(())
416        } else {
417            Err(self.error(ParserErrorKind::InvalidType))
418        }
419    }
420
421    fn error(&self, kind: ParserErrorKind) -> ParserError {
422        ParserError::new(kind, self.span())
423    }
424}
425
426fn flattenable(kind: WasmTypeKind) -> bool {
427    // TODO: Consider wither to allow flattening an option in a result or vice-versa.
428    !matches!(kind, WasmTypeKind::Option | WasmTypeKind::Result)
429}
430
431/// The type of a WAVE AST [`Node`].
432#[derive(Clone, Copy, Debug, PartialEq)]
433pub enum NodeType {
434    /// Boolean `true`
435    BoolTrue,
436    /// Boolean `false`
437    BoolFalse,
438    /// Number
439    /// May be an integer or float, including `nan`, `inf`, `-inf`
440    Number,
441    /// Char
442    /// Span includes delimiters.
443    Char,
444    /// String
445    /// Span includes delimiters.
446    String,
447    /// Multiline String
448    /// Span includes delimiters.
449    MultilineString,
450    /// Tuple
451    /// Child nodes are the tuple values.
452    Tuple,
453    /// List
454    /// Child nodes are the list values.
455    List,
456    /// Record
457    /// Child nodes are field Label, value pairs, e.g.
458    /// `[<field 1>, <value 1>, <field 2>, <value 2>, ...]`
459    Record,
460    /// Label
461    /// In value position may represent an enum value or variant case (without payload).
462    Label,
463    /// Variant case with payload
464    /// Child nodes are variant case Label and payload value.
465    VariantWithPayload,
466    /// Option `some`
467    /// Child node is the payload value.
468    OptionSome,
469    /// Option `none`
470    OptionNone,
471    /// Result `ok`
472    /// Has zero or one child node for the payload value.
473    ResultOk,
474    /// Result `err`
475    /// Has zero or one child node for the payload value.
476    ResultErr,
477    /// Flags
478    /// Child nodes are flag Labels.
479    Flags,
480}