wasefire_wire/
schema.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides a schema of the wire format.
16//!
17//! This is only meant to be used during testing, to ensure a wire format is monotonic.
18
19use alloc::boxed::Box;
20use alloc::vec;
21use alloc::vec::Vec;
22use core::any::TypeId;
23use core::fmt::Display;
24use std::sync::Mutex;
25
26use wasefire_error::{Code, Error};
27
28use crate::internal::{Builtin, Rule, RuleEnum, RuleStruct, Rules, Wire};
29use crate::reader::Reader;
30use crate::{helper, internal};
31
32#[derive(Debug, Clone, PartialEq, Eq, wasefire_wire_derive::Wire)]
33#[wire(crate = crate)]
34pub enum View<'a> {
35    Builtin(Builtin),
36    Array(Box<View<'a>>, usize),
37    Slice(Box<View<'a>>),
38    Struct(ViewStruct<'a>),
39    Enum(ViewEnum<'a>),
40    RecUse(usize),
41    RecNew(usize, Box<View<'a>>),
42}
43
44pub type ViewStruct<'a> = Vec<(Option<&'a str>, View<'a>)>;
45pub type ViewEnum<'a> = Vec<(&'a str, u32, ViewStruct<'a>)>;
46
47impl View<'static> {
48    pub fn new<'a, T: Wire<'a>>() -> View<'static> {
49        let mut rules = Rules::default();
50        T::schema(&mut rules);
51        Traverse::new(&rules).extract_or_empty(TypeId::of::<T::Type<'static>>())
52    }
53}
54
55struct Traverse<'a> {
56    rules: &'a Rules,
57    next: usize,
58    path: Vec<(TypeId, Option<usize>)>,
59}
60
61impl<'a> Traverse<'a> {
62    fn new(rules: &'a Rules) -> Self {
63        Traverse { rules, next: 0, path: Vec::new() }
64    }
65
66    fn extract_or_empty(&mut self, id: TypeId) -> View<'static> {
67        match self.extract(id) {
68            Some(x) => x,
69            None => View::Enum(Vec::new()),
70        }
71    }
72
73    fn extract(&mut self, id: TypeId) -> Option<View<'static>> {
74        if let Some((_, rec)) = self.path.iter_mut().find(|(x, _)| *x == id) {
75            let rec = rec.get_or_insert_with(|| {
76                self.next += 1;
77                self.next
78            });
79            return Some(View::RecUse(*rec));
80        }
81        self.path.push((id, None));
82        let result: Option<_> = try {
83            match self.rules.get(id) {
84                Rule::Builtin(x) => View::Builtin(*x),
85                Rule::Array(x, n) => View::Array(Box::new(self.extract(*x)?), *n),
86                Rule::Slice(x) => View::Slice(Box::new(self.extract_or_empty(*x))),
87                Rule::Struct(xs) => View::Struct(self.extract_struct(xs)?),
88                Rule::Enum(xs) => View::Enum(self.extract_enum(xs)),
89            }
90        };
91        let (id_, rec) = self.path.pop().unwrap();
92        assert_eq!(id_, id);
93        let result = result?;
94        Some(match rec {
95            Some(rec) => View::RecNew(rec, Box::new(result)),
96            None => result,
97        })
98    }
99
100    fn extract_struct(&mut self, xs: &RuleStruct) -> Option<ViewStruct<'static>> {
101        xs.iter()
102            .map(|(n, x)| Some((*n, self.extract(*x)?)))
103            .filter(|x| !matches!(x, Some((None, View::Struct(xs))) if xs.is_empty()))
104            .collect()
105    }
106
107    fn extract_enum(&mut self, xs: &RuleEnum) -> ViewEnum<'static> {
108        xs.iter().filter_map(|(n, i, xs)| Some((*n, *i, self.extract_struct(xs)?))).collect()
109    }
110}
111
112impl core::fmt::Display for Builtin {
113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114        match self {
115            Builtin::Bool => write!(f, "bool"),
116            Builtin::U8 => write!(f, "u8"),
117            Builtin::I8 => write!(f, "i8"),
118            Builtin::U16 => write!(f, "u16"),
119            Builtin::I16 => write!(f, "i16"),
120            Builtin::U32 => write!(f, "u32"),
121            Builtin::I32 => write!(f, "i32"),
122            Builtin::U64 => write!(f, "u64"),
123            Builtin::I64 => write!(f, "i64"),
124            Builtin::Usize => write!(f, "usize"),
125            Builtin::Isize => write!(f, "isize"),
126            Builtin::Str => write!(f, "str"),
127        }
128    }
129}
130
131impl core::fmt::Display for View<'_> {
132    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
133        match self {
134            View::Builtin(x) => write!(f, "{x}"),
135            View::Array(x, n) => write!(f, "[{x}; {n}]"),
136            View::Slice(x) => write!(f, "[{x}]"),
137            View::Struct(xs) => write_fields(f, xs),
138            View::Enum(xs) => write_list(f, xs),
139            View::RecUse(n) => write!(f, "<{n}>"),
140            View::RecNew(n, x) => write!(f, "<{n}>:{x}"),
141        }
142    }
143}
144
145#[derive(Debug, Copy, Clone, PartialEq, Eq)]
146pub struct ViewFields<'a>(pub &'a ViewStruct<'a>);
147
148impl core::fmt::Display for ViewFields<'_> {
149    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
150        write_fields(f, self.0)
151    }
152}
153
154trait List {
155    const BEG: char;
156    const END: char;
157    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
158    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
159}
160
161impl<'a> List for (Option<&'a str>, View<'a>) {
162    const BEG: char = '(';
163    const END: char = ')';
164    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
165        match self.0 {
166            Some(x) => write!(f, "{x}:"),
167            None => Ok(()),
168        }
169    }
170    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
171        self.1.fmt(f)
172    }
173}
174
175impl<'a> List for (&'a str, u32, ViewStruct<'a>) {
176    const BEG: char = '{';
177    const END: char = '}';
178    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
179        match self.0 {
180            "" => write!(f, "{}:", self.1),
181            _ => write!(f, "{}={}:", self.0, self.1),
182        }
183    }
184    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
185        write_fields(f, &self.2)
186    }
187}
188
189fn write_fields(f: &mut core::fmt::Formatter, xs: &[(Option<&str>, View)]) -> core::fmt::Result {
190    if xs.len() == 1 && xs[0].0.is_none() { xs[0].1.fmt(f) } else { write_list(f, xs) }
191}
192
193fn write_list<T: List>(f: &mut core::fmt::Formatter, xs: &[T]) -> core::fmt::Result {
194    write!(f, "{}", T::BEG)?;
195    let mut first = true;
196    for x in xs.iter() {
197        if !first {
198            write!(f, " ")?;
199        }
200        first = false;
201        x.fmt_name(f)?;
202        x.fmt_item(f)?;
203    }
204    write!(f, "{}", T::END)
205}
206
207impl View<'_> {
208    /// Simplifies a view preserving wire compatibility.
209    ///
210    /// Performs the following simplifications:
211    /// - Remove field names and use empty names for variants
212    /// - `[x; 0]` becomes `()`
213    /// - `[x; 1]` becomes `x`
214    /// - `[[x; n]; m]` becomes `[x; n * m]`
215    /// - `[{}; 2+]` becomes `{}`
216    /// - `(xs.. (ys..) zs..)` becomes `(xs.. ys.. zs..)`
217    /// - `(xs.. {} zs..)` becomes `{}`
218    /// - `(xs.. ys.. zs..)` becomes `(xs.. [y; n] zs..)` if `ys..` is `n` times `y`
219    /// - `(x)` becomes `x`
220    /// - `{xs.. =t:{} zs..}` becomes `{xs.. zs..}`
221    /// - `{xs..}` becomes `{ys..}` where `ys..` is `xs..` sorted by tags
222    /// - `<n>:x` (resp `<n>`) becomes `<k>:x` (resp. `<k>`) with `k` the number of recursion
223    ///   binders from the root
224    pub fn simplify(&self) -> View<'static> {
225        self.simplify_(RecStack::Root)
226    }
227
228    pub fn simplify_struct(xs: &ViewStruct) -> View<'static> {
229        View::simplify_struct_(xs, RecStack::Root)
230    }
231
232    fn simplify_(&self, rec: RecStack) -> View<'static> {
233        match self {
234            View::Builtin(x) => View::Builtin(*x),
235            View::Array(_, 0) => View::Struct(Vec::new()),
236            View::Array(x, 1) => x.simplify_(rec),
237            View::Array(x, n) => match x.simplify_(rec) {
238                View::Array(x, m) => View::Array(x, n * m),
239                View::Enum(xs) if xs.is_empty() => View::Enum(xs),
240                x => View::Array(Box::new(x), *n),
241            },
242            View::Slice(x) => View::Slice(Box::new(x.simplify_(rec))),
243            View::Struct(xs) => View::simplify_struct_(xs, rec),
244            View::Enum(xs) => {
245                let mut ys = Vec::new();
246                for (_, t, xs) in xs {
247                    let xs = match View::simplify_struct_(xs, rec) {
248                        View::Struct(xs) => xs,
249                        View::Enum(xs) if xs.is_empty() => continue,
250                        x => vec![(None, x)],
251                    };
252                    ys.push(("", *t, xs));
253                }
254                ys.sort_by_key(|(_, t, _)| *t);
255                View::Enum(ys)
256            }
257            View::RecUse(n) => View::RecUse(rec.use_(*n)),
258            View::RecNew(n, x) => View::RecNew(rec.len(), Box::new(x.simplify_(rec.new(*n)))),
259        }
260    }
261
262    fn simplify_struct_(xs: &ViewStruct, rec: RecStack) -> View<'static> {
263        let mut ys = Vec::new();
264        for (_, x) in xs {
265            match x.simplify_(rec) {
266                View::Struct(mut xs) => ys.append(&mut xs),
267                View::Enum(xs) if xs.is_empty() => return View::Enum(xs),
268                y => ys.push((None, y)),
269            }
270        }
271        let mut zs = Vec::new();
272        for (_, y) in ys {
273            let z = match zs.last_mut() {
274                Some((_, z)) => z,
275                None => {
276                    zs.push((None, y));
277                    continue;
278                }
279            };
280            match (z, y) {
281                (View::Array(x, n), View::Array(y, m)) if *x == y => *n += m,
282                (View::Array(x, n), y) if **x == y => *n += 1,
283                (x, View::Array(y, m)) if *x == *y => *x = View::Array(y, m + 1),
284                (x, y) if *x == y => *x = View::Array(Box::new(y), 2),
285                (_, y) => zs.push((None, y)),
286            }
287        }
288        match zs.len() {
289            1 => zs.pop().unwrap().1,
290            _ => View::Struct(zs),
291        }
292    }
293
294    /// Validates that serialized data matches the view.
295    ///
296    /// This function requires a global lock.
297    pub fn validate(&self, data: &[u8]) -> Result<(), Error> {
298        static GLOBAL_LOCK: Mutex<()> = Mutex::new(());
299        let _global_lock = GLOBAL_LOCK.lock().unwrap();
300        let _lock = ViewFrameLock::new(None, self);
301        let _ = crate::decode::<ViewDecoder>(data)?;
302        Ok(())
303    }
304
305    fn decode(&self, reader: &mut Reader) -> Result<(), Error> {
306        match self {
307            View::Builtin(Builtin::Bool) => drop(bool::decode(reader)?),
308            View::Builtin(Builtin::U8) => drop(u8::decode(reader)?),
309            View::Builtin(Builtin::I8) => drop(i8::decode(reader)?),
310            View::Builtin(Builtin::U16) => drop(u16::decode(reader)?),
311            View::Builtin(Builtin::I16) => drop(i16::decode(reader)?),
312            View::Builtin(Builtin::U32) => drop(u32::decode(reader)?),
313            View::Builtin(Builtin::I32) => drop(i32::decode(reader)?),
314            View::Builtin(Builtin::U64) => drop(u64::decode(reader)?),
315            View::Builtin(Builtin::I64) => drop(i64::decode(reader)?),
316            View::Builtin(Builtin::Usize) => drop(usize::decode(reader)?),
317            View::Builtin(Builtin::Isize) => drop(isize::decode(reader)?),
318            View::Builtin(Builtin::Str) => drop(<&str>::decode(reader)?),
319            View::Array(x, n) => {
320                let _lock = ViewFrameLock::new(None, x);
321                let _ = helper::decode_array_dyn(*n, reader, decode_view)?;
322            }
323            View::Slice(x) => {
324                let _lock = ViewFrameLock::new(None, x);
325                let _ = helper::decode_slice(reader, decode_view)?;
326            }
327            View::Struct(xs) => {
328                for (_, x) in xs {
329                    x.decode(reader)?;
330                }
331            }
332            View::Enum(xs) => {
333                let tag = internal::decode_tag(reader)?;
334                let mut found = false;
335                for (_, i, xs) in xs {
336                    if tag == *i {
337                        assert!(!std::mem::replace(&mut found, true));
338                        for (_, x) in xs {
339                            x.decode(reader)?;
340                        }
341                    }
342                }
343                if !found {
344                    return Err(Error::user(Code::InvalidArgument));
345                }
346            }
347            View::RecUse(rec) => {
348                let view = VIEW_STACK
349                    .lock()
350                    .unwrap()
351                    .iter()
352                    .find(|x| x.rec == Some(*rec))
353                    .ok_or(Error::user(Code::InvalidArgument))?
354                    .view;
355                view.decode(reader)?;
356            }
357            View::RecNew(rec, x) => {
358                let _lock = ViewFrameLock::new(Some(*rec), x);
359                x.decode(reader)?;
360            }
361        }
362        Ok(())
363    }
364}
365
366static VIEW_STACK: Mutex<Vec<ViewFrame>> = Mutex::new(Vec::new());
367
368#[derive(Copy, Clone)]
369struct ViewFrame {
370    rec: Option<usize>,
371    view: &'static View<'static>,
372}
373
374struct ViewFrameLock(ViewFrame);
375
376impl ViewFrame {
377    fn key(self) -> (Option<usize>, *const View<'static>) {
378        (self.rec, self.view as *const _)
379    }
380}
381
382impl ViewFrameLock {
383    fn new(rec: Option<usize>, view: &View) -> Self {
384        // TODO(https://github.com/rust-lang/rust-clippy/issues/12860): Remove when fixed.
385        #[expect(clippy::unnecessary_cast)]
386        // SAFETY: This function is only called when drop is called before the lifetime ends.
387        let view = unsafe { &*(view as *const _ as *const View<'static>) };
388        let frame = ViewFrame { rec, view };
389        VIEW_STACK.lock().unwrap().push(frame);
390        ViewFrameLock(frame)
391    }
392}
393
394impl Drop for ViewFrameLock {
395    fn drop(&mut self) {
396        assert_eq!(VIEW_STACK.lock().unwrap().pop().unwrap().key(), self.0.key());
397    }
398}
399
400struct ViewDecoder;
401
402impl<'a> internal::Wire<'a> for ViewDecoder {
403    type Type<'b> = ViewDecoder;
404    fn schema(_rules: &mut Rules) {
405        unreachable!()
406    }
407    fn encode(&self, _: &mut internal::Writer<'a>) -> internal::Result<()> {
408        unreachable!()
409    }
410    fn decode(reader: &mut Reader<'a>) -> internal::Result<Self> {
411        decode_view(reader)?;
412        Ok(ViewDecoder)
413    }
414}
415
416fn decode_view(reader: &mut Reader) -> Result<(), Error> {
417    let view = VIEW_STACK.lock().unwrap().last().unwrap().view;
418    view.decode(reader)
419}
420
421#[derive(Copy, Clone)]
422enum RecStack<'a> {
423    Root,
424    Binder(usize, &'a RecStack<'a>),
425}
426
427impl<'a> RecStack<'a> {
428    fn use_(&self, x: usize) -> usize {
429        match self {
430            RecStack::Root => unreachable!(),
431            RecStack::Binder(y, r) if x == *y => r.len(),
432            RecStack::Binder(_, r) => r.use_(x),
433        }
434    }
435
436    fn len(&self) -> usize {
437        match self {
438            RecStack::Root => 0,
439            RecStack::Binder(_, r) => 1 + r.len(),
440        }
441    }
442
443    #[allow(clippy::wrong_self_convention)]
444    fn new(&'a self, x: usize) -> RecStack<'a> {
445        RecStack::Binder(x, self)
446    }
447}