1use std::convert::Infallible;
2use std::io::{Read, Write};
3use std::ops::Deref;
4use std::{fmt, io, time};
5
6use crossbeam_channel as chan;
7use radicle::node::config::FetchPackSizeLimit;
8use radicle::node::NodeId;
9
10use crate::runtime::Handle;
11use crate::wire::StreamId;
12
13pub const MAX_WORKER_CHANNEL_SIZE: usize = 64;
17
18#[derive(Clone, Copy, Debug)]
19pub struct ChannelsConfig {
20 timeout: time::Duration,
21 reader_limit: FetchPackSizeLimit,
22}
23
24impl ChannelsConfig {
25 pub fn new(timeout: time::Duration) -> Self {
26 Self {
27 timeout,
28 reader_limit: FetchPackSizeLimit::default(),
29 }
30 }
31
32 pub fn with_timeout(self, timeout: time::Duration) -> Self {
33 Self { timeout, ..self }
34 }
35
36 pub fn with_reader_limit(self, reader_limit: FetchPackSizeLimit) -> Self {
37 Self {
38 reader_limit,
39 ..self
40 }
41 }
42}
43
44pub struct ChannelsFlush {
49 receiver: ChannelReader,
50 sender: ChannelFlushWriter,
51}
52
53impl ChannelsFlush {
54 pub fn new(handle: Handle, channels: Channels, remote: NodeId, stream: StreamId) -> Self {
55 Self {
56 receiver: channels.receiver,
57 sender: ChannelFlushWriter {
58 writer: channels.sender,
59 stream,
60 handle,
61 remote,
62 },
63 }
64 }
65
66 pub fn split(&mut self) -> (&mut ChannelReader, &mut ChannelFlushWriter) {
67 (&mut self.receiver, &mut self.sender)
68 }
69
70 pub fn timeout(&self) -> time::Duration {
71 self.sender.writer.timeout.max(self.receiver.timeout)
72 }
73}
74
75impl radicle_fetch::transport::ConnectionStream for ChannelsFlush {
76 type Read = ChannelReader;
77 type Write = ChannelFlushWriter;
78 type Error = Infallible;
79
80 fn open(&mut self) -> Result<(&mut Self::Read, &mut Self::Write), Self::Error> {
81 Ok((&mut self.receiver, &mut self.sender))
82 }
83}
84
85pub enum ChannelEvent<T = Vec<u8>> {
87 Data(T),
89 Close,
91 Eof,
94}
95
96impl<T> From<T> for ChannelEvent<T> {
97 fn from(value: T) -> Self {
98 Self::Data(value)
99 }
100}
101
102impl<T> fmt::Debug for ChannelEvent<T> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 Self::Data(_) => write!(f, "ChannelEvent::Data(..)"),
106 Self::Close => write!(f, "ChannelEvent::Close"),
107 Self::Eof => write!(f, "ChannelEvent::Eof"),
108 }
109 }
110}
111
112pub struct Channels<T = Vec<u8>> {
114 sender: ChannelWriter<T>,
115 receiver: ChannelReader<T>,
116}
117
118impl<T: AsRef<[u8]>> Channels<T> {
119 pub fn new(
120 sender: chan::Sender<ChannelEvent<T>>,
121 receiver: chan::Receiver<ChannelEvent<T>>,
122 config: ChannelsConfig,
123 ) -> Self {
124 let sender = ChannelWriter {
125 sender,
126 timeout: config.timeout,
127 };
128 let receiver = ChannelReader::new(receiver, config.timeout, config.reader_limit);
129
130 Self { sender, receiver }
131 }
132
133 pub fn pair(config: ChannelsConfig) -> io::Result<(Channels<T>, Channels<T>)> {
134 let (l_send, r_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
135 let (r_send, l_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
136
137 let l = Channels::new(l_send, l_recv, config);
138 let r = Channels::new(r_send, r_recv, config);
139
140 Ok((l, r))
141 }
142
143 pub fn try_iter(&self) -> impl Iterator<Item = ChannelEvent<T>> + '_ {
144 self.receiver.try_iter()
145 }
146
147 pub fn send(&self, event: ChannelEvent<T>) -> io::Result<()> {
148 self.sender.send(event)
149 }
150
151 pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
152 self.sender.close()
153 }
154}
155
156#[derive(Clone, Copy, Debug)]
157pub struct ReadLimiter {
158 limit: FetchPackSizeLimit,
159 total_read: usize,
160}
161
162impl ReadLimiter {
163 pub fn new(limit: FetchPackSizeLimit) -> Self {
164 Self {
165 limit,
166 total_read: 0,
167 }
168 }
169
170 pub fn read(&mut self, bytes: usize) -> io::Result<()> {
171 self.total_read = self.total_read.saturating_add(bytes);
172 log::trace!(target: "worker", "limit {}, total bytes read: {}", self.limit, self.total_read);
173 if self.limit.exceeded_by(self.total_read) {
174 Err(io::Error::new(
175 io::ErrorKind::Other,
176 "sender has exceeded number of allowed bytes, aborting read",
177 ))
178 } else {
179 Ok(())
180 }
181 }
182}
183
184#[derive(Clone)]
186pub struct ChannelReader<T = Vec<u8>> {
187 buffer: io::Cursor<Vec<u8>>,
188 receiver: chan::Receiver<ChannelEvent<T>>,
189 timeout: time::Duration,
190 limiter: ReadLimiter,
191}
192
193impl<T> Deref for ChannelReader<T> {
194 type Target = chan::Receiver<ChannelEvent<T>>;
195
196 fn deref(&self) -> &Self::Target {
197 &self.receiver
198 }
199}
200
201impl<T: AsRef<[u8]>> ChannelReader<T> {
202 pub fn new(
203 receiver: chan::Receiver<ChannelEvent<T>>,
204 timeout: time::Duration,
205 limit: FetchPackSizeLimit,
206 ) -> Self {
207 Self {
208 buffer: io::Cursor::new(Vec::new()),
209 receiver,
210 timeout,
211 limiter: ReadLimiter::new(limit),
212 }
213 }
214}
215
216impl Read for ChannelReader<Vec<u8>> {
217 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
218 let read = self.buffer.read(buf)?;
219 self.limiter.read(read)?;
220 if read > 0 {
221 return Ok(read);
222 }
223
224 match self.receiver.recv_timeout(self.timeout) {
225 Ok(ChannelEvent::Data(data)) => {
226 self.buffer = io::Cursor::new(data);
227 self.buffer.read(buf)
228 }
229 Ok(ChannelEvent::Eof) => Err(io::ErrorKind::UnexpectedEof.into()),
230 Ok(ChannelEvent::Close) => Err(io::ErrorKind::ConnectionReset.into()),
231
232 Err(chan::RecvTimeoutError::Timeout) => Err(io::Error::new(
233 io::ErrorKind::TimedOut,
234 "error reading from stream: channel timed out",
235 )),
236 Err(chan::RecvTimeoutError::Disconnected) => Err(io::Error::new(
237 io::ErrorKind::BrokenPipe,
238 "error reading from stream: channel is disconnected",
239 )),
240 }
241 }
242}
243
244#[derive(Clone)]
246struct ChannelWriter<T = Vec<u8>> {
247 sender: chan::Sender<ChannelEvent<T>>,
248 timeout: time::Duration,
249}
250
251pub struct ChannelFlushWriter<T = Vec<u8>> {
257 writer: ChannelWriter<T>,
258 handle: Handle,
259 stream: StreamId,
260 remote: NodeId,
261}
262
263impl radicle_fetch::transport::SignalEof for ChannelFlushWriter<Vec<u8>> {
264 type Error = io::Error;
265
266 fn eof(&mut self) -> io::Result<()> {
267 self.writer.send(ChannelEvent::Eof)?;
268 self.flush()
269 }
270}
271
272impl Write for ChannelFlushWriter<Vec<u8>> {
273 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
274 let n = buf.len();
275 self.writer.send(buf.to_vec())?;
276 self.flush()?;
277 Ok(n)
278 }
279
280 fn flush(&mut self) -> io::Result<()> {
281 self.handle.flush(self.remote, self.stream)
282 }
283}
284
285impl<T: AsRef<[u8]>> ChannelWriter<T> {
286 pub fn send(&self, event: impl Into<ChannelEvent<T>>) -> io::Result<()> {
287 match self.sender.send_timeout(event.into(), self.timeout) {
288 Ok(()) => Ok(()),
289 Err(chan::SendTimeoutError::Timeout(_)) => Err(io::Error::new(
290 io::ErrorKind::TimedOut,
291 "error writing to stream: channel timed out",
292 )),
293 Err(chan::SendTimeoutError::Disconnected(_)) => Err(io::Error::new(
294 io::ErrorKind::BrokenPipe,
295 "error writing to stream: channel is disconnected",
296 )),
297 }
298 }
299
300 pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
302 self.sender.send(ChannelEvent::Close)
303 }
304}