1use tracing::warn;
2use wasmrs::{PayloadError, RawPayload};
3use wasmrs_runtime::ConditionallySend;
4use wasmrs_rx::{FluxChannel, Observer};
5
6use crate::{Packet, PacketPayload};
7
8pub struct OutgoingPort<T> {
9 channel: FluxChannel<RawPayload, PayloadError>,
10 name: String,
11 _phantom: std::marker::PhantomData<T>,
12}
13
14impl<T> std::fmt::Debug for OutgoingPort<T> {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 f.debug_struct("Output").field("name", &self.name).finish()
17 }
18}
19
20pub trait WasmRsChannel: ConditionallySend {
21 fn channel(&self) -> FluxChannel<RawPayload, PayloadError>;
22}
23
24pub trait Port: WasmRsChannel + ConditionallySend {
25 fn name(&self) -> &str;
26
27 fn send_raw_result(&mut self, value: Result<RawPayload, PayloadError>) {
28 if let Err(e) = self.channel().send_result(value) {
29 warn!(
30 port = self.name(),
31 error = %e,
32 "failed sending packet on output channel, this is a bug"
33 );
34 };
35 }
36
37 fn send_packet(&mut self, value: Packet) {
38 let value = value.to_port(self.name());
39 self.send_raw_result(value.into());
40 }
41
42 fn send_raw_payload(&mut self, value: PacketPayload) {
43 self.send_packet(Packet::new_for_port(self.name(), value, 0));
44 }
45
46 fn open_bracket(&mut self) {
47 self.send_packet(Packet::open_bracket(self.name()));
48 }
49
50 fn close_bracket(&mut self) {
51 self.send_packet(Packet::close_bracket(self.name()));
52 }
53
54 fn done(&mut self) {
55 self.send_packet(Packet::done(self.name()));
56 }
57
58 fn error(&mut self, err: &str) {
59 self.send_packet(Packet::err(self.name(), err));
60 }
61}
62
63pub trait ValuePort<T>: Port {
64 fn send(&mut self, value: T);
65
66 fn send_result(&mut self, value: Result<T, impl std::fmt::Display>);
67}
68
69impl<T> ValuePort<T> for OutgoingPort<T>
70where
71 T: serde::Serialize + ConditionallySend,
72{
73 fn send(&mut self, value: T) {
74 self.send_packet(Packet::encode(self.name(), value));
75 }
76
77 fn send_result(&mut self, value: Result<T, impl std::fmt::Display>) {
78 match value {
79 Ok(value) => self.send(value),
80 Err(err) => self.error(err.to_string().as_str()),
81 }
82 }
83}
84
85impl<T> ValuePort<&T> for OutgoingPort<T>
86where
87 T: serde::Serialize + ConditionallySend,
88{
89 fn send(&mut self, value: &T) {
90 self.send_packet(Packet::encode(self.name(), value));
91 }
92
93 fn send_result(&mut self, value: Result<&T, impl std::fmt::Display>) {
94 match value {
95 Ok(value) => self.send(value),
96 Err(err) => self.error(err.to_string().as_str()),
97 }
98 }
99}
100
101impl ValuePort<&str> for OutgoingPort<String> {
102 fn send(&mut self, value: &str) {
103 self.send_packet(Packet::encode(self.name(), value));
104 }
105
106 fn send_result(&mut self, value: Result<&str, impl std::fmt::Display>) {
107 match value {
108 Ok(value) => self.send(value),
109 Err(err) => self.error(err.to_string().as_str()),
110 }
111 }
112}
113
114impl<T> Port for OutgoingPort<T>
115where
116 T: serde::Serialize + ConditionallySend,
117{
118 fn name(&self) -> &str {
119 &self.name
120 }
121}
122
123impl<T> WasmRsChannel for OutgoingPort<T>
124where
125 T: serde::Serialize + ConditionallySend,
126{
127 fn channel(&self) -> FluxChannel<RawPayload, PayloadError> {
128 self.channel.clone()
129 }
130}
131
132impl<T> OutgoingPort<T>
133where
134 T: serde::Serialize,
135{
136 pub fn new<K: Into<String>>(name: K, channel: FluxChannel<RawPayload, PayloadError>) -> Self {
137 Self {
138 channel,
139 name: name.into(),
140 _phantom: Default::default(),
141 }
142 }
143}
144
145#[must_use]
147pub struct OutputIterator<'a> {
148 outputs: Vec<&'a mut dyn Port>,
149}
150
151impl<'a> std::fmt::Debug for OutputIterator<'a> {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("OutputIterator")
154 .field("outputs", &self.outputs.iter().map(|a| a.name()).collect::<Vec<_>>())
155 .finish()
156 }
157}
158impl<'a> OutputIterator<'a> {
159 pub fn new(outputs: Vec<&'a mut dyn Port>) -> Self {
161 Self { outputs }
162 }
163}
164
165impl<'a> IntoIterator for OutputIterator<'a> {
166 type Item = &'a mut dyn Port;
167
168 type IntoIter = std::vec::IntoIter<Self::Item>;
169
170 fn into_iter(self) -> Self::IntoIter {
171 self.outputs.into_iter()
172 }
173}
174
175#[cfg(test)]
176mod test {
177 use anyhow::Result;
178 use tokio::task::JoinHandle;
179 use tokio_stream::StreamExt;
180
181 use super::*;
182 use crate::{packet_stream, PacketExt, PacketStream};
183
184 #[test_logger::test(tokio::test)]
185 async fn test_outputs() -> Result<()> {
186 struct Outputs {
187 a: OutgoingPort<i32>,
188 b: OutgoingPort<String>,
189 c: OutgoingPort<SomeStruct>,
190 }
191
192 let (stream, rx) = FluxChannel::new_parts();
193
194 let mut outputs = Outputs {
195 a: OutgoingPort::new("a", stream.clone()),
196 b: OutgoingPort::new("b", stream.clone()),
197 c: OutgoingPort::new("c", stream),
198 };
199
200 #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
201 struct SomeStruct {
202 a: String,
203 }
204 let some_struct = SomeStruct { a: "hey".to_owned() };
205
206 outputs.a.send(&42);
207 outputs.a.send(42);
208 outputs.b.send("hey");
209 outputs.b.send("hey".to_owned());
210 let kinda_string = std::borrow::Cow::Borrowed("hey");
211 outputs.b.send(kinda_string.as_ref());
212 outputs.c.send(&some_struct.clone());
213 outputs.c.send(&some_struct);
214 drop(outputs);
215
216 let mut packets = rx.collect::<Vec<_>>().await;
217
218 let p: Packet = packets.remove(0).into();
219 assert_eq!(p.decode::<i32>()?, 42);
220 let p: Packet = packets.remove(0).into();
221 assert_eq!(p.decode::<i32>()?, 42);
222 let p: Packet = packets.remove(0).into();
223 assert_eq!(p.decode::<String>()?, "hey");
224 let p: Packet = packets.remove(0).into();
225 assert_eq!(p.decode::<String>()?, "hey");
226 let p: Packet = packets.remove(0).into();
227 assert_eq!(p.decode::<String>()?, "hey");
228 let p: Packet = packets.remove(0).into();
229 assert_eq!(p.decode::<SomeStruct>()?, some_struct);
230 let p: Packet = packets.remove(0).into();
231 assert_eq!(p.decode::<SomeStruct>()?, some_struct);
232
233 Ok(())
234 }
235
236 #[test_logger::test(tokio::test)]
237 async fn test_inputs() -> Result<()> {
238 struct Inputs {
239 #[allow(unused)]
240 task: JoinHandle<()>,
241 a: PacketStream,
242 b: PacketStream,
243 }
244 impl Inputs {
245 fn new(mut stream: PacketStream) -> Self {
246 let (a_tx, a_rx) = PacketStream::new_channels();
247 let (b_tx, b_rx) = PacketStream::new_channels();
248 let task = tokio::spawn(async move {
249 while let Some(next) = stream.next().await {
250 let _ = match next {
251 Ok(packet) => match packet.port() {
252 "a" => a_tx.send(packet),
253 "b" => b_tx.send(packet),
254 crate::Packet::FATAL_ERROR => {
255 let _ = a_tx.send(packet.clone());
256 b_tx.send(packet.clone())
257 }
258 _ => continue,
259 },
260 Err(e) => {
261 let _ = a_tx.error(e.clone());
262 b_tx.error(e)
263 }
264 };
265 }
266 });
267
268 Self { task, a: a_rx, b: b_rx }
269 }
270 }
271
272 let stream = packet_stream!(("a", 32), ("b", "Hey"));
273
274 let mut inputs = Inputs::new(stream);
275
276 assert_eq!(inputs.a.next().await.unwrap()?.decode::<i32>()?, 32);
277 assert_eq!(inputs.b.next().await.unwrap()?.decode::<String>()?, "Hey");
278
279 Ok(())
280 }
281}