1use eyre::{Context, OptionExt};
2use std::{collections::HashMap, sync::Arc};
3
4use tokio::sync::{
5 Mutex,
6 broadcast::{Receiver, Sender},
7};
8
9use arrow_array::Array;
10use arrow_data::ArrayData;
11
12use crate::prelude::*;
13
14pub struct RawOutput {
15 clock: Arc<uhlc::HLC>,
16 tx: Sender<DataflowMessage>,
17}
18
19impl RawOutput {
20 pub fn new(clock: Arc<uhlc::HLC>, tx: Sender<DataflowMessage>) -> Self {
21 Self { clock, tx }
22 }
23
24 pub fn send(&self, data: ArrayData) -> eyre::Result<()> {
25 let data = DataflowMessage {
26 header: Header {
27 timestamp: self.clock.new_timestamp(),
28 },
29 data,
30 };
31
32 self.tx
33 .send(data)
34 .map(|_| ())
35 .map_err(eyre::Report::msg)
36 .wrap_err("Failed to send the message")
37 }
38}
39
40pub struct Output<T: ArrowMessage> {
41 raw: RawOutput,
42 _phantom: std::marker::PhantomData<T>,
43}
44
45impl<T: ArrowMessage> Output<T> {
46 pub fn new(clock: Arc<uhlc::HLC>, tx: Sender<DataflowMessage>) -> Self {
47 Self {
48 raw: RawOutput::new(clock, tx),
49 _phantom: std::marker::PhantomData,
50 }
51 }
52
53 pub fn send(&self, data: T) -> eyre::Result<()> {
54 self.raw.send(
55 data.try_into_arrow()
56 .wrap_err("Failed to convert arrow 'data' to message T")?
57 .into_data(),
58 )
59 }
60}
61
62pub struct RawInput {
63 rx: Receiver<DataflowMessage>,
64}
65
66impl RawInput {
67 pub fn new(rx: Receiver<DataflowMessage>) -> Self {
68 Self { rx }
69 }
70
71 pub fn recv(&mut self) -> eyre::Result<(Header, ArrayData)> {
72 let DataflowMessage { header, data } = self
73 .rx
74 .blocking_recv()
75 .map_err(eyre::Report::msg)
76 .wrap_err("Failed to receive from this input")?;
77
78 Ok((header, data))
79 }
80
81 pub async fn recv_async(&mut self) -> eyre::Result<(Header, ArrayData)> {
82 let DataflowMessage { header, data } = self
83 .rx
84 .recv()
85 .await
86 .map_err(eyre::Report::msg)
87 .wrap_err("Failed to receive from this input")?;
88
89 Ok((header, data))
90 }
91}
92
93pub struct Input<T: ArrowMessage> {
94 raw: RawInput,
95
96 _phantom: std::marker::PhantomData<T>,
97}
98
99impl<T: ArrowMessage> Input<T> {
100 pub fn new(rx: Receiver<DataflowMessage>) -> Self {
101 Self {
102 raw: RawInput::new(rx),
103 _phantom: std::marker::PhantomData,
104 }
105 }
106
107 pub fn recv(&mut self) -> eyre::Result<(Header, T)> {
108 let (header, data) = self.raw.recv()?;
109
110 Ok((
111 header,
112 T::try_from_arrow(data).wrap_err("Failed to convert arrow 'data' to message T")?,
113 ))
114 }
115
116 pub async fn recv_async(&mut self) -> eyre::Result<(Header, T)> {
117 let (header, data) = self.raw.recv_async().await?;
118
119 Ok((
120 header,
121 T::try_from_arrow(data).wrap_err("Failed to convert arrow 'data' to message T")?,
122 ))
123 }
124}
125
126pub struct Inputs {
127 node: NodeID,
128 receivers: Arc<Mutex<HashMap<InputID, Receiver<DataflowMessage>>>>,
129}
130
131impl Inputs {
132 pub fn new(
133 node: NodeID,
134 receivers: Arc<Mutex<HashMap<InputID, Receiver<DataflowMessage>>>>,
135 ) -> Self {
136 Self { node, receivers }
137 }
138
139 pub async fn with<T: ArrowMessage>(
140 &mut self,
141 input: impl Into<String>,
142 ) -> eyre::Result<Input<T>> {
143 let id = self.node.input(input);
144
145 let receiver = self
146 .receivers
147 .lock()
148 .await
149 .remove(&id)
150 .ok_or_eyre(format!("Input {} not found", id.0))?;
151
152 Ok(Input::new(receiver))
153 }
154}
155
156pub struct Outputs {
157 node: NodeID,
158 clock: Arc<uhlc::HLC>,
159 senders: Arc<Mutex<HashMap<OutputID, Sender<DataflowMessage>>>>,
160}
161
162impl Outputs {
163 pub fn new(
164 node: NodeID,
165 clock: Arc<uhlc::HLC>,
166 senders: Arc<Mutex<HashMap<OutputID, Sender<DataflowMessage>>>>,
167 ) -> Self {
168 Self {
169 node,
170 clock,
171 senders,
172 }
173 }
174
175 pub async fn with<T: ArrowMessage>(
176 &mut self,
177 output: impl Into<String>,
178 ) -> eyre::Result<Output<T>> {
179 let id = self.node.output(output);
180
181 let sender = self
182 .senders
183 .lock()
184 .await
185 .remove(&id)
186 .ok_or_eyre(format!("Output {} not found", id.0))?;
187
188 Ok(Output::new(self.clock.clone(), sender))
189 }
190}