1use serde::de::DeserializeOwned;
2use serde::{Deserialize, Serialize};
3use tokio_stream::Stream;
4use wasmrs::{BoxFlux, Metadata, Payload, PayloadError, RawPayload};
5use wasmrs_runtime::ConditionallySend;
6use wick_interface_types::Type;
7
8use crate::metadata::DONE_FLAG;
9use crate::wrapped_type::coerce;
10use crate::{Base64Bytes, Error, PacketStream, TypeWrapper, WickMetadata, CLOSE_BRACKET, OPEN_BRACKET};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[must_use]
14pub struct Packet {
15 pub(crate) metadata: Metadata,
16 pub(crate) extra: WickMetadata,
17 pub payload: PacketPayload,
18}
19
20impl PartialEq for Packet {
21 fn eq(&self, other: &Self) -> bool {
22 if self.metadata.index != other.metadata.index || !self.metadata.extra.eq(&other.metadata.extra) {
23 return false;
24 }
25 if self.extra.ne(&other.extra) {
26 return false;
27 }
28 self.payload == other.payload
29 }
30}
31
32impl Packet {
33 pub const FATAL_ERROR: &str = "<error>";
35 pub const NO_INPUT: &str = "<>";
36
37 pub const fn new_raw(payload: PacketPayload, wasmrs: Metadata, metadata: WickMetadata) -> Self {
39 Self {
40 payload,
41 metadata: wasmrs,
42 extra: metadata,
43 }
44 }
45
46 pub fn new_for_port<T: Into<String>>(port: T, payload: PacketPayload, flags: u8) -> Self {
48 let md = Metadata::new(0);
49 let wmd = WickMetadata::new(port, flags);
50 Self {
51 payload,
52 metadata: md,
53 extra: wmd,
54 }
55 }
56
57 pub fn no_input() -> Self {
58 Self::encode(Self::NO_INPUT, ())
59 }
60
61 pub fn component_error<T: Into<String>>(err: T) -> Self {
63 Self::new_for_port(Self::FATAL_ERROR, PacketPayload::fatal_error(err), 0)
64 }
65
66 pub fn ok<T: Into<String>>(port: T, payload: RawPayload) -> Self {
68 Self::new_for_port(port, PacketPayload::Ok(payload.data.map(Into::into)), 0)
69 }
70
71 pub fn raw_err<T: Into<String>>(port: T, payload: PacketError) -> Self {
73 Self::new_for_port(port, PacketPayload::Err(payload), 0)
74 }
75
76 pub fn err<T: Into<String>, E: Into<String>>(port: T, msg: E) -> Self {
78 Self::new_for_port(port, PacketPayload::Err(PacketError::new(msg.into())), 0)
79 }
80
81 pub fn done<T: Into<String>>(port: T) -> Self {
83 Self::new_for_port(port, PacketPayload::Ok(None), DONE_FLAG)
84 }
85
86 pub fn open_bracket<T: Into<String>>(port: T) -> Self {
88 Self::new_for_port(port, PacketPayload::Ok(None), OPEN_BRACKET)
89 }
90
91 pub fn close_bracket<T: Into<String>>(port: T) -> Self {
93 Self::new_for_port(port, PacketPayload::Ok(None), CLOSE_BRACKET)
94 }
95
96 pub fn context(&self) -> Option<Base64Bytes> {
98 self.extra.context.clone()
99 }
100
101 pub fn set_context(&mut self, context: Base64Bytes) {
103 self.extra.context = Some(context);
104 }
105
106 pub fn encode<P: Into<String>, T: Serialize>(port: P, data: T) -> Self {
108 Self::new_for_port(port, PacketPayload::encode(data), 0)
109 }
110
111 pub const fn index(&self) -> Option<u32> {
113 self.metadata.index
114 }
115
116 pub fn decode<T: DeserializeOwned>(&self) -> Result<T, Error> {
118 self.payload.decode()
119 }
120
121 pub fn to_type_wrapper(self, ty: Type) -> Result<TypeWrapper, Error> {
123 self.payload.type_wrapper(ty)
124 }
125
126 pub fn decode_value(&self) -> Result<serde_json::Value, Error> {
128 self.payload.decode()
129 }
130
131 pub fn to_port<T: Into<String>>(mut self, port: T) -> Self {
133 self.extra.port = port.into();
134 self
135 }
136
137 pub fn set_port<T: Into<String>>(&mut self, port: T) {
139 self.extra.port = port.into();
140 }
141
142 pub const fn is_error(&self) -> bool {
144 matches!(self.payload, PacketPayload::Err(_))
145 }
146
147 pub const fn payload(&self) -> &PacketPayload {
149 &self.payload
150 }
151
152 pub fn unwrap_payload(self) -> Option<Base64Bytes> {
154 match self.payload {
155 PacketPayload::Ok(v) => v,
156 _ => panic!("Packet is an error"),
157 }
158 }
159
160 pub fn unwrap_err(self) -> PacketError {
162 match self.payload {
163 PacketPayload::Err(err) => err,
164 _ => panic!("Packet is not an error"),
165 }
166 }
167
168 pub fn to_json(&self) -> serde_json::Value {
170 if self.flags() > 0 {
171 let mut map = serde_json::json!({
172 "flags": self.flags(),
173 "port": self.port()
174 });
175 if self.has_data() {
176 map
177 .as_object_mut()
178 .unwrap()
179 .insert("payload".to_owned(), self.payload.to_json());
180 }
181 map
182 } else {
183 serde_json::json!({
184 "port": self.port(),
185 "payload": self.payload.to_json(),
186 })
187 }
188 }
189}
190
191impl PacketExt for Packet {
192 fn has_data(&self) -> bool {
193 match &self.payload {
194 PacketPayload::Ok(Some(data)) => !data.is_empty(),
195 PacketPayload::Ok(None) => false,
196 PacketPayload::Err(_) => false,
197 }
198 }
199
200 fn port(&self) -> &str {
201 &self.extra.port
202 }
203
204 fn flags(&self) -> u8 {
205 self.extra.flags
206 }
207}
208
209pub trait PacketExt {
210 fn has_data(&self) -> bool;
212
213 fn port(&self) -> &str;
215
216 fn flags(&self) -> u8;
218
219 fn is_noop(&self) -> bool {
221 self.port() == Packet::NO_INPUT
222 }
223
224 fn is_fatal_error(&self) -> bool {
226 self.port() == Packet::FATAL_ERROR
227 }
228
229 fn is_signal(&self) -> bool {
231 self.flags() > 0
232 }
233
234 fn is_bracket(&self) -> bool {
236 self.flags() & (OPEN_BRACKET | CLOSE_BRACKET) > 0
237 }
238
239 fn is_done(&self) -> bool {
241 self.flags() & DONE_FLAG == DONE_FLAG
242 }
243
244 fn is_open_bracket(&self) -> bool {
246 self.flags() & OPEN_BRACKET == OPEN_BRACKET
247 }
248
249 fn is_close_bracket(&self) -> bool {
251 self.flags() & CLOSE_BRACKET == CLOSE_BRACKET
252 }
253}
254
255impl PartialEq for PacketPayload {
256 fn eq(&self, other: &Self) -> bool {
257 match (self, other) {
258 (Self::Ok(l0), Self::Ok(r0)) => l0 == r0,
259 (Self::Err(l0), Self::Err(r0)) => l0.msg == r0.msg,
260 _ => core::mem::discriminant(self) == core::mem::discriminant(other),
261 }
262 }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266#[allow(clippy::exhaustive_enums)]
267pub enum PacketPayload {
268 Ok(Option<Base64Bytes>),
269 Err(PacketError),
270}
271
272impl PacketPayload {
273 pub fn fatal_error<T: Into<String>>(err: T) -> Self {
274 Self::Err(PacketError::new(err))
275 }
276
277 pub fn encode<T: Serialize>(data: T) -> Self {
279 match wasmrs_codec::messagepack::serialize(&data) {
280 Ok(bytes) => PacketPayload::Ok(Some(bytes.into())),
281 Err(err) => PacketPayload::err(err.to_string()),
282 }
283 }
284
285 pub fn decode<T: DeserializeOwned>(&self) -> Result<T, Error> {
287 match self {
288 PacketPayload::Ok(Some(bytes)) => match wasmrs_codec::messagepack::deserialize(bytes) {
289 Ok(data) => Ok(data),
290 Err(err) => Err(crate::Error::Decode {
291 as_json: wasmrs_codec::messagepack::deserialize::<serde_json::Value>(bytes)
292 .map_or_else(|_e| "could not convert".to_owned(), |v| v.to_string()),
293 error: err.to_string(),
294 }),
295 },
296 PacketPayload::Ok(None) => Err(crate::Error::NoData),
297 PacketPayload::Err(err) => Err(crate::Error::PayloadError(err.clone())),
298 }
299 }
300
301 pub fn err<T: Into<String>>(msg: T) -> Self {
302 Self::Err(PacketError::new(msg))
303 }
304
305 pub fn type_wrapper(self, sig: Type) -> Result<TypeWrapper, Error> {
307 let val = coerce(self.decode::<serde_json::Value>()?, &sig)?;
308 Ok(TypeWrapper::new(sig, val))
309 }
310
311 pub const fn bytes(&self) -> Option<&Base64Bytes> {
312 match self {
313 Self::Ok(b) => b.as_ref(),
314 _ => None,
315 }
316 }
317
318 #[allow(clippy::missing_const_for_fn)]
319 pub fn into_bytes(self) -> Option<Base64Bytes> {
320 match self {
321 Self::Ok(b) => b,
322 _ => None,
323 }
324 }
325
326 pub fn to_json(&self) -> serde_json::Value {
327 match self {
328 Self::Ok(Some(b)) => match wasmrs_codec::messagepack::deserialize::<serde_json::Value>(b) {
329 Ok(data) => serde_json::json!({ "value": data }),
330 Err(err) => serde_json::json! ({"error" : crate::Error::Jsonify(err.to_string()).to_string()}),
331 },
332 Self::Ok(None) => serde_json::Value::Null,
333 Self::Err(err) => serde_json::json! ({"error" : crate::Error::PayloadError(err.clone()).to_string()}),
334 }
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
339pub struct PacketError {
340 msg: String,
341}
342
343impl PacketError {
344 pub fn new<T: Into<String>>(msg: T) -> Self {
345 Self { msg: msg.into() }
346 }
347
348 #[must_use]
349 pub fn msg(&self) -> &str {
350 &self.msg
351 }
352}
353
354impl std::error::Error for PacketError {}
355
356impl std::fmt::Display for PacketError {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 f.write_str(&self.msg)
359 }
360}
361
362impl From<Result<RawPayload, PayloadError>> for Packet {
363 fn from(p: Result<RawPayload, PayloadError>) -> Self {
364 p.map_or_else(
365 |e| {
366 if let Some(mut metadata) = e.metadata {
367 let md = wasmrs::Metadata::decode(&mut metadata);
368
369 let wmd = md.map_or_else(
370 |_e| WickMetadata::default(),
371 |m| {
372 m.extra
373 .map_or_else(WickMetadata::default, |extra| WickMetadata::decode(extra).unwrap())
374 },
375 );
376 Packet::raw_err(wmd.port, PacketError::new(e.msg))
377 } else {
378 Packet::component_error(e.msg)
379 }
380 },
381 |p| {
382 if let Some(mut metadata) = p.metadata {
383 let md = wasmrs::Metadata::decode(&mut metadata);
384
385 let wmd = md.map_or_else(
386 |_e| WickMetadata::default(),
387 |m| {
388 m.extra
389 .map_or_else(WickMetadata::default, |extra| WickMetadata::decode(extra).unwrap())
390 },
391 );
392 let data = p.data.and_then(|b| (!b.is_empty()).then_some(b));
395 Packet::new_for_port(wmd.port(), PacketPayload::Ok(data.map(Into::into)), wmd.flags())
396 } else {
397 Packet::component_error("invalid wasmrs packet with no metadata.")
398 }
399 },
400 )
401 }
402}
403
404impl From<Result<Payload, PayloadError>> for Packet {
405 fn from(p: Result<Payload, PayloadError>) -> Self {
406 p.map_or_else(
407 |e| {
408 let md = wasmrs::Metadata::decode(&mut e.metadata.unwrap());
409
410 let wmd = md.map_or_else(
411 |_e| WickMetadata::default(),
412 |m| WickMetadata::decode(m.extra.unwrap()).unwrap(),
413 );
414 Packet::raw_err(wmd.port, PacketError::new(e.msg))
415 },
416 |p| {
417 let md = p.metadata;
418 let wmd = WickMetadata::decode(md.extra.unwrap()).unwrap();
419 let data = p.data;
422 Packet::new_for_port(wmd.port(), PacketPayload::Ok(Some(data.into())), wmd.flags())
423 },
424 )
425 }
426}
427
428#[must_use]
429pub fn packetstream_to_wasmrs(index: u32, stream: PacketStream) -> BoxFlux<RawPayload, PayloadError> {
430 let s = tokio_stream::StreamExt::map(stream, move |p| {
431 p.map_or_else(
432 |e| Err(PayloadError::application_error(e.to_string(), None)),
433 |p| {
434 let md = wasmrs::Metadata::new_extra(index, p.extra.encode()).encode();
435 match p.payload {
436 PacketPayload::Ok(b) => Ok(wasmrs::RawPayload::new_data(Some(md), b.map(Into::into))),
437 PacketPayload::Err(e) => Err(wasmrs::PayloadError::application_error(e.msg(), Some(md))),
438 }
439 },
440 )
441 });
442 Box::pin(s)
443}
444
445pub fn from_raw_wasmrs<T: Stream<Item = Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static>(
446 stream: T,
447) -> PacketStream {
448 let s = tokio_stream::StreamExt::map(stream, move |p| Ok(p.into()));
449 PacketStream::new(Box::new(s))
450}
451
452pub fn from_wasmrs<T: Stream<Item = Result<Payload, PayloadError>> + ConditionallySend + Unpin + 'static>(
453 stream: T,
454) -> PacketStream {
455 let s = tokio_stream::StreamExt::map(stream, move |p| Ok(p.into()));
456 PacketStream::new(Box::new(s))
457}
458
459impl From<Payload> for Packet {
460 fn from(mut value: Payload) -> Self {
461 let ex = value.metadata.extra.take();
462
463 Self {
464 extra: WickMetadata::decode(ex.unwrap()).unwrap(),
465 metadata: value.metadata,
466 payload: PacketPayload::Ok(Some(value.data.into())),
467 }
468 }
469}
470
471impl From<Packet> for Result<RawPayload, PayloadError> {
472 fn from(value: Packet) -> Self {
473 let mut md = value.metadata;
474 md.extra = Some(value.extra.encode());
475 match value.payload {
476 PacketPayload::Ok(b) => Ok(RawPayload::new_data(Some(md.encode()), b.map(Into::into))),
477 PacketPayload::Err(e) => Err(PayloadError::application_error(e.msg(), Some(md.encode()))),
478 }
479 }
480}
481
482#[cfg(test)]
483mod test {
484 use anyhow::Result;
485 use serde_json::Value;
486 use wick_interface_types::Type;
487
488 use super::PacketPayload;
489 use crate::{Base64Bytes, Packet};
490
491 #[test]
492 fn test_basic() -> Result<()> {
493 let packet = Packet::encode("test", 10);
494 let res: i32 = packet.decode()?;
495 assert_eq!(res, 10);
496 Ok(())
497 }
498
499 #[rstest::rstest]
500 #[case(u64::MIN, Value::Number(serde_json::Number::from(u64::MIN)))]
501 #[case(u64::MAX, Value::Number(serde_json::Number::from(u64::MAX)))]
502 #[case(&[1,2,3,4,5,6], vec![1,2,3,4,5,6].into())]
503 #[case("test", Value::String("test".to_owned()))]
504 #[case(Base64Bytes::new(b"test".as_slice()), Value::String("dGVzdA==".to_owned()))]
505 fn test_encode_to_generic<T>(#[case] value: T, #[case] expected: Value) -> Result<()>
506 where
507 T: serde::Serialize + std::fmt::Debug,
508 {
509 let packet = Packet::encode("test", value);
510 println!("{:?}", packet);
511 let res = packet.decode_value()?;
512 assert_eq!(res, expected);
513 Ok(())
514 }
515
516 #[rstest::rstest]
517 #[case("2", Type::String, Value::String("2".into()))]
518 #[case(2, Type::String, Value::String("2".into()))]
519 fn test_type_wrapper<T>(#[case] value: T, #[case] ty: Type, #[case] expected: Value) -> Result<()>
520 where
521 T: serde::Serialize + std::fmt::Debug,
522 {
523 let packet = PacketPayload::encode(value);
524 println!("{:?}", packet);
525 let wrapper = packet.type_wrapper(ty)?;
526 assert_eq!(wrapper.into_inner(), expected);
527 Ok(())
528 }
529
530 #[rstest::rstest]
531 #[case("dGVzdA==", b"test")]
532 fn test_from_b64(#[case] value: &str, #[case] expected: &[u8]) -> Result<()> {
533 let packet = Packet::encode("test", value);
534 println!("{:?}", packet);
535 let res = packet.decode_value()?;
536 let bytes: Base64Bytes = serde_json::from_value(res).unwrap();
537 assert_eq!(bytes, expected);
538 Ok(())
539 }
540}