wick_packet/
output.rs

1use tracing::warn;
2use wasmrs::{PayloadError, RawPayload};
3use wasmrs_runtime::ConditionallySend;
4use wasmrs_rx::{FluxChannel, Observer};
5
6use crate::{Packet, PacketPayload};
7
8pub struct OutgoingPort<T> {
9  channel: FluxChannel<RawPayload, PayloadError>,
10  name: String,
11  _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T> std::fmt::Debug for OutgoingPort<T> {
15  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16    f.debug_struct("Output").field("name", &self.name).finish()
17  }
18}
19
20pub trait WasmRsChannel: ConditionallySend {
21  fn channel(&self) -> FluxChannel<RawPayload, PayloadError>;
22}
23
24pub trait Port: WasmRsChannel + ConditionallySend {
25  fn name(&self) -> &str;
26
27  fn send_raw_result(&mut self, value: Result<RawPayload, PayloadError>) {
28    if let Err(e) = self.channel().send_result(value) {
29      warn!(
30        port = self.name(),
31        error = %e,
32        "failed sending packet on output channel, this is a bug"
33      );
34    };
35  }
36
37  fn send_packet(&mut self, value: Packet) {
38    let value = value.to_port(self.name());
39    self.send_raw_result(value.into());
40  }
41
42  fn send_raw_payload(&mut self, value: PacketPayload) {
43    self.send_packet(Packet::new_for_port(self.name(), value, 0));
44  }
45
46  fn open_bracket(&mut self) {
47    self.send_packet(Packet::open_bracket(self.name()));
48  }
49
50  fn close_bracket(&mut self) {
51    self.send_packet(Packet::close_bracket(self.name()));
52  }
53
54  fn done(&mut self) {
55    self.send_packet(Packet::done(self.name()));
56  }
57
58  fn error(&mut self, err: &str) {
59    self.send_packet(Packet::err(self.name(), err));
60  }
61}
62
63pub trait ValuePort<T>: Port {
64  fn send(&mut self, value: T);
65
66  fn send_result(&mut self, value: Result<T, impl std::fmt::Display>);
67}
68
69impl<T> ValuePort<T> for OutgoingPort<T>
70where
71  T: serde::Serialize + ConditionallySend,
72{
73  fn send(&mut self, value: T) {
74    self.send_packet(Packet::encode(self.name(), value));
75  }
76
77  fn send_result(&mut self, value: Result<T, impl std::fmt::Display>) {
78    match value {
79      Ok(value) => self.send(value),
80      Err(err) => self.error(err.to_string().as_str()),
81    }
82  }
83}
84
85impl<T> ValuePort<&T> for OutgoingPort<T>
86where
87  T: serde::Serialize + ConditionallySend,
88{
89  fn send(&mut self, value: &T) {
90    self.send_packet(Packet::encode(self.name(), value));
91  }
92
93  fn send_result(&mut self, value: Result<&T, impl std::fmt::Display>) {
94    match value {
95      Ok(value) => self.send(value),
96      Err(err) => self.error(err.to_string().as_str()),
97    }
98  }
99}
100
101impl ValuePort<&str> for OutgoingPort<String> {
102  fn send(&mut self, value: &str) {
103    self.send_packet(Packet::encode(self.name(), value));
104  }
105
106  fn send_result(&mut self, value: Result<&str, impl std::fmt::Display>) {
107    match value {
108      Ok(value) => self.send(value),
109      Err(err) => self.error(err.to_string().as_str()),
110    }
111  }
112}
113
114impl<T> Port for OutgoingPort<T>
115where
116  T: serde::Serialize + ConditionallySend,
117{
118  fn name(&self) -> &str {
119    &self.name
120  }
121}
122
123impl<T> WasmRsChannel for OutgoingPort<T>
124where
125  T: serde::Serialize + ConditionallySend,
126{
127  fn channel(&self) -> FluxChannel<RawPayload, PayloadError> {
128    self.channel.clone()
129  }
130}
131
132impl<T> OutgoingPort<T>
133where
134  T: serde::Serialize,
135{
136  pub fn new<K: Into<String>>(name: K, channel: FluxChannel<RawPayload, PayloadError>) -> Self {
137    Self {
138      channel,
139      name: name.into(),
140      _phantom: Default::default(),
141    }
142  }
143}
144
145/// Iterator over a mutable set of output ports
146#[must_use]
147pub struct OutputIterator<'a> {
148  outputs: Vec<&'a mut dyn Port>,
149}
150
151impl<'a> std::fmt::Debug for OutputIterator<'a> {
152  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153    f.debug_struct("OutputIterator")
154      .field("outputs", &self.outputs.iter().map(|a| a.name()).collect::<Vec<_>>())
155      .finish()
156  }
157}
158impl<'a> OutputIterator<'a> {
159  /// Create a new [OutputIterator]
160  pub fn new(outputs: Vec<&'a mut dyn Port>) -> Self {
161    Self { outputs }
162  }
163}
164
165impl<'a> IntoIterator for OutputIterator<'a> {
166  type Item = &'a mut dyn Port;
167
168  type IntoIter = std::vec::IntoIter<Self::Item>;
169
170  fn into_iter(self) -> Self::IntoIter {
171    self.outputs.into_iter()
172  }
173}
174
175#[cfg(test)]
176mod test {
177  use anyhow::Result;
178  use tokio::task::JoinHandle;
179  use tokio_stream::StreamExt;
180
181  use super::*;
182  use crate::{packet_stream, PacketExt, PacketStream};
183
184  #[test_logger::test(tokio::test)]
185  async fn test_outputs() -> Result<()> {
186    struct Outputs {
187      a: OutgoingPort<i32>,
188      b: OutgoingPort<String>,
189      c: OutgoingPort<SomeStruct>,
190    }
191
192    let (stream, rx) = FluxChannel::new_parts();
193
194    let mut outputs = Outputs {
195      a: OutgoingPort::new("a", stream.clone()),
196      b: OutgoingPort::new("b", stream.clone()),
197      c: OutgoingPort::new("c", stream),
198    };
199
200    #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
201    struct SomeStruct {
202      a: String,
203    }
204    let some_struct = SomeStruct { a: "hey".to_owned() };
205
206    outputs.a.send(&42);
207    outputs.a.send(42);
208    outputs.b.send("hey");
209    outputs.b.send("hey".to_owned());
210    let kinda_string = std::borrow::Cow::Borrowed("hey");
211    outputs.b.send(kinda_string.as_ref());
212    outputs.c.send(&some_struct.clone());
213    outputs.c.send(&some_struct);
214    drop(outputs);
215
216    let mut packets = rx.collect::<Vec<_>>().await;
217
218    let p: Packet = packets.remove(0).into();
219    assert_eq!(p.decode::<i32>()?, 42);
220    let p: Packet = packets.remove(0).into();
221    assert_eq!(p.decode::<i32>()?, 42);
222    let p: Packet = packets.remove(0).into();
223    assert_eq!(p.decode::<String>()?, "hey");
224    let p: Packet = packets.remove(0).into();
225    assert_eq!(p.decode::<String>()?, "hey");
226    let p: Packet = packets.remove(0).into();
227    assert_eq!(p.decode::<String>()?, "hey");
228    let p: Packet = packets.remove(0).into();
229    assert_eq!(p.decode::<SomeStruct>()?, some_struct);
230    let p: Packet = packets.remove(0).into();
231    assert_eq!(p.decode::<SomeStruct>()?, some_struct);
232
233    Ok(())
234  }
235
236  #[test_logger::test(tokio::test)]
237  async fn test_inputs() -> Result<()> {
238    struct Inputs {
239      #[allow(unused)]
240      task: JoinHandle<()>,
241      a: PacketStream,
242      b: PacketStream,
243    }
244    impl Inputs {
245      fn new(mut stream: PacketStream) -> Self {
246        let (a_tx, a_rx) = PacketStream::new_channels();
247        let (b_tx, b_rx) = PacketStream::new_channels();
248        let task = tokio::spawn(async move {
249          while let Some(next) = stream.next().await {
250            let _ = match next {
251              Ok(packet) => match packet.port() {
252                "a" => a_tx.send(packet),
253                "b" => b_tx.send(packet),
254                crate::Packet::FATAL_ERROR => {
255                  let _ = a_tx.send(packet.clone());
256                  b_tx.send(packet.clone())
257                }
258                _ => continue,
259              },
260              Err(e) => {
261                let _ = a_tx.error(e.clone());
262                b_tx.error(e)
263              }
264            };
265          }
266        });
267
268        Self { task, a: a_rx, b: b_rx }
269      }
270    }
271
272    let stream = packet_stream!(("a", 32), ("b", "Hey"));
273
274    let mut inputs = Inputs::new(stream);
275
276    assert_eq!(inputs.a.next().await.unwrap()?.decode::<i32>()?, 32);
277    assert_eq!(inputs.b.next().await.unwrap()?.decode::<String>()?, "Hey");
278
279    Ok(())
280  }
281}