1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
use std::{fmt::Debug, marker::PhantomData};
use tokio::{
io::Ready,
select,
sync::{
broadcast::Sender as BroadcastSender,
mpsc::{UnboundedReceiver, UnboundedSender},
},
task::JoinHandle,
};
use tracing::{debug, error, trace};
use crate::{codecs::Codec, error::Error};
/// Generic [Readable] trait over different (split) transport types
pub(crate) trait Readable: Send + Sync + 'static {
// Wait for the underlying stream to be ready for reading (or closed/errored)
fn ready_internal(&mut self) -> impl Future<Output = Result<Ready, std::io::Error>> + Send;
// Try to read data from the underlying stream into the provided buffer
fn try_read_buf_internal(&mut self, buf: &mut Vec<u8>) -> Result<usize, std::io::Error>;
}
/// Generic [Writeable] trait over different (split) transport types
pub(crate) trait Writeable: Send + Sync + 'static {
// Write the provided buffer to the underlying stream, ensuring all data is sent
fn write_all_internal(
&mut self,
buf: &[u8],
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}
/// A handle for an active connection, used for both the client and server implementations.
///
pub(crate) struct Handle<
C: Codec<OUT, IN>,
OUT: Debug + Send + 'static,
IN: Debug + Send + 'static,
ADDR: Debug + Clone + Sync + Send + 'static,
> {
addr: ADDR,
exit_tx: BroadcastSender<()>,
_rx_handle: JoinHandle<()>,
_tx_handle: JoinHandle<()>,
_c: PhantomData<C>,
_out: PhantomData<OUT>,
_in: PhantomData<IN>,
}
impl<
C: Codec<OUT, IN>,
OUT: Debug + Send + 'static,
IN: Debug + Send + 'static,
ADDR: Debug + Clone + Sync + Send + 'static,
> Handle<C, OUT, IN, ADDR>
{
/// Create tasks to read from and write to an existing TCP stream and channels
pub(crate) async fn new(
addr: ADDR,
mut reader: impl Readable + Unpin + Send + 'static,
mut writer: impl Writeable + Unpin + Send + 'static,
mut out_rx: UnboundedReceiver<OUT>,
in_tx: UnboundedSender<(IN, ADDR)>,
) -> Result<Self, Error> {
// Setup the exit channel
let (exit_tx, _exit_rx) = tokio::sync::broadcast::channel::<()>(1);
// Setup a task to handle reading from the stream
let rx_exit_tx = exit_tx.clone();
let addr_ = addr.clone();
let _rx_handle = tokio::task::spawn(async move {
let mut accumulator = Vec::with_capacity(1024);
let mut exit_rx = rx_exit_tx.subscribe();
debug!("Reader task started for {addr_:?}");
loop {
select! {
biased;
// Handle exit signal
_ = exit_rx.recv() => {
debug!("Client reader exiting");
break;
},
// Poll for incoming data from the stream and accumulate it into the buffer
r = reader.ready_internal() => match r {
Ok(Ready::READABLE) => {
if let Err(e) = Self::handle_read(&mut reader, &mut accumulator, &addr_, &in_tx).await {
error!("Error handling read for {addr_:?}: {e:?}");
rx_exit_tx.send(()).ok();
break;
}
},
Ok(r) if r.is_read_closed() => {
debug!("Stream for {addr_:?} closed");
break;
}
Ok(r) if r.is_error() => {
debug!("Stream for {addr_:?} encountered an error");
break;
}
Ok(r) => {
// Unexpected readiness state, continue polling
debug!("Unexpected readiness state for {addr_:?}: {r:?}");
continue;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// No more data to read
continue;
}
Err(e) => {
error!("Failed to read from stream: {e:?}");
break;
}
},
}
}
rx_exit_tx.send(()).ok();
drop(reader);
});
// Setup a task to handle writing events to the stream
let tx_exit_tx = exit_tx.clone();
let addr_ = addr.clone();
let _tx_handle = tokio::task::spawn(async move {
let mut exit_rx = tx_exit_tx.subscribe();
debug!("Writer task started for {addr_:?}");
loop {
select! {
biased;
_ = exit_rx.recv() => {
debug!("Client writer exiting");
break;
}
e = out_rx.recv() => match e {
Some(event) => {
// Serialize the event with the codec
let data = match C::encode(&event) {
Ok(data) => data,
Err(e) => {
error!("Failed to serialize event: {:?}", e);
continue;
}
};
// Write the event data to the stream
if let Err(e) = writer.write_all_internal(&data).await {
error!("Failed to write to stream: {:?}", e);
tx_exit_tx.send(()).ok();
break;
}
},
None => {
debug!("Failed to receive event, channel closed");
tx_exit_tx.send(()).ok();
break;
}
},
}
}
tx_exit_tx.send(()).ok();
drop(writer);
});
Ok(Self {
addr,
_rx_handle,
_tx_handle,
exit_tx,
_c: PhantomData,
_in: PhantomData,
_out: PhantomData,
})
}
async fn handle_read<READER: Readable>(
reader: &mut READER,
accumulator: &mut Vec<u8>,
addr_: &ADDR,
in_tx: &UnboundedSender<(IN, ADDR)>,
) -> Result<(), Error> {
let mut total = 0;
// Read new data from the stream into a temporary buffer
let mut buff = Vec::with_capacity(1 * 1024 * 1024);
'read: loop {
buff.clear();
let n = match reader.try_read_buf_internal(&mut buff) {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// No more data to read
break 'read;
}
Err(e) => {
error!("Failed to read from stream: {:?}", e);
return Err(Error::Io(e));
}
};
// Append the new data to the accumulator buffer for decoding
accumulator.extend_from_slice(&buff[..n]);
total += n;
if n < buff.capacity() {
// All available data has been read
break 'read;
}
}
if total == 0 {
return Ok(());
}
trace!("Read {total} bytes from stream");
trace!("Accumulated buffer size: {}", accumulator.len());
// Try to parse complete messages from the codec
loop {
match C::try_decode(accumulator) {
Ok(Some(cmd)) => {
debug!("Decoded message from {:?}: {:?}", addr_, cmd);
// Successfully parsed a complete message, forward it to the server
_ = in_tx.send((cmd, addr_.clone()));
}
Ok(None) => {
// Not enough data yet, wait for more
return Ok(());
}
Err(e) => {
error!("Failed to decode message: {:?}", e);
return Err(e);
}
}
}
}
/// Fetch the target address of the TCP connection
pub fn addr(&self) -> ADDR {
self.addr.clone()
}
/// Register a callback to be called when the connection is closed
pub fn on_closed<F: FnOnce() + Send + 'static>(&self, callback: F) {
let mut exit_rx = self.exit_tx.subscribe();
tokio::task::spawn(async move {
debug!("Registering on_closed callback");
let _ = exit_rx.recv().await;
debug!("Connection closed, executing callback");
callback();
});
}
/// Exit the internal tasks and close the TCP connection
pub fn close(self) -> Result<(), Error> {
let _ = self.exit_tx.send(());
Ok(())
}
}