Skip to main content

wasefire_wire/
schema.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Provides a schema of the wire format.
16//!
17//! This is only meant to be used during testing, to ensure a wire format is monotonic.
18
19use alloc::boxed::Box;
20use alloc::vec;
21use alloc::vec::Vec;
22use core::any::TypeId;
23use core::fmt::Display;
24use std::sync::Mutex;
25
26use wasefire_error::{Code, Error};
27
28use crate::internal::{Builtin, Rule, RuleEnum, RuleStruct, Rules, Wire};
29use crate::reader::Reader;
30use crate::{helper, internal};
31
32#[derive(Debug, Clone, PartialEq, Eq, wasefire_wire_derive::Wire)]
33#[wire(crate = crate)]
34pub enum View<'a> {
35    Builtin(Builtin),
36    Array(Box<View<'a>>, usize),
37    Slice(Box<View<'a>>),
38    Struct(ViewStruct<'a>),
39    Enum(ViewEnum<'a>),
40    RecUse(usize),
41    RecNew(usize, Box<View<'a>>),
42}
43
44pub type ViewStruct<'a> = Vec<(Option<&'a str>, View<'a>)>;
45pub type ViewEnum<'a> = Vec<(&'a str, u32, ViewStruct<'a>)>;
46
47impl View<'static> {
48    pub fn new<'a, T: Wire<'a>>() -> View<'static> {
49        let mut rules = Rules::default();
50        T::schema(&mut rules);
51        Traverse::new(&rules).extract_or_empty(TypeId::of::<T::Type<'static>>())
52    }
53}
54
55struct Traverse<'a> {
56    rules: &'a Rules,
57    next: usize,
58    path: Vec<(TypeId, Option<usize>)>,
59}
60
61impl<'a> Traverse<'a> {
62    fn new(rules: &'a Rules) -> Self {
63        Traverse { rules, next: 0, path: Vec::new() }
64    }
65
66    fn extract_or_empty(&mut self, id: TypeId) -> View<'static> {
67        match self.extract(id) {
68            Some(x) => x,
69            None => View::Enum(Vec::new()),
70        }
71    }
72
73    fn extract(&mut self, id: TypeId) -> Option<View<'static>> {
74        if let Some((_, rec)) = self.path.iter_mut().find(|(x, _)| *x == id) {
75            let rec = rec.get_or_insert_with(|| {
76                self.next += 1;
77                self.next
78            });
79            return Some(View::RecUse(*rec));
80        }
81        self.path.push((id, None));
82        let result: Option<_> = try {
83            match self.rules.get(id) {
84                Rule::Builtin(x) => View::Builtin(*x),
85                Rule::Alias(_) => unreachable!(),
86                Rule::Array(x, n) => View::Array(Box::new(self.extract(*x)?), *n),
87                Rule::Slice(x) => View::Slice(Box::new(self.extract_or_empty(*x))),
88                Rule::Struct(xs) => View::Struct(self.extract_struct(xs)?),
89                Rule::Enum(xs) => View::Enum(self.extract_enum(xs)),
90            }
91        };
92        let (id_, rec) = self.path.pop().unwrap();
93        assert_eq!(id_, id);
94        let result = result?;
95        Some(match rec {
96            Some(rec) => View::RecNew(rec, Box::new(result)),
97            None => result,
98        })
99    }
100
101    fn extract_struct(&mut self, xs: &RuleStruct) -> Option<ViewStruct<'static>> {
102        xs.iter()
103            .map(|(n, x)| Some((*n, self.extract(*x)?)))
104            .filter(|x| !matches!(x, Some((None, View::Struct(xs))) if xs.is_empty()))
105            .collect()
106    }
107
108    fn extract_enum(&mut self, xs: &RuleEnum) -> ViewEnum<'static> {
109        xs.iter().filter_map(|(n, i, xs)| Some((*n, *i, self.extract_struct(xs)?))).collect()
110    }
111}
112
113impl core::fmt::Display for Builtin {
114    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
115        match self {
116            Builtin::Bool => write!(f, "bool"),
117            Builtin::U8 => write!(f, "u8"),
118            Builtin::I8 => write!(f, "i8"),
119            Builtin::U16 => write!(f, "u16"),
120            Builtin::I16 => write!(f, "i16"),
121            Builtin::U32 => write!(f, "u32"),
122            Builtin::I32 => write!(f, "i32"),
123            Builtin::U64 => write!(f, "u64"),
124            Builtin::I64 => write!(f, "i64"),
125            Builtin::Usize => write!(f, "usize"),
126            Builtin::Isize => write!(f, "isize"),
127        }
128    }
129}
130
131impl core::fmt::Display for View<'_> {
132    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
133        match self {
134            View::Builtin(x) => write!(f, "{x}"),
135            View::Array(x, n) => write!(f, "[{x}; {n}]"),
136            View::Slice(x) => write!(f, "[{x}]"),
137            View::Struct(xs) => write_fields(f, xs),
138            View::Enum(xs) => write_list(f, xs),
139            View::RecUse(n) => write!(f, "<{n}>"),
140            View::RecNew(n, x) => write!(f, "<{n}>:{x}"),
141        }
142    }
143}
144
145#[derive(Debug, Copy, Clone, PartialEq, Eq)]
146pub struct ViewFields<'a>(pub &'a ViewStruct<'a>);
147
148impl core::fmt::Display for ViewFields<'_> {
149    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
150        write_fields(f, self.0)
151    }
152}
153
154trait List {
155    const BEG: char;
156    const END: char;
157    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
158    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
159}
160
161impl<'a> List for (Option<&'a str>, View<'a>) {
162    const BEG: char = '(';
163    const END: char = ')';
164    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
165        match self.0 {
166            Some(x) => write!(f, "{x}:"),
167            None => Ok(()),
168        }
169    }
170    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
171        self.1.fmt(f)
172    }
173}
174
175impl<'a> List for (&'a str, u32, ViewStruct<'a>) {
176    const BEG: char = '{';
177    const END: char = '}';
178    fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
179        match self.0 {
180            "" => write!(f, "{}:", self.1),
181            _ => write!(f, "{}={}:", self.0, self.1),
182        }
183    }
184    fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
185        write_fields(f, &self.2)
186    }
187}
188
189fn write_fields(f: &mut core::fmt::Formatter, xs: &[(Option<&str>, View)]) -> core::fmt::Result {
190    if xs.len() == 1 && xs[0].0.is_none() { xs[0].1.fmt(f) } else { write_list(f, xs) }
191}
192
193fn write_list<T: List>(f: &mut core::fmt::Formatter, xs: &[T]) -> core::fmt::Result {
194    write!(f, "{}", T::BEG)?;
195    let mut first = true;
196    for x in xs.iter() {
197        if !first {
198            write!(f, " ")?;
199        }
200        first = false;
201        x.fmt_name(f)?;
202        x.fmt_item(f)?;
203    }
204    write!(f, "{}", T::END)
205}
206
207impl View<'_> {
208    /// Simplifies a view preserving wire compatibility.
209    ///
210    /// Performs the following simplifications:
211    /// - Remove field names and use empty names for variants
212    /// - `[x; 0]` becomes `()`
213    /// - `[x; 1]` becomes `x`
214    /// - `[[x; n]; m]` becomes `[x; n * m]`
215    /// - `[{}; 2+]` becomes `{}`
216    /// - `(xs.. (ys..) zs..)` becomes `(xs.. ys.. zs..)`
217    /// - `(xs.. {} zs..)` becomes `{}`
218    /// - `(xs.. ys.. zs..)` becomes `(xs.. [y; n] zs..)` if `ys..` is `n` times `y`
219    /// - `(x)` becomes `x`
220    /// - `{xs.. =t:{} zs..}` becomes `{xs.. zs..}`
221    /// - `{xs..}` becomes `{ys..}` where `ys..` is `xs..` sorted by tags
222    /// - `<n>:x` (resp `<n>`) becomes `<k>:x` (resp. `<k>`) with `k` the number of recursion
223    ///   binders from the root
224    pub fn simplify(&self) -> View<'static> {
225        self.simplify_(RecStack::Root)
226    }
227
228    pub fn simplify_struct(xs: &ViewStruct) -> View<'static> {
229        View::simplify_struct_(xs, RecStack::Root)
230    }
231
232    fn simplify_(&self, rec: RecStack) -> View<'static> {
233        match self {
234            View::Builtin(x) => View::Builtin(*x),
235            View::Array(_, 0) => View::Struct(Vec::new()),
236            View::Array(x, 1) => x.simplify_(rec),
237            View::Array(x, n) => match x.simplify_(rec) {
238                View::Array(x, m) => View::Array(x, n * m),
239                View::Enum(xs) if xs.is_empty() => View::Enum(xs),
240                x => View::Array(Box::new(x), *n),
241            },
242            View::Slice(x) => View::Slice(Box::new(x.simplify_(rec))),
243            View::Struct(xs) => View::simplify_struct_(xs, rec),
244            View::Enum(xs) => {
245                let mut ys = Vec::new();
246                for (_, t, xs) in xs {
247                    let xs = match View::simplify_struct_(xs, rec) {
248                        View::Struct(xs) => xs,
249                        View::Enum(xs) if xs.is_empty() => continue,
250                        x => vec![(None, x)],
251                    };
252                    ys.push(("", *t, xs));
253                }
254                ys.sort_by_key(|(_, t, _)| *t);
255                View::Enum(ys)
256            }
257            View::RecUse(n) => View::RecUse(rec.use_(*n)),
258            View::RecNew(n, x) => View::RecNew(rec.len(), Box::new(x.simplify_(rec.new(*n)))),
259        }
260    }
261
262    fn simplify_struct_(xs: &ViewStruct, rec: RecStack) -> View<'static> {
263        let mut ys = Vec::new();
264        for (_, x) in xs {
265            match x.simplify_(rec) {
266                View::Struct(mut xs) => ys.append(&mut xs),
267                View::Enum(xs) if xs.is_empty() => return View::Enum(xs),
268                y => ys.push((None, y)),
269            }
270        }
271        let mut zs = Vec::new();
272        for (_, y) in ys {
273            let z = match zs.last_mut() {
274                Some((_, z)) => z,
275                None => {
276                    zs.push((None, y));
277                    continue;
278                }
279            };
280            match (z, y) {
281                (View::Array(x, n), View::Array(y, m)) if *x == y => *n += m,
282                (View::Array(x, n), y) if **x == y => *n += 1,
283                (x, View::Array(y, m)) if *x == *y => *x = View::Array(y, m + 1),
284                (x, y) if *x == y => *x = View::Array(Box::new(y), 2),
285                (_, y) => zs.push((None, y)),
286            }
287        }
288        match zs.len() {
289            1 => zs.pop().unwrap().1,
290            _ => View::Struct(zs),
291        }
292    }
293
294    /// Validates that serialized data matches the view.
295    ///
296    /// This function requires a global lock.
297    pub fn validate(&self, data: &[u8]) -> Result<(), Error> {
298        static GLOBAL_LOCK: Mutex<()> = Mutex::new(());
299        let _global_lock = GLOBAL_LOCK.lock().unwrap();
300        let _lock = ViewFrameLock::new(None, self);
301        let _ = crate::decode::<ViewDecoder>(data)?;
302        Ok(())
303    }
304
305    fn decode(&self, reader: &mut Reader) -> Result<(), Error> {
306        match self {
307            View::Builtin(Builtin::Bool) => drop(bool::decode(reader)?),
308            View::Builtin(Builtin::U8) => drop(u8::decode(reader)?),
309            View::Builtin(Builtin::I8) => drop(i8::decode(reader)?),
310            View::Builtin(Builtin::U16) => drop(u16::decode(reader)?),
311            View::Builtin(Builtin::I16) => drop(i16::decode(reader)?),
312            View::Builtin(Builtin::U32) => drop(u32::decode(reader)?),
313            View::Builtin(Builtin::I32) => drop(i32::decode(reader)?),
314            View::Builtin(Builtin::U64) => drop(u64::decode(reader)?),
315            View::Builtin(Builtin::I64) => drop(i64::decode(reader)?),
316            View::Builtin(Builtin::Usize) => drop(usize::decode(reader)?),
317            View::Builtin(Builtin::Isize) => drop(isize::decode(reader)?),
318            View::Array(x, n) => {
319                let _lock = ViewFrameLock::new(None, x);
320                let _ = helper::decode_array_dyn(*n, reader, decode_view)?;
321            }
322            View::Slice(x) => {
323                let _lock = ViewFrameLock::new(None, x);
324                let _ = helper::decode_slice(reader, decode_view)?;
325            }
326            View::Struct(xs) => {
327                for (_, x) in xs {
328                    x.decode(reader)?;
329                }
330            }
331            View::Enum(xs) => {
332                let tag = internal::decode_tag(reader)?;
333                let mut found = false;
334                for (_, i, xs) in xs {
335                    if tag == *i {
336                        assert!(!std::mem::replace(&mut found, true));
337                        for (_, x) in xs {
338                            x.decode(reader)?;
339                        }
340                    }
341                }
342                if !found {
343                    return Err(Error::user(Code::InvalidArgument));
344                }
345            }
346            View::RecUse(rec) => {
347                let view = VIEW_STACK
348                    .lock()
349                    .unwrap()
350                    .iter()
351                    .find(|x| x.rec == Some(*rec))
352                    .ok_or(Error::user(Code::InvalidArgument))?
353                    .view;
354                view.decode(reader)?;
355            }
356            View::RecNew(rec, x) => {
357                let _lock = ViewFrameLock::new(Some(*rec), x);
358                x.decode(reader)?;
359            }
360        }
361        Ok(())
362    }
363}
364
365static VIEW_STACK: Mutex<Vec<ViewFrame>> = Mutex::new(Vec::new());
366
367#[derive(Copy, Clone)]
368struct ViewFrame {
369    rec: Option<usize>,
370    view: &'static View<'static>,
371}
372
373struct ViewFrameLock(ViewFrame);
374
375impl ViewFrame {
376    fn key(self) -> (Option<usize>, *const View<'static>) {
377        (self.rec, self.view as *const _)
378    }
379}
380
381impl ViewFrameLock {
382    fn new(rec: Option<usize>, view: &View) -> Self {
383        // TODO(https://github.com/rust-lang/rust-clippy/issues/12860): Remove when fixed.
384        #[expect(clippy::unnecessary_cast)]
385        // SAFETY: This function is only called when drop is called before the lifetime ends.
386        let view = unsafe { &*(view as *const _ as *const View<'static>) };
387        let frame = ViewFrame { rec, view };
388        VIEW_STACK.lock().unwrap().push(frame);
389        ViewFrameLock(frame)
390    }
391}
392
393impl Drop for ViewFrameLock {
394    fn drop(&mut self) {
395        assert_eq!(VIEW_STACK.lock().unwrap().pop().unwrap().key(), self.0.key());
396    }
397}
398
399struct ViewDecoder;
400
401impl<'a> internal::Wire<'a> for ViewDecoder {
402    type Type<'b> = ViewDecoder;
403    fn schema(_rules: &mut Rules) {
404        unreachable!()
405    }
406    fn encode(&self, _: &mut internal::Writer<'a>) -> internal::Result<()> {
407        unreachable!()
408    }
409    fn decode(reader: &mut Reader<'a>) -> internal::Result<Self> {
410        decode_view(reader)?;
411        Ok(ViewDecoder)
412    }
413}
414
415fn decode_view(reader: &mut Reader) -> Result<(), Error> {
416    let view = VIEW_STACK.lock().unwrap().last().unwrap().view;
417    view.decode(reader)
418}
419
420#[derive(Copy, Clone)]
421enum RecStack<'a> {
422    Root,
423    Binder(usize, &'a RecStack<'a>),
424}
425
426impl<'a> RecStack<'a> {
427    fn use_(&self, x: usize) -> usize {
428        match self {
429            RecStack::Root => unreachable!(),
430            RecStack::Binder(y, r) if x == *y => r.len(),
431            RecStack::Binder(_, r) => r.use_(x),
432        }
433    }
434
435    fn len(&self) -> usize {
436        match self {
437            RecStack::Root => 0,
438            RecStack::Binder(_, r) => 1 + r.len(),
439        }
440    }
441
442    #[allow(clippy::wrong_self_convention)]
443    fn new(&'a self, x: usize) -> RecStack<'a> {
444        RecStack::Binder(x, self)
445    }
446}