radicle_node/worker/
channels.rs

1use std::convert::Infallible;
2use std::io::{Read, Write};
3use std::ops::Deref;
4use std::{fmt, io, time};
5
6use crossbeam_channel as chan;
7use radicle::node::config::FetchPackSizeLimit;
8use radicle::node::NodeId;
9
10use crate::runtime::Handle;
11use crate::wire::StreamId;
12
13/// Maximum size of channel used to communicate with a worker.
14/// Note that as long as we're using [`std::io::copy`] to copy data from the
15/// upload-pack's stdout, the data chunks are of a maximum size of 8192 bytes.
16pub const MAX_WORKER_CHANNEL_SIZE: usize = 64;
17
18#[derive(Clone, Copy, Debug)]
19pub struct ChannelsConfig {
20    timeout: time::Duration,
21    reader_limit: FetchPackSizeLimit,
22}
23
24impl ChannelsConfig {
25    pub fn new(timeout: time::Duration) -> Self {
26        Self {
27            timeout,
28            reader_limit: FetchPackSizeLimit::default(),
29        }
30    }
31
32    pub fn with_timeout(self, timeout: time::Duration) -> Self {
33        Self { timeout, ..self }
34    }
35
36    pub fn with_reader_limit(self, reader_limit: FetchPackSizeLimit) -> Self {
37        Self {
38            reader_limit,
39            ..self
40        }
41    }
42}
43
44/// A reader and writer pair that can be used in the fetch protocol.
45///
46/// It implements [`radicle::fetch::transport::ConnectionStream`] to
47/// provide its underlying channels for reading and writing.
48pub struct ChannelsFlush {
49    receiver: ChannelReader,
50    sender: ChannelFlushWriter,
51}
52
53impl ChannelsFlush {
54    pub fn new(handle: Handle, channels: Channels, remote: NodeId, stream: StreamId) -> Self {
55        Self {
56            receiver: channels.receiver,
57            sender: ChannelFlushWriter {
58                writer: channels.sender,
59                stream,
60                handle,
61                remote,
62            },
63        }
64    }
65
66    pub fn split(&mut self) -> (&mut ChannelReader, &mut ChannelFlushWriter) {
67        (&mut self.receiver, &mut self.sender)
68    }
69
70    pub fn timeout(&self) -> time::Duration {
71        self.sender.writer.timeout.max(self.receiver.timeout)
72    }
73}
74
75impl radicle_fetch::transport::ConnectionStream for ChannelsFlush {
76    type Read = ChannelReader;
77    type Write = ChannelFlushWriter;
78    type Error = Infallible;
79
80    fn open(&mut self) -> Result<(&mut Self::Read, &mut Self::Write), Self::Error> {
81        Ok((&mut self.receiver, &mut self.sender))
82    }
83}
84
85/// Data that can be sent and received on worker channels.
86pub enum ChannelEvent<T = Vec<u8>> {
87    /// Git protocol data.
88    Data(T),
89    /// A request to close the channel.
90    Close,
91    /// A signal that the git protocol has ended, eg. when the remote fetch closes the
92    /// connection.
93    Eof,
94}
95
96impl<T> From<T> for ChannelEvent<T> {
97    fn from(value: T) -> Self {
98        Self::Data(value)
99    }
100}
101
102impl<T> fmt::Debug for ChannelEvent<T> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        match self {
105            Self::Data(_) => write!(f, "ChannelEvent::Data(..)"),
106            Self::Close => write!(f, "ChannelEvent::Close"),
107            Self::Eof => write!(f, "ChannelEvent::Eof"),
108        }
109    }
110}
111
112/// Worker channels for communicating through the git stream with the remote.
113pub struct Channels<T = Vec<u8>> {
114    sender: ChannelWriter<T>,
115    receiver: ChannelReader<T>,
116}
117
118impl<T: AsRef<[u8]>> Channels<T> {
119    pub fn new(
120        sender: chan::Sender<ChannelEvent<T>>,
121        receiver: chan::Receiver<ChannelEvent<T>>,
122        config: ChannelsConfig,
123    ) -> Self {
124        let sender = ChannelWriter {
125            sender,
126            timeout: config.timeout,
127        };
128        let receiver = ChannelReader::new(receiver, config.timeout, config.reader_limit);
129
130        Self { sender, receiver }
131    }
132
133    pub fn pair(config: ChannelsConfig) -> io::Result<(Channels<T>, Channels<T>)> {
134        let (l_send, r_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
135        let (r_send, l_recv) = chan::bounded::<ChannelEvent<T>>(MAX_WORKER_CHANNEL_SIZE);
136
137        let l = Channels::new(l_send, l_recv, config);
138        let r = Channels::new(r_send, r_recv, config);
139
140        Ok((l, r))
141    }
142
143    pub fn try_iter(&self) -> impl Iterator<Item = ChannelEvent<T>> + '_ {
144        self.receiver.try_iter()
145    }
146
147    pub fn send(&self, event: ChannelEvent<T>) -> io::Result<()> {
148        self.sender.send(event)
149    }
150
151    pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
152        self.sender.close()
153    }
154}
155
156#[derive(Clone, Copy, Debug)]
157pub struct ReadLimiter {
158    limit: FetchPackSizeLimit,
159    total_read: usize,
160}
161
162impl ReadLimiter {
163    pub fn new(limit: FetchPackSizeLimit) -> Self {
164        Self {
165            limit,
166            total_read: 0,
167        }
168    }
169
170    pub fn read(&mut self, bytes: usize) -> io::Result<()> {
171        self.total_read = self.total_read.saturating_add(bytes);
172        log::trace!(target: "worker", "limit {}, total bytes read: {}", self.limit, self.total_read);
173        if self.limit.exceeded_by(self.total_read) {
174            Err(io::Error::new(
175                io::ErrorKind::Other,
176                "sender has exceeded number of allowed bytes, aborting read",
177            ))
178        } else {
179            Ok(())
180        }
181    }
182}
183
184/// Wraps a [`chan::Receiver`] and provides it with [`io::Read`].
185#[derive(Clone)]
186pub struct ChannelReader<T = Vec<u8>> {
187    buffer: io::Cursor<Vec<u8>>,
188    receiver: chan::Receiver<ChannelEvent<T>>,
189    timeout: time::Duration,
190    limiter: ReadLimiter,
191}
192
193impl<T> Deref for ChannelReader<T> {
194    type Target = chan::Receiver<ChannelEvent<T>>;
195
196    fn deref(&self) -> &Self::Target {
197        &self.receiver
198    }
199}
200
201impl<T: AsRef<[u8]>> ChannelReader<T> {
202    pub fn new(
203        receiver: chan::Receiver<ChannelEvent<T>>,
204        timeout: time::Duration,
205        limit: FetchPackSizeLimit,
206    ) -> Self {
207        Self {
208            buffer: io::Cursor::new(Vec::new()),
209            receiver,
210            timeout,
211            limiter: ReadLimiter::new(limit),
212        }
213    }
214}
215
216impl Read for ChannelReader<Vec<u8>> {
217    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
218        let read = self.buffer.read(buf)?;
219        self.limiter.read(read)?;
220        if read > 0 {
221            return Ok(read);
222        }
223
224        match self.receiver.recv_timeout(self.timeout) {
225            Ok(ChannelEvent::Data(data)) => {
226                self.buffer = io::Cursor::new(data);
227                self.buffer.read(buf)
228            }
229            Ok(ChannelEvent::Eof) => Err(io::ErrorKind::UnexpectedEof.into()),
230            Ok(ChannelEvent::Close) => Err(io::ErrorKind::ConnectionReset.into()),
231
232            Err(chan::RecvTimeoutError::Timeout) => Err(io::Error::new(
233                io::ErrorKind::TimedOut,
234                "error reading from stream: channel timed out",
235            )),
236            Err(chan::RecvTimeoutError::Disconnected) => Err(io::Error::new(
237                io::ErrorKind::BrokenPipe,
238                "error reading from stream: channel is disconnected",
239            )),
240        }
241    }
242}
243
244/// Wraps a [`chan::Sender`] and provides it with [`io::Write`].
245#[derive(Clone)]
246struct ChannelWriter<T = Vec<u8>> {
247    sender: chan::Sender<ChannelEvent<T>>,
248    timeout: time::Duration,
249}
250
251/// Wraps a [`ChannelWriter`] alongside the associated [`Handle`] and [`NodeId`].
252///
253/// This allows the channel to [`Write::flush`] when calling
254/// [`Write::write`], which is necessary to signal to the
255/// controller to send the wire data.
256pub struct ChannelFlushWriter<T = Vec<u8>> {
257    writer: ChannelWriter<T>,
258    handle: Handle,
259    stream: StreamId,
260    remote: NodeId,
261}
262
263impl radicle_fetch::transport::SignalEof for ChannelFlushWriter<Vec<u8>> {
264    type Error = io::Error;
265
266    fn eof(&mut self) -> io::Result<()> {
267        self.writer.send(ChannelEvent::Eof)?;
268        self.flush()
269    }
270}
271
272impl Write for ChannelFlushWriter<Vec<u8>> {
273    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
274        let n = buf.len();
275        self.writer.send(buf.to_vec())?;
276        self.flush()?;
277        Ok(n)
278    }
279
280    fn flush(&mut self) -> io::Result<()> {
281        self.handle.flush(self.remote, self.stream)
282    }
283}
284
285impl<T: AsRef<[u8]>> ChannelWriter<T> {
286    pub fn send(&self, event: impl Into<ChannelEvent<T>>) -> io::Result<()> {
287        match self.sender.send_timeout(event.into(), self.timeout) {
288            Ok(()) => Ok(()),
289            Err(chan::SendTimeoutError::Timeout(_)) => Err(io::Error::new(
290                io::ErrorKind::TimedOut,
291                "error writing to stream: channel timed out",
292            )),
293            Err(chan::SendTimeoutError::Disconnected(_)) => Err(io::Error::new(
294                io::ErrorKind::BrokenPipe,
295                "error writing to stream: channel is disconnected",
296            )),
297        }
298    }
299
300    /// Permanently close this stream.
301    pub fn close(self) -> Result<(), chan::SendError<ChannelEvent<T>>> {
302        self.sender.send(ChannelEvent::Close)
303    }
304}