Skip to main content

wasefire_wire/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wasefire wire format.
16//!
17//! This crate provides a binary format for a wire used as an RPC from a large host to a small
18//! device. The format is compact and canonical, in particular it is not self-describing.
19//! Compatibility is encoded with tags of a top-level enum, in particular RPC messages are never
20//! changed but instead duplicated to a new variant. The host supports all variants because it is
21//! not constrained. The device only supports the latest versions to minimize binary footprint. The
22//! host and the device are written in Rust, so wire types are defined in Rust. The data model is
23//! simple and contains builtin types, arrays, slices, structs, enums, and supports recursion.
24//!
25//! Alternatives like serde (with postcard) or protocol buffers solve a more general problem than
26//! this use-case. The main differences are:
27//!
28//! - Not self-describing: the model is simpler and more robust (smaller code footprint on device).
29//! - No special cases for options and maps: those are encoded from basic types.
30//! - No need for tagged and optional fields: full messages are versioned.
31//! - Variant tags can be explicit, and thus feature-gated to reduce device code size.
32//! - Wire types are only used to represent wire data, they are not used as regular data types.
33//! - Wire types only borrow from the wire, and do so in a covariant way.
34//! - Wire types can be inspected programmatically for unit testing.
35//! - Users can't implement the wire trait: they can only derive it.
36
37#![no_std]
38#![feature(array_try_from_fn)]
39#![feature(doc_cfg)]
40#![feature(never_type)]
41#![feature(try_blocks)]
42
43extern crate alloc;
44#[cfg(feature = "std")]
45extern crate std;
46
47use alloc::borrow::{Cow, ToOwned};
48use alloc::boxed::Box;
49use alloc::string::String;
50use alloc::vec::Vec;
51#[cfg(feature = "schema")]
52use core::any::TypeId;
53use core::convert::Infallible;
54use core::mem::{ManuallyDrop, MaybeUninit};
55
56use wasefire_common::platform::Side;
57use wasefire_error::{Code, Error};
58pub use wasefire_wire_derive::Wire;
59use wasefire_wire_derive::internal_wire;
60
61#[cfg(feature = "schema")]
62use crate::internal::{Builtin, Rules};
63use crate::reader::Reader;
64use crate::writer::Writer;
65
66mod helper;
67pub mod internal;
68mod reader;
69#[cfg(feature = "schema")]
70pub mod schema;
71mod writer;
72
73pub trait Wire<'a>: internal::Wire<'a> {}
74impl<'a, T: internal::Wire<'a>> Wire<'a> for T {}
75
76pub fn encode_suffix<'a, T: Wire<'a>>(data: &mut Vec<u8>, value: &T) -> Result<(), Error> {
77    let mut writer = Writer::new();
78    value.encode(&mut writer)?;
79    Ok(writer.finalize(data))
80}
81
82pub fn encode<'a, T: Wire<'a>>(value: &T) -> Result<Box<[u8]>, Error> {
83    let mut data = Vec::new();
84    encode_suffix(&mut data, value)?;
85    Ok(data.into_boxed_slice())
86}
87
88pub fn decode_prefix<'a, T: Wire<'a>>(data: &mut &'a [u8]) -> Result<T, Error> {
89    let mut reader = Reader::new(data);
90    let value = T::decode(&mut reader)?;
91    *data = reader.finalize();
92    Ok(value)
93}
94
95pub fn decode<'a, T: Wire<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
96    let value = decode_prefix(&mut data)?;
97    Error::user(Code::InvalidLength).check(data.is_empty())?;
98    Ok(value)
99}
100
101pub struct Yoke<T: Wire<'static>> {
102    // TODO(https://github.com/rust-lang/rust/issues/118166): Use MaybeDangling.
103    value: MaybeUninit<T>,
104    data: *mut [u8],
105}
106
107impl<T: Wire<'static>> core::fmt::Debug for Yoke<T>
108where for<'a> T::Type<'a>: core::fmt::Debug
109{
110    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
111        <T::Type<'_> as core::fmt::Debug>::fmt(self.get(), f)
112    }
113}
114
115impl<T: Wire<'static>> Drop for Yoke<T> {
116    fn drop(&mut self) {
117        // SAFETY: data comes from into_raw and has been used linearly since then.
118        drop(unsafe { Box::from_raw(self.data) });
119    }
120}
121
122impl<T: Wire<'static>> Yoke<T> {
123    fn take(self) -> (T, *mut [u8]) {
124        let this = ManuallyDrop::new(self);
125        (unsafe { this.value.assume_init_read() }, this.data)
126    }
127}
128
129impl<T: Wire<'static>> Yoke<T> {
130    pub fn get(&self) -> &<T as internal::Wire<'static>>::Type<'_> {
131        // SAFETY: We only read from value which borrows from data.
132        unsafe { core::mem::transmute(&self.value) }
133    }
134
135    pub fn map<S: Wire<'static>, F: FnOnce(T) -> S>(self, f: F) -> Yoke<S> {
136        let (value, data) = self.take();
137        Yoke { value: MaybeUninit::new(f(value)), data }
138    }
139
140    pub fn try_map<S: Wire<'static>, E, F: FnOnce(T) -> Result<S, E>>(
141        self, f: F,
142    ) -> Result<Yoke<S>, E> {
143        let (value, data) = self.take();
144        Ok(Yoke { value: MaybeUninit::new(f(value)?), data })
145    }
146}
147
148pub fn decode_yoke<T: Wire<'static>>(data: Box<[u8]>) -> Result<Yoke<T>, Error> {
149    let data = Box::into_raw(data);
150    // SAFETY: decode does not leak its input in other ways than in its result.
151    let value = MaybeUninit::new(decode::<T>(unsafe { &*data })?);
152    Ok(Yoke { value, data })
153}
154
155macro_rules! impl_builtin {
156    ($t:tt $T:tt $encode:tt $decode:tt) => {
157        impl<'a> internal::Wire<'a> for $t {
158            type Type<'b> = $t;
159            #[cfg(feature = "schema")]
160            fn schema(rules: &mut Rules) {
161                rules.builtin::<Self::Type<'static>>(Builtin::$T);
162            }
163            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
164                Ok(helper::$encode(*self, writer))
165            }
166            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
167                helper::$decode(reader)
168            }
169        }
170    };
171}
172impl_builtin!(bool Bool encode_byte decode_byte);
173impl_builtin!(u8 U8 encode_byte decode_byte);
174impl_builtin!(i8 I8 encode_byte decode_byte);
175impl_builtin!(u16 U16 encode_varint decode_varint);
176impl_builtin!(i16 I16 encode_zigzag decode_zigzag);
177impl_builtin!(u32 U32 encode_varint decode_varint);
178impl_builtin!(i32 I32 encode_zigzag decode_zigzag);
179impl_builtin!(u64 U64 encode_varint decode_varint);
180impl_builtin!(i64 I64 encode_zigzag decode_zigzag);
181impl_builtin!(usize Usize encode_varint decode_varint);
182impl_builtin!(isize Isize encode_zigzag decode_zigzag);
183
184impl<'a> internal::Wire<'a> for () {
185    type Type<'b> = ();
186    #[cfg(feature = "schema")]
187    fn schema(rules: &mut Rules) {
188        if rules.struct_::<Self::Type<'static>>(Vec::new()) {}
189    }
190    fn encode(&self, _writer: &mut Writer<'a>) -> Result<(), Error> {
191        Ok(())
192    }
193    fn decode(_reader: &mut Reader<'a>) -> Result<Self, Error> {
194        Ok(())
195    }
196}
197
198macro_rules! impl_tuple {
199    (($($i:tt $t:tt),*), $n:tt) => {
200        impl<'a, $($t: Wire<'a>),*> internal::Wire<'a> for ($($t),*) {
201            type Type<'b> = ($($t::Type<'b>),*);
202            #[cfg(feature = "schema")]
203            fn schema(rules: &mut Rules) {
204                let mut fields = Vec::with_capacity($n);
205                $(fields.push((None, internal::type_id::<$t>()));)*
206                if rules.struct_::<Self::Type<'static>>(fields) {
207                    $(<$t>::schema(rules);)*
208                }
209            }
210            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
211                $(self.$i.encode(writer)?;)*
212                Ok(())
213            }
214            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
215                Ok(($(<$t>::decode(reader)?),*))
216            }
217        }
218    };
219}
220impl_tuple!((0 T, 1 S), 2);
221impl_tuple!((0 T, 1 S, 2 R), 3);
222impl_tuple!((0 T, 1 S, 2 R, 3 Q), 4);
223impl_tuple!((0 T, 1 S, 2 R, 3 Q, 4 P), 5);
224
225impl<'a, const N: usize> internal::Wire<'a> for &'a [u8; N] {
226    type Type<'b> = &'b [u8; N];
227    #[cfg(feature = "schema")]
228    fn schema(rules: &mut Rules) {
229        rules.array::<Self::Type<'static>, u8>(N);
230    }
231    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
232        Ok(writer.put_share(*self))
233    }
234    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
235        Ok(reader.get(N)?.try_into().unwrap())
236    }
237}
238
239impl<'a> internal::Wire<'a> for &'a [u8] {
240    type Type<'b> = &'b [u8];
241    #[cfg(feature = "schema")]
242    fn schema(rules: &mut Rules) {
243        rules.slice::<Self::Type<'static>, u8>();
244    }
245    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
246        helper::encode_length(self.len(), writer)?;
247        writer.put_share(self);
248        Ok(())
249    }
250    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
251        let len = helper::decode_length(reader)?;
252        reader.get(len)
253    }
254}
255
256impl<'a> internal::Wire<'a> for &'a str {
257    type Type<'b> = &'b str;
258    #[cfg(feature = "schema")]
259    fn schema(rules: &mut Rules) {
260        rules.alias::<Self::Type<'static>, &[u8]>();
261    }
262    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
263        self.as_bytes().encode(writer)
264    }
265    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
266        let bytes = <&[u8]>::decode(reader)?;
267        core::str::from_utf8(bytes).map_err(|_| Error::user(Code::InvalidArgument))
268    }
269}
270
271impl<'a> internal::Wire<'a> for String {
272    type Type<'b> = String;
273    #[cfg(feature = "schema")]
274    fn schema(rules: &mut Rules) {
275        rules.alias::<Self::Type<'static>, Vec<u8>>();
276    }
277    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
278        helper::encode_length(self.len(), writer)?;
279        writer.put_copy(self.as_bytes());
280        Ok(())
281    }
282    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
283        let len = helper::decode_length(reader)?;
284        let bytes = reader.get(len)?.to_vec();
285        String::from_utf8(bytes).map_err(|_| Error::user(Code::InvalidArgument))
286    }
287}
288
289impl<'a, T: ?Sized + ToOwned + 'static> internal::Wire<'a> for Cow<'a, T>
290where
291    for<'b> &'b T: Wire<'b>,
292    for<'b> T::Owned: Wire<'b>,
293{
294    type Type<'b> = Cow<'b, T>;
295    #[cfg(feature = "schema")]
296    fn schema(rules: &mut Rules) {
297        rules.alias::<Self::Type<'static>, &T>();
298        internal::schema::<T::Owned>(rules);
299        assert!(
300            rules.get(TypeId::of::<&T>()) == rules.get(TypeId::of::<T::Owned>()),
301            "Cow<{}> does not have the same wire format for its borrowed and owned variants",
302            core::any::type_name::<T>(),
303        );
304    }
305    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
306        match self {
307            Cow::Borrowed(x) => x.encode(writer),
308            Cow::Owned(x) => x.encode(writer),
309        }
310    }
311    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
312        Ok(Cow::Borrowed(<&T>::decode(reader)?))
313    }
314}
315
316impl<'a, T: Wire<'a>, const N: usize> internal::Wire<'a> for [T; N] {
317    type Type<'b> = [T::Type<'b>; N];
318    #[cfg(feature = "schema")]
319    fn schema(rules: &mut Rules) {
320        rules.array::<Self::Type<'static>, T::Type<'static>>(N);
321    }
322    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
323        helper::encode_array(self, writer, T::encode)
324    }
325    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
326        helper::decode_array(reader, T::decode)
327    }
328}
329
330impl<'a, T: Wire<'a>> internal::Wire<'a> for Vec<T> {
331    type Type<'b> = Vec<T::Type<'b>>;
332    #[cfg(feature = "schema")]
333    fn schema(rules: &mut Rules) {
334        rules.slice::<Self::Type<'static>, T::Type<'static>>();
335    }
336    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
337        helper::encode_slice(self, writer, T::encode)
338    }
339    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
340        helper::decode_slice(reader, T::decode)
341    }
342}
343
344impl<'a, T: Wire<'a>> internal::Wire<'a> for Box<T> {
345    type Type<'b> = Box<T::Type<'b>>;
346    #[cfg(feature = "schema")]
347    fn schema(rules: &mut Rules) {
348        let mut fields = Vec::with_capacity(1);
349        fields.push((None, internal::type_id::<T>()));
350        if rules.struct_::<Self::Type<'static>>(fields) {
351            internal::schema::<T>(rules);
352        }
353    }
354    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
355        T::encode(self, writer)
356    }
357    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
358        Ok(Box::new(T::decode(reader)?))
359    }
360}
361
362impl<'a> internal::Wire<'a> for Error {
363    type Type<'b> = Error;
364    #[cfg(feature = "schema")]
365    fn schema(rules: &mut Rules) {
366        let mut fields = Vec::with_capacity(2);
367        fields.push((Some("space"), internal::type_id::<u8>()));
368        fields.push((Some("code"), internal::type_id::<u16>()));
369        if rules.struct_::<Self::Type<'static>>(fields) {
370            internal::schema::<u8>(rules);
371            internal::schema::<u16>(rules);
372        }
373    }
374    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
375        self.space().encode(writer)?;
376        self.code().encode(writer)
377    }
378    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
379        let space = u8::decode(reader)?;
380        let code = u16::decode(reader)?;
381        Ok(Error::new(space, code))
382    }
383}
384
385impl<'a> internal::Wire<'a> for ! {
386    type Type<'b> = !;
387    #[cfg(feature = "schema")]
388    fn schema(rules: &mut Rules) {
389        if rules.enum_::<Self::Type<'static>>(Vec::new()) {}
390    }
391    fn encode(&self, _: &mut Writer<'a>) -> Result<(), Error> {
392        match *self {}
393    }
394    fn decode(_: &mut Reader<'a>) -> Result<Self, Error> {
395        Err(Error::user(Code::InvalidArgument))
396    }
397}
398
399#[internal_wire]
400#[wire(crate = crate)]
401enum Infallible {}
402
403#[internal_wire]
404#[wire(crate = crate, where = T: Wire<'wire>)]
405enum Option<T> {
406    None,
407    Some(T),
408}
409
410#[internal_wire]
411#[wire(crate = crate, where = T: Wire<'wire>, where = E: Wire<'wire>)]
412enum Result<T, E> {
413    Ok(T),
414    Err(E),
415}
416
417#[internal_wire]
418#[wire(crate = crate)]
419enum Side {
420    A,
421    B,
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::schema::View;
428
429    #[test]
430    fn encode_varint() {
431        #[track_caller]
432        fn test<T: Wire<'static>>(value: T, expected: &[u8]) {
433            assert_eq!(&encode(&value).unwrap()[..], expected);
434        }
435        test::<u16>(0x00, &[0x00]);
436        test::<u16>(0x01, &[0x01]);
437        test::<u16>(0x7f, &[0x7f]);
438        test::<u16>(0x80, &[0x80, 0x01]);
439        test::<u16>(0xff, &[0xff, 0x01]);
440        test::<u16>(0xfffe, &[0xfe, 0xff, 0x03]);
441        test::<i16>(0, &[0x00]);
442        test::<i16>(-1, &[0x01]);
443        test::<i16>(1, &[0x02]);
444        test::<i16>(-2, &[0x03]);
445        test::<i16>(i16::MAX, &[0xfe, 0xff, 0x03]);
446        test::<i16>(i16::MIN, &[0xff, 0xff, 0x03]);
447    }
448
449    #[test]
450    fn decode_varint() {
451        #[track_caller]
452        fn test<T: Wire<'static> + Eq + std::fmt::Debug>(data: &'static [u8], expected: Option<T>) {
453            assert_eq!(decode(data).ok(), expected);
454        }
455        test::<u16>(&[0x00], Some(0x00));
456        test::<u16>(&[0x01], Some(0x01));
457        test::<u16>(&[0x7f], Some(0x7f));
458        test::<u16>(&[0x80, 0x01], Some(0x80));
459        test::<u16>(&[0xff, 0x01], Some(0xff));
460        test::<u16>(&[0xfe, 0xff, 0x03], Some(0xfffe));
461        test::<u16>(&[0xfe, 0x00], None);
462        test::<u16>(&[0xfe, 0xff, 0x00], None);
463        test::<u16>(&[0xfe, 0xff, 0x04], None);
464        test::<u16>(&[0xfe, 0xff, 0x40], None);
465        test::<u16>(&[0xfe, 0xff, 0x80], None);
466        test::<u16>(&[0xfe, 0xff, 0x80, 0x01], None);
467        test::<i16>(&[0x00], Some(0));
468        test::<i16>(&[0x01], Some(-1));
469        test::<i16>(&[0x02], Some(1));
470        test::<i16>(&[0x03], Some(-2));
471        test::<i16>(&[0xfe, 0xff, 0x03], Some(i16::MAX));
472        test::<i16>(&[0xff, 0xff, 0x03], Some(i16::MIN));
473    }
474
475    #[track_caller]
476    fn assert_schema<'a, T: Wire<'a>>(expected: &str) {
477        let x = View::new::<T>();
478        assert_eq!(std::format!("{x}"), expected);
479    }
480
481    #[test]
482    fn display_schema() {
483        assert_schema::<bool>("bool");
484        assert_schema::<u8>("u8");
485        assert_schema::<&str>("[u8]");
486        assert_schema::<Result<&str, &[u8]>>("{Ok=0:[u8] Err=1:[u8]}");
487        assert_schema::<Option<[u8; 42]>>("{None=0:() Some=1:[u8; 42]}");
488    }
489
490    #[test]
491    fn derive_struct() {
492        #[derive(Wire)]
493        #[wire(crate = crate)]
494        struct Foo1 {
495            bar: u8,
496            baz: u32,
497        }
498        assert_schema::<Foo1>("(bar:u8 baz:u32)");
499
500        #[derive(Wire)]
501        #[wire(crate = crate)]
502        struct Foo2<'a> {
503            bar: Cow<'a, str>,
504            baz: Option<&'a [u8]>,
505        }
506        assert_schema::<Foo2>("(bar:[u8] baz:{None=0:() Some=1:[u8]})");
507    }
508
509    #[test]
510    fn derive_enum() {
511        #[derive(Wire)]
512        #[wire(crate = crate)]
513        enum Foo1 {
514            Bar,
515            Baz(u32),
516        }
517        assert_schema::<Foo1>("{Bar=0:() Baz=1:u32}");
518
519        #[derive(Wire)]
520        #[wire(crate = crate)]
521        enum Foo2<'a> {
522            #[wire(tag = 1)]
523            Bar(Cow<'a, [u8]>),
524            #[wire(tag = 0)]
525            Baz((), Option<&'a [u8]>),
526        }
527        assert_schema::<Foo2>("{Bar=1:[u8] Baz=0:{None=0:() Some=1:[u8]}}");
528    }
529
530    #[test]
531    fn recursive_view() {
532        #[derive(Debug, Wire, PartialEq, Eq)]
533        #[wire(crate = crate)]
534        enum List {
535            Nil,
536            Cons(u8, Box<List>),
537        }
538        assert_schema::<List>("<1>:{Nil=0:() Cons=1:(u8 <1>)}");
539        let value = List::Cons(13, Box::new(List::Cons(42, Box::new(List::Nil))));
540        let data = encode(&value).unwrap();
541        let view = View::new::<List>();
542        assert!(view.validate(&data).is_ok());
543        assert_eq!(decode::<List>(&data).unwrap(), value);
544    }
545
546    #[test]
547    fn yoke() {
548        type T = Result<&'static [u8], ()>;
549        let value: T = Ok(b"hello");
550        let data = encode(&value).unwrap();
551        let yoke = decode_yoke::<T>(data).unwrap();
552        let bytes = yoke.try_map(|x| x).unwrap();
553        assert_eq!(bytes.get(), b"hello");
554    }
555
556    #[test]
557    fn cow() {
558        #[track_caller]
559        fn test<T: ?Sized + ToOwned + Eq + core::fmt::Debug + 'static>(data: &'static T)
560        where
561            for<'b> &'b T: Wire<'b>,
562            for<'b> T::Owned: Wire<'b>,
563            T::Owned: core::fmt::Debug,
564        {
565            let encoded = encode(&Cow::Borrowed(data)).unwrap();
566            assert_eq!(encode(&Cow::Owned(data.to_owned())).unwrap(), encoded);
567            let decoded = decode::<Cow<T>>(&encoded).unwrap();
568            assert_eq!(decoded, Cow::Borrowed(data));
569        }
570        test("hello");
571        test(b"hello");
572    }
573}