1use std::io::{Read, Write};
2use std::ops::Deref;
3use std::{fmt, io, time};
4
5use crossbeam_channel as chan;
6use radicle::node::config::FetchPackSizeLimit;
7use radicle::node::NodeId;
8
9use crate::runtime::Handle;
10use crate::wire::StreamId;
11
12pub const MAX_WORKER_CHANNEL_SIZE: usize = 64;
16
17#[derive(Clone, Copy, Debug)]
18pub struct ChannelsConfig {
19 timeout: time::Duration,
20 reader_limit: FetchPackSizeLimit,
21}
22
23impl ChannelsConfig {
24 pub fn new(timeout: time::Duration) -> Self {
25 Self {
26 timeout,
27 reader_limit: FetchPackSizeLimit::default(),
28 }
29 }
30
31 pub fn with_timeout(self, timeout: time::Duration) -> Self {
32 Self { timeout, ..self }
33 }
34
35 pub fn with_reader_limit(self, reader_limit: FetchPackSizeLimit) -> Self {
36 Self {
37 reader_limit,
38 ..self
39 }
40 }
41}
42
43pub struct ChannelsFlush {
48 receiver: ChannelReader,
49 sender: ChannelFlushWriter,
50}
51
52impl ChannelsFlush {
53 pub fn new(handle: Handle, channels: Channels, remote: NodeId, stream: StreamId) -> Self {
54 Self {
55 receiver: channels.receiver,
56 sender: ChannelFlushWriter {
57 writer: channels.sender,
58 stream,
59 handle,
60 remote,
61 },
62 }
63 }
64
65 pub fn split(&mut self) -> (&mut ChannelReader, &mut ChannelFlushWriter) {
66 (&mut self.receiver, &mut self.sender)
67 }
68
69 pub fn timeout(&self) -> time::Duration {
70 self.sender.writer.timeout.max(self.receiver.timeout)
71 }
72}
73
74impl radicle_fetch::transport::ConnectionStream for ChannelsFlush {
75 type Read = ChannelReader;
76 type Write = ChannelFlushWriter;
77
78 fn open(&mut self) -> (&mut Self::Read, &mut Self::Write) {
79 (&mut self.receiver, &mut self.sender)
80 }
81}
82
83pub enum ChannelEvent<T = Vec<u8>> {
85 Data(T),
87 Close,
89 Eof,
92}
93
94impl<T> From<T> for ChannelEvent<T> {
95 fn from(value: T) -> Self {
96 Self::Data(value)
97 }
98}
99
100impl<T> fmt::Debug for ChannelEvent<T> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 Self::Data(_) => write!(f, "ChannelEvent::Data(..)"),
104 Self::Close => write!(f, "ChannelEvent::Close"),
105 Self::Eof => write!(f, "ChannelEvent::Eof"),
106 }
107 }
108}
109
110pub struct Channels<T = Vec<u8>> {
112 sender: ChannelWriter<T>,
113 receiver: ChannelReader<T>,
114}
115
116impl<T: AsRef<[u8]>> Channels<T> {
117 pub fn new(
118 sender: chan::Sender<ChannelEvent<T>>,
119 receiver: chan::Receiver<ChannelEvent<T>>,
120 config: ChannelsConfig,
121 ) -> Self {
122 let sender = ChannelWriter {
123 sender,
124 timeout: config.timeout,
125 };
126 let receiver = ChannelReader::new(receiver, config.timeout, config.reader_limit);
127
128 Self { sender, receiver }
129 }
130
131 pub fn pair(config: ChannelsConfig) -> io::Result<(Channels<T>, Channels<T>)> {
132 let (l_send, r_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
133 let (r_send, l_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
134
135 let l = Channels::new(l_send, l_recv, config);
136 let r = Channels::new(r_send, r_recv, config);
137
138 Ok((l, r))
139 }
140
141 pub fn try_iter(&self) -> impl Iterator<Item = ChannelEvent<T>> + '_ {
142 self.receiver.try_iter()
143 }
144
145 pub fn send(&self, event: ChannelEvent<T>) -> io::Result<()> {
146 self.sender.send(event)
147 }
148
149 pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
150 self.sender.close()
151 }
152}
153
154#[derive(Clone, Copy, Debug)]
155pub struct ReadLimiter {
156 limit: FetchPackSizeLimit,
157 total_read: usize,
158}
159
160impl ReadLimiter {
161 pub fn new(limit: FetchPackSizeLimit) -> Self {
162 Self {
163 limit,
164 total_read: 0,
165 }
166 }
167
168 pub fn read(&mut self, bytes: usize) -> io::Result<()> {
169 self.total_read = self.total_read.saturating_add(bytes);
170 log::trace!(target: "worker", "limit {}, total bytes read: {}", self.limit, self.total_read);
171 if self.limit.exceeded_by(self.total_read) {
172 Err(io::Error::other(
173 "sender has exceeded number of allowed bytes, aborting read",
174 ))
175 } else {
176 Ok(())
177 }
178 }
179}
180
181#[derive(Clone)]
183pub struct ChannelReader<T = Vec<u8>> {
184 buffer: io::Cursor<Vec<u8>>,
185 receiver: chan::Receiver<ChannelEvent<T>>,
186 timeout: time::Duration,
187 limiter: ReadLimiter,
188}
189
190impl<T> Deref for ChannelReader<T> {
191 type Target = chan::Receiver<ChannelEvent<T>>;
192
193 fn deref(&self) -> &Self::Target {
194 &self.receiver
195 }
196}
197
198impl<T: AsRef<[u8]>> ChannelReader<T> {
199 pub fn new(
200 receiver: chan::Receiver<ChannelEvent<T>>,
201 timeout: time::Duration,
202 limit: FetchPackSizeLimit,
203 ) -> Self {
204 Self {
205 buffer: io::Cursor::new(Vec::new()),
206 receiver,
207 timeout,
208 limiter: ReadLimiter::new(limit),
209 }
210 }
211}
212
213impl Read for ChannelReader<Vec<u8>> {
214 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
215 let read = self.buffer.read(buf)?;
216 self.limiter.read(read)?;
217 if read > 0 {
218 return Ok(read);
219 }
220
221 match self.receiver.recv_timeout(self.timeout) {
222 Ok(ChannelEvent::Data(data)) => {
223 self.buffer = io::Cursor::new(data);
224 self.buffer.read(buf)
225 }
226 Ok(ChannelEvent::Eof) => Err(io::ErrorKind::UnexpectedEof.into()),
227 Ok(ChannelEvent::Close) => Err(io::ErrorKind::ConnectionReset.into()),
228
229 Err(chan::RecvTimeoutError::Timeout) => Err(io::Error::new(
230 io::ErrorKind::TimedOut,
231 "error reading from stream: channel timed out",
232 )),
233 Err(chan::RecvTimeoutError::Disconnected) => Err(io::Error::new(
234 io::ErrorKind::BrokenPipe,
235 "error reading from stream: channel is disconnected",
236 )),
237 }
238 }
239}
240
241#[derive(Clone)]
243struct ChannelWriter<T = Vec<u8>> {
244 sender: chan::Sender<ChannelEvent<T>>,
245 timeout: time::Duration,
246}
247
248pub struct ChannelFlushWriter<T = Vec<u8>> {
254 writer: ChannelWriter<T>,
255 handle: Handle,
256 stream: StreamId,
257 remote: NodeId,
258}
259
260impl radicle_fetch::transport::SignalEof for ChannelFlushWriter<Vec<u8>> {
261 fn eof(&mut self) -> io::Result<()> {
262 self.writer.send(ChannelEvent::Eof)?;
263 self.flush()
264 }
265}
266
267impl Write for ChannelFlushWriter<Vec<u8>> {
268 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
269 let n = buf.len();
270 self.writer.send(buf.to_vec())?;
271 self.flush()?;
272 Ok(n)
273 }
274
275 fn flush(&mut self) -> io::Result<()> {
276 self.handle.flush(self.remote, self.stream)
277 }
278}
279
280impl<T: AsRef<[u8]>> ChannelWriter<T> {
281 pub fn send(&self, event: impl Into<ChannelEvent<T>>) -> io::Result<()> {
282 match self.sender.send_timeout(event.into(), self.timeout) {
283 Ok(()) => Ok(()),
284 Err(chan::SendTimeoutError::Timeout(_)) => Err(io::Error::new(
285 io::ErrorKind::TimedOut,
286 "error writing to stream: channel timed out",
287 )),
288 Err(chan::SendTimeoutError::Disconnected(_)) => Err(io::Error::new(
289 io::ErrorKind::BrokenPipe,
290 "error writing to stream: channel is disconnected",
291 )),
292 }
293 }
294
295 pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
297 self.sender.send(ChannelEvent::Close)
298 }
299}