wick_packet/
packet.rs

1use serde::de::DeserializeOwned;
2use serde::{Deserialize, Serialize};
3use tokio_stream::Stream;
4use wasmrs::{BoxFlux, Metadata, Payload, PayloadError, RawPayload};
5use wasmrs_runtime::ConditionallySend;
6use wick_interface_types::Type;
7
8use crate::metadata::DONE_FLAG;
9use crate::wrapped_type::coerce;
10use crate::{Base64Bytes, Error, PacketStream, TypeWrapper, WickMetadata, CLOSE_BRACKET, OPEN_BRACKET};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[must_use]
14pub struct Packet {
15  pub(crate) metadata: Metadata,
16  pub(crate) extra: WickMetadata,
17  pub payload: PacketPayload,
18}
19
20impl PartialEq for Packet {
21  fn eq(&self, other: &Self) -> bool {
22    if self.metadata.index != other.metadata.index || !self.metadata.extra.eq(&other.metadata.extra) {
23      return false;
24    }
25    if self.extra.ne(&other.extra) {
26      return false;
27    }
28    self.payload == other.payload
29  }
30}
31
32impl Packet {
33  /// The port name that indicates a component-wide fatal error.
34  pub const FATAL_ERROR: &str = "<error>";
35  pub const NO_INPUT: &str = "<>";
36
37  /// Create a new packet for the given port with a raw [PacketPayload], wasmRS [Metadata], and [WickMetadata].
38  pub const fn new_raw(payload: PacketPayload, wasmrs: Metadata, metadata: WickMetadata) -> Self {
39    Self {
40      payload,
41      metadata: wasmrs,
42      extra: metadata,
43    }
44  }
45
46  /// Create a new packet for the given port with a raw [PacketPayload] value and given flags.
47  pub fn new_for_port<T: Into<String>>(port: T, payload: PacketPayload, flags: u8) -> Self {
48    let md = Metadata::new(0);
49    let wmd = WickMetadata::new(port, flags);
50    Self {
51      payload,
52      metadata: md,
53      extra: wmd,
54    }
55  }
56
57  pub fn no_input() -> Self {
58    Self::encode(Self::NO_INPUT, ())
59  }
60
61  /// Create a new fatal error packet for the component.
62  pub fn component_error<T: Into<String>>(err: T) -> Self {
63    Self::new_for_port(Self::FATAL_ERROR, PacketPayload::fatal_error(err), 0)
64  }
65
66  /// Create a new success packet for the given port with a raw [RawPayload] value.
67  pub fn ok<T: Into<String>>(port: T, payload: RawPayload) -> Self {
68    Self::new_for_port(port, PacketPayload::Ok(payload.data.map(Into::into)), 0)
69  }
70
71  /// Create a new error packet for the given port with a raw [PacketError] value.
72  pub fn raw_err<T: Into<String>>(port: T, payload: PacketError) -> Self {
73    Self::new_for_port(port, PacketPayload::Err(payload), 0)
74  }
75
76  /// Create a new error packet for the given port.
77  pub fn err<T: Into<String>, E: Into<String>>(port: T, msg: E) -> Self {
78    Self::new_for_port(port, PacketPayload::Err(PacketError::new(msg.into())), 0)
79  }
80
81  /// Create a new done packet for the given port.
82  pub fn done<T: Into<String>>(port: T) -> Self {
83    Self::new_for_port(port, PacketPayload::Ok(None), DONE_FLAG)
84  }
85
86  /// Create a new open bracket packet for the given port.
87  pub fn open_bracket<T: Into<String>>(port: T) -> Self {
88    Self::new_for_port(port, PacketPayload::Ok(None), OPEN_BRACKET)
89  }
90
91  /// Create a close bracket packet for the given port.
92  pub fn close_bracket<T: Into<String>>(port: T) -> Self {
93    Self::new_for_port(port, PacketPayload::Ok(None), CLOSE_BRACKET)
94  }
95
96  /// Get the context of a [crate::ContextTransport] on this packet.
97  pub fn context(&self) -> Option<Base64Bytes> {
98    self.extra.context.clone()
99  }
100
101  /// Set the content of a [crate::ContextTransport] on this packet.
102  pub fn set_context(&mut self, context: Base64Bytes) {
103    self.extra.context = Some(context);
104  }
105
106  /// Encode a value into a [Packet] for the given port.
107  pub fn encode<P: Into<String>, T: Serialize>(port: P, data: T) -> Self {
108    Self::new_for_port(port, PacketPayload::encode(data), 0)
109  }
110
111  /// Get the operation index associated with this packet.
112  pub const fn index(&self) -> Option<u32> {
113    self.metadata.index
114  }
115
116  /// Try to deserialize a [Packet] into the target type.
117  pub fn decode<T: DeserializeOwned>(&self) -> Result<T, Error> {
118    self.payload.decode()
119  }
120
121  /// Partially decode a [Packet] and wrap it into a [TypeWrapper].
122  pub fn to_type_wrapper(self, ty: Type) -> Result<TypeWrapper, Error> {
123    self.payload.type_wrapper(ty)
124  }
125
126  /// Decode a [Packet] into a [serde_json::Value].
127  pub fn decode_value(&self) -> Result<serde_json::Value, Error> {
128    self.payload.decode()
129  }
130
131  /// Set the port for this packet.
132  pub fn to_port<T: Into<String>>(mut self, port: T) -> Self {
133    self.extra.port = port.into();
134    self
135  }
136
137  /// Set the port for this packet.
138  pub fn set_port<T: Into<String>>(&mut self, port: T) {
139    self.extra.port = port.into();
140  }
141
142  /// Return `true` if this is an error packet.
143  pub const fn is_error(&self) -> bool {
144    matches!(self.payload, PacketPayload::Err(_))
145  }
146
147  /// Get the inner payload of this packet.
148  pub const fn payload(&self) -> &PacketPayload {
149    &self.payload
150  }
151
152  /// Returns the payload, panicking if it is an error.
153  pub fn unwrap_payload(self) -> Option<Base64Bytes> {
154    match self.payload {
155      PacketPayload::Ok(v) => v,
156      _ => panic!("Packet is an error"),
157    }
158  }
159
160  /// Returns the error, panicking if the packet was a success packet.
161  pub fn unwrap_err(self) -> PacketError {
162    match self.payload {
163      PacketPayload::Err(err) => err,
164      _ => panic!("Packet is not an error"),
165    }
166  }
167
168  /// Return a simplified JSON representation of this packet.
169  pub fn to_json(&self) -> serde_json::Value {
170    if self.flags() > 0 {
171      let mut map = serde_json::json!({
172        "flags": self.flags(),
173        "port": self.port()
174      });
175      if self.has_data() {
176        map
177          .as_object_mut()
178          .unwrap()
179          .insert("payload".to_owned(), self.payload.to_json());
180      }
181      map
182    } else {
183      serde_json::json!({
184        "port": self.port(),
185        "payload": self.payload.to_json(),
186      })
187    }
188  }
189}
190
191impl PacketExt for Packet {
192  fn has_data(&self) -> bool {
193    match &self.payload {
194      PacketPayload::Ok(Some(data)) => !data.is_empty(),
195      PacketPayload::Ok(None) => false,
196      PacketPayload::Err(_) => false,
197    }
198  }
199
200  fn port(&self) -> &str {
201    &self.extra.port
202  }
203
204  fn flags(&self) -> u8 {
205    self.extra.flags
206  }
207}
208
209pub trait PacketExt {
210  /// Returns `true` if the packet contains data in the payload.
211  fn has_data(&self) -> bool;
212
213  /// Get the port for this packet.
214  fn port(&self) -> &str;
215
216  /// Get the inner payload of this packet.
217  fn flags(&self) -> u8;
218
219  /// Return `true` if this is a No-Op packet. No action should be taken.
220  fn is_noop(&self) -> bool {
221    self.port() == Packet::NO_INPUT
222  }
223
224  /// Return `true` if this is a fatal, component wide error packet.
225  fn is_fatal_error(&self) -> bool {
226    self.port() == Packet::FATAL_ERROR
227  }
228
229  /// Returns true if this packet is a signal packet (i.e. done, open_bracket, close_bracket, etc).
230  fn is_signal(&self) -> bool {
231    self.flags() > 0
232  }
233
234  /// Returns true if this packet is a bracket packet (i.e open_bracket, close_bracket, etc).
235  fn is_bracket(&self) -> bool {
236    self.flags() & (OPEN_BRACKET | CLOSE_BRACKET) > 0
237  }
238
239  /// Returns true if this packet is a done packet.
240  fn is_done(&self) -> bool {
241    self.flags() & DONE_FLAG == DONE_FLAG
242  }
243
244  /// Returns true if this packet is an open bracket packet.
245  fn is_open_bracket(&self) -> bool {
246    self.flags() & OPEN_BRACKET == OPEN_BRACKET
247  }
248
249  /// Returns true if this packet is a close bracket packet.
250  fn is_close_bracket(&self) -> bool {
251    self.flags() & CLOSE_BRACKET == CLOSE_BRACKET
252  }
253}
254
255impl PartialEq for PacketPayload {
256  fn eq(&self, other: &Self) -> bool {
257    match (self, other) {
258      (Self::Ok(l0), Self::Ok(r0)) => l0 == r0,
259      (Self::Err(l0), Self::Err(r0)) => l0.msg == r0.msg,
260      _ => core::mem::discriminant(self) == core::mem::discriminant(other),
261    }
262  }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266#[allow(clippy::exhaustive_enums)]
267pub enum PacketPayload {
268  Ok(Option<Base64Bytes>),
269  Err(PacketError),
270}
271
272impl PacketPayload {
273  pub fn fatal_error<T: Into<String>>(err: T) -> Self {
274    Self::Err(PacketError::new(err))
275  }
276
277  /// Encode a value into a [PacketPayload]
278  pub fn encode<T: Serialize>(data: T) -> Self {
279    match wasmrs_codec::messagepack::serialize(&data) {
280      Ok(bytes) => PacketPayload::Ok(Some(bytes.into())),
281      Err(err) => PacketPayload::err(err.to_string()),
282    }
283  }
284
285  /// Try to deserialize a [Packet] into the target type
286  pub fn decode<T: DeserializeOwned>(&self) -> Result<T, Error> {
287    match self {
288      PacketPayload::Ok(Some(bytes)) => match wasmrs_codec::messagepack::deserialize(bytes) {
289        Ok(data) => Ok(data),
290        Err(err) => Err(crate::Error::Decode {
291          as_json: wasmrs_codec::messagepack::deserialize::<serde_json::Value>(bytes)
292            .map_or_else(|_e| "could not convert".to_owned(), |v| v.to_string()),
293          error: err.to_string(),
294        }),
295      },
296      PacketPayload::Ok(None) => Err(crate::Error::NoData),
297      PacketPayload::Err(err) => Err(crate::Error::PayloadError(err.clone())),
298    }
299  }
300
301  pub fn err<T: Into<String>>(msg: T) -> Self {
302    Self::Err(PacketError::new(msg))
303  }
304
305  /// Partially process a [Packet] as [Type].
306  pub fn type_wrapper(self, sig: Type) -> Result<TypeWrapper, Error> {
307    let val = coerce(self.decode::<serde_json::Value>()?, &sig)?;
308    Ok(TypeWrapper::new(sig, val))
309  }
310
311  pub const fn bytes(&self) -> Option<&Base64Bytes> {
312    match self {
313      Self::Ok(b) => b.as_ref(),
314      _ => None,
315    }
316  }
317
318  #[allow(clippy::missing_const_for_fn)]
319  pub fn into_bytes(self) -> Option<Base64Bytes> {
320    match self {
321      Self::Ok(b) => b,
322      _ => None,
323    }
324  }
325
326  pub fn to_json(&self) -> serde_json::Value {
327    match self {
328      Self::Ok(Some(b)) => match wasmrs_codec::messagepack::deserialize::<serde_json::Value>(b) {
329        Ok(data) => serde_json::json!({ "value": data }),
330        Err(err) => serde_json::json! ({"error" : crate::Error::Jsonify(err.to_string()).to_string()}),
331      },
332      Self::Ok(None) => serde_json::Value::Null,
333      Self::Err(err) => serde_json::json! ({"error" : crate::Error::PayloadError(err.clone()).to_string()}),
334    }
335  }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
339pub struct PacketError {
340  msg: String,
341}
342
343impl PacketError {
344  pub fn new<T: Into<String>>(msg: T) -> Self {
345    Self { msg: msg.into() }
346  }
347
348  #[must_use]
349  pub fn msg(&self) -> &str {
350    &self.msg
351  }
352}
353
354impl std::error::Error for PacketError {}
355
356impl std::fmt::Display for PacketError {
357  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358    f.write_str(&self.msg)
359  }
360}
361
362impl From<Result<RawPayload, PayloadError>> for Packet {
363  fn from(p: Result<RawPayload, PayloadError>) -> Self {
364    p.map_or_else(
365      |e| {
366        if let Some(mut metadata) = e.metadata {
367          let md = wasmrs::Metadata::decode(&mut metadata);
368
369          let wmd = md.map_or_else(
370            |_e| WickMetadata::default(),
371            |m| {
372              m.extra
373                .map_or_else(WickMetadata::default, |extra| WickMetadata::decode(extra).unwrap())
374            },
375          );
376          Packet::raw_err(wmd.port, PacketError::new(e.msg))
377        } else {
378          Packet::component_error(e.msg)
379        }
380      },
381      |p| {
382        if let Some(mut metadata) = p.metadata {
383          let md = wasmrs::Metadata::decode(&mut metadata);
384
385          let wmd = md.map_or_else(
386            |_e| WickMetadata::default(),
387            |m| {
388              m.extra
389                .map_or_else(WickMetadata::default, |extra| WickMetadata::decode(extra).unwrap())
390            },
391          );
392          // Potential danger zone: this converts empty payload to None which *should* be the
393          // same thing. Calling this out as a potential source for weird bugs if they pop up.
394          let data = p.data.and_then(|b| (!b.is_empty()).then_some(b));
395          Packet::new_for_port(wmd.port(), PacketPayload::Ok(data.map(Into::into)), wmd.flags())
396        } else {
397          Packet::component_error("invalid wasmrs packet with no metadata.")
398        }
399      },
400    )
401  }
402}
403
404impl From<Result<Payload, PayloadError>> for Packet {
405  fn from(p: Result<Payload, PayloadError>) -> Self {
406    p.map_or_else(
407      |e| {
408        let md = wasmrs::Metadata::decode(&mut e.metadata.unwrap());
409
410        let wmd = md.map_or_else(
411          |_e| WickMetadata::default(),
412          |m| WickMetadata::decode(m.extra.unwrap()).unwrap(),
413        );
414        Packet::raw_err(wmd.port, PacketError::new(e.msg))
415      },
416      |p| {
417        let md = p.metadata;
418        let wmd = WickMetadata::decode(md.extra.unwrap()).unwrap();
419        // Potential danger zone: this converts empty payload to None which *should* be the
420        // same thing. Calling this out as a potential source for weird bugs if they pop up.
421        let data = p.data;
422        Packet::new_for_port(wmd.port(), PacketPayload::Ok(Some(data.into())), wmd.flags())
423      },
424    )
425  }
426}
427
428#[must_use]
429pub fn packetstream_to_wasmrs(index: u32, stream: PacketStream) -> BoxFlux<RawPayload, PayloadError> {
430  let s = tokio_stream::StreamExt::map(stream, move |p| {
431    p.map_or_else(
432      |e| Err(PayloadError::application_error(e.to_string(), None)),
433      |p| {
434        let md = wasmrs::Metadata::new_extra(index, p.extra.encode()).encode();
435        match p.payload {
436          PacketPayload::Ok(b) => Ok(wasmrs::RawPayload::new_data(Some(md), b.map(Into::into))),
437          PacketPayload::Err(e) => Err(wasmrs::PayloadError::application_error(e.msg(), Some(md))),
438        }
439      },
440    )
441  });
442  Box::pin(s)
443}
444
445pub fn from_raw_wasmrs<T: Stream<Item = Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static>(
446  stream: T,
447) -> PacketStream {
448  let s = tokio_stream::StreamExt::map(stream, move |p| Ok(p.into()));
449  PacketStream::new(Box::new(s))
450}
451
452pub fn from_wasmrs<T: Stream<Item = Result<Payload, PayloadError>> + ConditionallySend + Unpin + 'static>(
453  stream: T,
454) -> PacketStream {
455  let s = tokio_stream::StreamExt::map(stream, move |p| Ok(p.into()));
456  PacketStream::new(Box::new(s))
457}
458
459impl From<Payload> for Packet {
460  fn from(mut value: Payload) -> Self {
461    let ex = value.metadata.extra.take();
462
463    Self {
464      extra: WickMetadata::decode(ex.unwrap()).unwrap(),
465      metadata: value.metadata,
466      payload: PacketPayload::Ok(Some(value.data.into())),
467    }
468  }
469}
470
471impl From<Packet> for Result<RawPayload, PayloadError> {
472  fn from(value: Packet) -> Self {
473    let mut md = value.metadata;
474    md.extra = Some(value.extra.encode());
475    match value.payload {
476      PacketPayload::Ok(b) => Ok(RawPayload::new_data(Some(md.encode()), b.map(Into::into))),
477      PacketPayload::Err(e) => Err(PayloadError::application_error(e.msg(), Some(md.encode()))),
478    }
479  }
480}
481
482#[cfg(test)]
483mod test {
484  use anyhow::Result;
485  use serde_json::Value;
486  use wick_interface_types::Type;
487
488  use super::PacketPayload;
489  use crate::{Base64Bytes, Packet};
490
491  #[test]
492  fn test_basic() -> Result<()> {
493    let packet = Packet::encode("test", 10);
494    let res: i32 = packet.decode()?;
495    assert_eq!(res, 10);
496    Ok(())
497  }
498
499  #[rstest::rstest]
500  #[case(u64::MIN, Value::Number(serde_json::Number::from(u64::MIN)))]
501  #[case(u64::MAX, Value::Number(serde_json::Number::from(u64::MAX)))]
502  #[case(&[1,2,3,4,5,6], vec![1,2,3,4,5,6].into())]
503  #[case("test", Value::String("test".to_owned()))]
504  #[case(Base64Bytes::new(b"test".as_slice()), Value::String("dGVzdA==".to_owned()))]
505  fn test_encode_to_generic<T>(#[case] value: T, #[case] expected: Value) -> Result<()>
506  where
507    T: serde::Serialize + std::fmt::Debug,
508  {
509    let packet = Packet::encode("test", value);
510    println!("{:?}", packet);
511    let res = packet.decode_value()?;
512    assert_eq!(res, expected);
513    Ok(())
514  }
515
516  #[rstest::rstest]
517  #[case("2", Type::String, Value::String("2".into()))]
518  #[case(2, Type::String, Value::String("2".into()))]
519  fn test_type_wrapper<T>(#[case] value: T, #[case] ty: Type, #[case] expected: Value) -> Result<()>
520  where
521    T: serde::Serialize + std::fmt::Debug,
522  {
523    let packet = PacketPayload::encode(value);
524    println!("{:?}", packet);
525    let wrapper = packet.type_wrapper(ty)?;
526    assert_eq!(wrapper.into_inner(), expected);
527    Ok(())
528  }
529
530  #[rstest::rstest]
531  #[case("dGVzdA==", b"test")]
532  fn test_from_b64(#[case] value: &str, #[case] expected: &[u8]) -> Result<()> {
533    let packet = Packet::encode("test", value);
534    println!("{:?}", packet);
535    let res = packet.decode_value()?;
536    let bytes: Base64Bytes = serde_json::from_value(res).unwrap();
537    assert_eq!(bytes, expected);
538    Ok(())
539  }
540}