1#![no_std]
38#![feature(array_try_from_fn)]
39#![feature(doc_cfg)]
40#![feature(never_type)]
41#![feature(try_blocks)]
42
43extern crate alloc;
44#[cfg(feature = "std")]
45extern crate std;
46
47use alloc::borrow::{Cow, ToOwned};
48use alloc::boxed::Box;
49use alloc::string::String;
50use alloc::vec::Vec;
51#[cfg(feature = "schema")]
52use core::any::TypeId;
53use core::convert::Infallible;
54use core::mem::{ManuallyDrop, MaybeUninit};
55
56use wasefire_common::platform::Side;
57use wasefire_error::{Code, Error};
58pub use wasefire_wire_derive::Wire;
59use wasefire_wire_derive::internal_wire;
60
61#[cfg(feature = "schema")]
62use crate::internal::{Builtin, Rules};
63use crate::reader::Reader;
64use crate::writer::Writer;
65
66mod helper;
67pub mod internal;
68mod reader;
69#[cfg(feature = "schema")]
70pub mod schema;
71mod writer;
72
73pub trait Wire<'a>: internal::Wire<'a> {}
74impl<'a, T: internal::Wire<'a>> Wire<'a> for T {}
75
76pub fn encode_suffix<'a, T: Wire<'a>>(data: &mut Vec<u8>, value: &T) -> Result<(), Error> {
77 let mut writer = Writer::new();
78 value.encode(&mut writer)?;
79 Ok(writer.finalize(data))
80}
81
82pub fn encode<'a, T: Wire<'a>>(value: &T) -> Result<Box<[u8]>, Error> {
83 let mut data = Vec::new();
84 encode_suffix(&mut data, value)?;
85 Ok(data.into_boxed_slice())
86}
87
88pub fn decode_prefix<'a, T: Wire<'a>>(data: &mut &'a [u8]) -> Result<T, Error> {
89 let mut reader = Reader::new(data);
90 let value = T::decode(&mut reader)?;
91 *data = reader.finalize();
92 Ok(value)
93}
94
95pub fn decode<'a, T: Wire<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
96 let value = decode_prefix(&mut data)?;
97 Error::user(Code::InvalidLength).check(data.is_empty())?;
98 Ok(value)
99}
100
101pub struct Yoke<T: Wire<'static>> {
102 value: MaybeUninit<T>,
104 data: *mut [u8],
105}
106
107impl<T: Wire<'static>> core::fmt::Debug for Yoke<T>
108where for<'a> T::Type<'a>: core::fmt::Debug
109{
110 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
111 <T::Type<'_> as core::fmt::Debug>::fmt(self.get(), f)
112 }
113}
114
115impl<T: Wire<'static>> Drop for Yoke<T> {
116 fn drop(&mut self) {
117 drop(unsafe { Box::from_raw(self.data) });
119 }
120}
121
122impl<T: Wire<'static>> Yoke<T> {
123 fn take(self) -> (T, *mut [u8]) {
124 let this = ManuallyDrop::new(self);
125 (unsafe { this.value.assume_init_read() }, this.data)
126 }
127}
128
129impl<T: Wire<'static>> Yoke<T> {
130 pub fn get(&self) -> &<T as internal::Wire<'static>>::Type<'_> {
131 unsafe { core::mem::transmute(&self.value) }
133 }
134
135 pub fn map<S: Wire<'static>, F: FnOnce(T) -> S>(self, f: F) -> Yoke<S> {
136 let (value, data) = self.take();
137 Yoke { value: MaybeUninit::new(f(value)), data }
138 }
139
140 pub fn try_map<S: Wire<'static>, E, F: FnOnce(T) -> Result<S, E>>(
141 self, f: F,
142 ) -> Result<Yoke<S>, E> {
143 let (value, data) = self.take();
144 Ok(Yoke { value: MaybeUninit::new(f(value)?), data })
145 }
146}
147
148pub fn decode_yoke<T: Wire<'static>>(data: Box<[u8]>) -> Result<Yoke<T>, Error> {
149 let data = Box::into_raw(data);
150 let value = MaybeUninit::new(decode::<T>(unsafe { &*data })?);
152 Ok(Yoke { value, data })
153}
154
155macro_rules! impl_builtin {
156 ($t:tt $T:tt $encode:tt $decode:tt) => {
157 impl<'a> internal::Wire<'a> for $t {
158 type Type<'b> = $t;
159 #[cfg(feature = "schema")]
160 fn schema(rules: &mut Rules) {
161 rules.builtin::<Self::Type<'static>>(Builtin::$T);
162 }
163 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
164 Ok(helper::$encode(*self, writer))
165 }
166 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
167 helper::$decode(reader)
168 }
169 }
170 };
171}
172impl_builtin!(bool Bool encode_byte decode_byte);
173impl_builtin!(u8 U8 encode_byte decode_byte);
174impl_builtin!(i8 I8 encode_byte decode_byte);
175impl_builtin!(u16 U16 encode_varint decode_varint);
176impl_builtin!(i16 I16 encode_zigzag decode_zigzag);
177impl_builtin!(u32 U32 encode_varint decode_varint);
178impl_builtin!(i32 I32 encode_zigzag decode_zigzag);
179impl_builtin!(u64 U64 encode_varint decode_varint);
180impl_builtin!(i64 I64 encode_zigzag decode_zigzag);
181impl_builtin!(usize Usize encode_varint decode_varint);
182impl_builtin!(isize Isize encode_zigzag decode_zigzag);
183
184impl<'a> internal::Wire<'a> for () {
185 type Type<'b> = ();
186 #[cfg(feature = "schema")]
187 fn schema(rules: &mut Rules) {
188 if rules.struct_::<Self::Type<'static>>(Vec::new()) {}
189 }
190 fn encode(&self, _writer: &mut Writer<'a>) -> Result<(), Error> {
191 Ok(())
192 }
193 fn decode(_reader: &mut Reader<'a>) -> Result<Self, Error> {
194 Ok(())
195 }
196}
197
198macro_rules! impl_tuple {
199 (($($i:tt $t:tt),*), $n:tt) => {
200 impl<'a, $($t: Wire<'a>),*> internal::Wire<'a> for ($($t),*) {
201 type Type<'b> = ($($t::Type<'b>),*);
202 #[cfg(feature = "schema")]
203 fn schema(rules: &mut Rules) {
204 let mut fields = Vec::with_capacity($n);
205 $(fields.push((None, internal::type_id::<$t>()));)*
206 if rules.struct_::<Self::Type<'static>>(fields) {
207 $(<$t>::schema(rules);)*
208 }
209 }
210 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
211 $(self.$i.encode(writer)?;)*
212 Ok(())
213 }
214 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
215 Ok(($(<$t>::decode(reader)?),*))
216 }
217 }
218 };
219}
220impl_tuple!((0 T, 1 S), 2);
221impl_tuple!((0 T, 1 S, 2 R), 3);
222impl_tuple!((0 T, 1 S, 2 R, 3 Q), 4);
223impl_tuple!((0 T, 1 S, 2 R, 3 Q, 4 P), 5);
224
225impl<'a, const N: usize> internal::Wire<'a> for &'a [u8; N] {
226 type Type<'b> = &'b [u8; N];
227 #[cfg(feature = "schema")]
228 fn schema(rules: &mut Rules) {
229 rules.array::<Self::Type<'static>, u8>(N);
230 }
231 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
232 Ok(writer.put_share(*self))
233 }
234 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
235 Ok(reader.get(N)?.try_into().unwrap())
236 }
237}
238
239impl<'a> internal::Wire<'a> for &'a [u8] {
240 type Type<'b> = &'b [u8];
241 #[cfg(feature = "schema")]
242 fn schema(rules: &mut Rules) {
243 rules.slice::<Self::Type<'static>, u8>();
244 }
245 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
246 helper::encode_length(self.len(), writer)?;
247 writer.put_share(self);
248 Ok(())
249 }
250 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
251 let len = helper::decode_length(reader)?;
252 reader.get(len)
253 }
254}
255
256impl<'a> internal::Wire<'a> for &'a str {
257 type Type<'b> = &'b str;
258 #[cfg(feature = "schema")]
259 fn schema(rules: &mut Rules) {
260 rules.alias::<Self::Type<'static>, &[u8]>();
261 }
262 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
263 self.as_bytes().encode(writer)
264 }
265 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
266 let bytes = <&[u8]>::decode(reader)?;
267 core::str::from_utf8(bytes).map_err(|_| Error::user(Code::InvalidArgument))
268 }
269}
270
271impl<'a> internal::Wire<'a> for String {
272 type Type<'b> = String;
273 #[cfg(feature = "schema")]
274 fn schema(rules: &mut Rules) {
275 rules.alias::<Self::Type<'static>, Vec<u8>>();
276 }
277 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
278 helper::encode_length(self.len(), writer)?;
279 writer.put_copy(self.as_bytes());
280 Ok(())
281 }
282 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
283 let len = helper::decode_length(reader)?;
284 let bytes = reader.get(len)?.to_vec();
285 String::from_utf8(bytes).map_err(|_| Error::user(Code::InvalidArgument))
286 }
287}
288
289impl<'a, T: ?Sized + ToOwned + 'static> internal::Wire<'a> for Cow<'a, T>
290where
291 for<'b> &'b T: Wire<'b>,
292 for<'b> T::Owned: Wire<'b>,
293{
294 type Type<'b> = Cow<'b, T>;
295 #[cfg(feature = "schema")]
296 fn schema(rules: &mut Rules) {
297 rules.alias::<Self::Type<'static>, &T>();
298 internal::schema::<T::Owned>(rules);
299 assert!(
300 rules.get(TypeId::of::<&T>()) == rules.get(TypeId::of::<T::Owned>()),
301 "Cow<{}> does not have the same wire format for its borrowed and owned variants",
302 core::any::type_name::<T>(),
303 );
304 }
305 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
306 match self {
307 Cow::Borrowed(x) => x.encode(writer),
308 Cow::Owned(x) => x.encode(writer),
309 }
310 }
311 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
312 Ok(Cow::Borrowed(<&T>::decode(reader)?))
313 }
314}
315
316impl<'a, T: Wire<'a>, const N: usize> internal::Wire<'a> for [T; N] {
317 type Type<'b> = [T::Type<'b>; N];
318 #[cfg(feature = "schema")]
319 fn schema(rules: &mut Rules) {
320 rules.array::<Self::Type<'static>, T::Type<'static>>(N);
321 }
322 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
323 helper::encode_array(self, writer, T::encode)
324 }
325 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
326 helper::decode_array(reader, T::decode)
327 }
328}
329
330impl<'a, T: Wire<'a>> internal::Wire<'a> for Vec<T> {
331 type Type<'b> = Vec<T::Type<'b>>;
332 #[cfg(feature = "schema")]
333 fn schema(rules: &mut Rules) {
334 rules.slice::<Self::Type<'static>, T::Type<'static>>();
335 }
336 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
337 helper::encode_slice(self, writer, T::encode)
338 }
339 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
340 helper::decode_slice(reader, T::decode)
341 }
342}
343
344impl<'a, T: Wire<'a>> internal::Wire<'a> for Box<T> {
345 type Type<'b> = Box<T::Type<'b>>;
346 #[cfg(feature = "schema")]
347 fn schema(rules: &mut Rules) {
348 let mut fields = Vec::with_capacity(1);
349 fields.push((None, internal::type_id::<T>()));
350 if rules.struct_::<Self::Type<'static>>(fields) {
351 internal::schema::<T>(rules);
352 }
353 }
354 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
355 T::encode(self, writer)
356 }
357 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
358 Ok(Box::new(T::decode(reader)?))
359 }
360}
361
362impl<'a> internal::Wire<'a> for Error {
363 type Type<'b> = Error;
364 #[cfg(feature = "schema")]
365 fn schema(rules: &mut Rules) {
366 let mut fields = Vec::with_capacity(2);
367 fields.push((Some("space"), internal::type_id::<u8>()));
368 fields.push((Some("code"), internal::type_id::<u16>()));
369 if rules.struct_::<Self::Type<'static>>(fields) {
370 internal::schema::<u8>(rules);
371 internal::schema::<u16>(rules);
372 }
373 }
374 fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
375 self.space().encode(writer)?;
376 self.code().encode(writer)
377 }
378 fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
379 let space = u8::decode(reader)?;
380 let code = u16::decode(reader)?;
381 Ok(Error::new(space, code))
382 }
383}
384
385impl<'a> internal::Wire<'a> for ! {
386 type Type<'b> = !;
387 #[cfg(feature = "schema")]
388 fn schema(rules: &mut Rules) {
389 if rules.enum_::<Self::Type<'static>>(Vec::new()) {}
390 }
391 fn encode(&self, _: &mut Writer<'a>) -> Result<(), Error> {
392 match *self {}
393 }
394 fn decode(_: &mut Reader<'a>) -> Result<Self, Error> {
395 Err(Error::user(Code::InvalidArgument))
396 }
397}
398
399#[internal_wire]
400#[wire(crate = crate)]
401enum Infallible {}
402
403#[internal_wire]
404#[wire(crate = crate, where = T: Wire<'wire>)]
405enum Option<T> {
406 None,
407 Some(T),
408}
409
410#[internal_wire]
411#[wire(crate = crate, where = T: Wire<'wire>, where = E: Wire<'wire>)]
412enum Result<T, E> {
413 Ok(T),
414 Err(E),
415}
416
417#[internal_wire]
418#[wire(crate = crate)]
419enum Side {
420 A,
421 B,
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::schema::View;
428
429 #[test]
430 fn encode_varint() {
431 #[track_caller]
432 fn test<T: Wire<'static>>(value: T, expected: &[u8]) {
433 assert_eq!(&encode(&value).unwrap()[..], expected);
434 }
435 test::<u16>(0x00, &[0x00]);
436 test::<u16>(0x01, &[0x01]);
437 test::<u16>(0x7f, &[0x7f]);
438 test::<u16>(0x80, &[0x80, 0x01]);
439 test::<u16>(0xff, &[0xff, 0x01]);
440 test::<u16>(0xfffe, &[0xfe, 0xff, 0x03]);
441 test::<i16>(0, &[0x00]);
442 test::<i16>(-1, &[0x01]);
443 test::<i16>(1, &[0x02]);
444 test::<i16>(-2, &[0x03]);
445 test::<i16>(i16::MAX, &[0xfe, 0xff, 0x03]);
446 test::<i16>(i16::MIN, &[0xff, 0xff, 0x03]);
447 }
448
449 #[test]
450 fn decode_varint() {
451 #[track_caller]
452 fn test<T: Wire<'static> + Eq + std::fmt::Debug>(data: &'static [u8], expected: Option<T>) {
453 assert_eq!(decode(data).ok(), expected);
454 }
455 test::<u16>(&[0x00], Some(0x00));
456 test::<u16>(&[0x01], Some(0x01));
457 test::<u16>(&[0x7f], Some(0x7f));
458 test::<u16>(&[0x80, 0x01], Some(0x80));
459 test::<u16>(&[0xff, 0x01], Some(0xff));
460 test::<u16>(&[0xfe, 0xff, 0x03], Some(0xfffe));
461 test::<u16>(&[0xfe, 0x00], None);
462 test::<u16>(&[0xfe, 0xff, 0x00], None);
463 test::<u16>(&[0xfe, 0xff, 0x04], None);
464 test::<u16>(&[0xfe, 0xff, 0x40], None);
465 test::<u16>(&[0xfe, 0xff, 0x80], None);
466 test::<u16>(&[0xfe, 0xff, 0x80, 0x01], None);
467 test::<i16>(&[0x00], Some(0));
468 test::<i16>(&[0x01], Some(-1));
469 test::<i16>(&[0x02], Some(1));
470 test::<i16>(&[0x03], Some(-2));
471 test::<i16>(&[0xfe, 0xff, 0x03], Some(i16::MAX));
472 test::<i16>(&[0xff, 0xff, 0x03], Some(i16::MIN));
473 }
474
475 #[track_caller]
476 fn assert_schema<'a, T: Wire<'a>>(expected: &str) {
477 let x = View::new::<T>();
478 assert_eq!(std::format!("{x}"), expected);
479 }
480
481 #[test]
482 fn display_schema() {
483 assert_schema::<bool>("bool");
484 assert_schema::<u8>("u8");
485 assert_schema::<&str>("[u8]");
486 assert_schema::<Result<&str, &[u8]>>("{Ok=0:[u8] Err=1:[u8]}");
487 assert_schema::<Option<[u8; 42]>>("{None=0:() Some=1:[u8; 42]}");
488 }
489
490 #[test]
491 fn derive_struct() {
492 #[derive(Wire)]
493 #[wire(crate = crate)]
494 struct Foo1 {
495 bar: u8,
496 baz: u32,
497 }
498 assert_schema::<Foo1>("(bar:u8 baz:u32)");
499
500 #[derive(Wire)]
501 #[wire(crate = crate)]
502 struct Foo2<'a> {
503 bar: Cow<'a, str>,
504 baz: Option<&'a [u8]>,
505 }
506 assert_schema::<Foo2>("(bar:[u8] baz:{None=0:() Some=1:[u8]})");
507 }
508
509 #[test]
510 fn derive_enum() {
511 #[derive(Wire)]
512 #[wire(crate = crate)]
513 enum Foo1 {
514 Bar,
515 Baz(u32),
516 }
517 assert_schema::<Foo1>("{Bar=0:() Baz=1:u32}");
518
519 #[derive(Wire)]
520 #[wire(crate = crate)]
521 enum Foo2<'a> {
522 #[wire(tag = 1)]
523 Bar(Cow<'a, [u8]>),
524 #[wire(tag = 0)]
525 Baz((), Option<&'a [u8]>),
526 }
527 assert_schema::<Foo2>("{Bar=1:[u8] Baz=0:{None=0:() Some=1:[u8]}}");
528 }
529
530 #[test]
531 fn recursive_view() {
532 #[derive(Debug, Wire, PartialEq, Eq)]
533 #[wire(crate = crate)]
534 enum List {
535 Nil,
536 Cons(u8, Box<List>),
537 }
538 assert_schema::<List>("<1>:{Nil=0:() Cons=1:(u8 <1>)}");
539 let value = List::Cons(13, Box::new(List::Cons(42, Box::new(List::Nil))));
540 let data = encode(&value).unwrap();
541 let view = View::new::<List>();
542 assert!(view.validate(&data).is_ok());
543 assert_eq!(decode::<List>(&data).unwrap(), value);
544 }
545
546 #[test]
547 fn yoke() {
548 type T = Result<&'static [u8], ()>;
549 let value: T = Ok(b"hello");
550 let data = encode(&value).unwrap();
551 let yoke = decode_yoke::<T>(data).unwrap();
552 let bytes = yoke.try_map(|x| x).unwrap();
553 assert_eq!(bytes.get(), b"hello");
554 }
555
556 #[test]
557 fn cow() {
558 #[track_caller]
559 fn test<T: ?Sized + ToOwned + Eq + core::fmt::Debug + 'static>(data: &'static T)
560 where
561 for<'b> &'b T: Wire<'b>,
562 for<'b> T::Owned: Wire<'b>,
563 T::Owned: core::fmt::Debug,
564 {
565 let encoded = encode(&Cow::Borrowed(data)).unwrap();
566 assert_eq!(encode(&Cow::Owned(data.to_owned())).unwrap(), encoded);
567 let decoded = decode::<Cow<T>>(&encoded).unwrap();
568 assert_eq!(decoded, Cow::Borrowed(data));
569 }
570 test("hello");
571 test(b"hello");
572 }
573}