iroh_blobs/util/
stream.rs

1use std::{
2    future::Future,
3    io,
4    ops::{Deref, DerefMut},
5};
6
7use bytes::Bytes;
8use iroh::endpoint::{ReadExactError, VarInt};
9use iroh_io::AsyncStreamReader;
10use serde::{de::DeserializeOwned, Serialize};
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12
13/// An abstract `iroh::endpoint::SendStream`.
14pub trait SendStream: Send {
15    /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them.
16    ///
17    /// This method is not cancellation safe. Even if this does not resolve, some bytes may have been written when previously polled.
18    fn send_bytes(&mut self, bytes: Bytes) -> impl Future<Output = io::Result<()>> + Send;
19    /// Send that sends a fixed sized buffer.
20    fn send(&mut self, buf: &[u8]) -> impl Future<Output = io::Result<()>> + Send;
21    /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression.
22    fn sync(&mut self) -> impl Future<Output = io::Result<()>> + Send;
23    /// Reset the stream with the given error code.
24    fn reset(&mut self, code: VarInt) -> io::Result<()>;
25    /// Wait for the stream to be stopped, returning the error code if it was.
26    fn stopped(&mut self) -> impl Future<Output = io::Result<Option<VarInt>>> + Send;
27    /// Get the stream id.
28    fn id(&self) -> u64;
29}
30
31/// An abstract `iroh::endpoint::RecvStream`.
32pub trait RecvStream: Send {
33    /// Receive up to `len` bytes from the stream, directly into a `Bytes`.
34    fn recv_bytes(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>> + Send;
35    /// Receive exactly `len` bytes from the stream, directly into a `Bytes`.
36    ///
37    /// This will return an error if the stream ends before `len` bytes are read.
38    ///
39    /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends.
40    fn recv_bytes_exact(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>> + Send;
41    /// Receive exactly `target.len()` bytes from the stream.
42    fn recv_exact(&mut self, target: &mut [u8]) -> impl Future<Output = io::Result<()>> + Send;
43    /// Stop the stream with the given error code.
44    fn stop(&mut self, code: VarInt) -> io::Result<()>;
45    /// Get the stream id.
46    fn id(&self) -> u64;
47}
48
49impl SendStream for iroh::endpoint::SendStream {
50    async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
51        Ok(self.write_chunk(bytes).await?)
52    }
53
54    async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
55        Ok(self.write_all(buf).await?)
56    }
57
58    async fn sync(&mut self) -> io::Result<()> {
59        Ok(())
60    }
61
62    fn reset(&mut self, code: VarInt) -> io::Result<()> {
63        Ok(self.reset(code)?)
64    }
65
66    async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
67        Ok(self.stopped().await?)
68    }
69
70    fn id(&self) -> u64 {
71        self.id().index()
72    }
73}
74
75impl RecvStream for iroh::endpoint::RecvStream {
76    async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
77        let mut buf = vec![0; len];
78        match self.read_exact(&mut buf).await {
79            Err(ReadExactError::FinishedEarly(n)) => {
80                buf.truncate(n);
81            }
82            Err(ReadExactError::ReadError(e)) => {
83                return Err(e.into());
84            }
85            Ok(()) => {}
86        };
87        Ok(buf.into())
88    }
89
90    async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
91        let mut buf = vec![0; len];
92        self.read_exact(&mut buf).await.map_err(|e| match e {
93            ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""),
94            ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""),
95            ReadExactError::ReadError(e) => e.into(),
96        })?;
97        Ok(buf.into())
98    }
99
100    async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
101        self.read_exact(buf).await.map_err(|e| match e {
102            ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""),
103            ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""),
104            ReadExactError::ReadError(e) => e.into(),
105        })
106    }
107
108    fn stop(&mut self, code: VarInt) -> io::Result<()> {
109        Ok(self.stop(code)?)
110    }
111
112    fn id(&self) -> u64 {
113        self.id().index()
114    }
115}
116
117impl<R: RecvStream> RecvStream for &mut R {
118    async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
119        self.deref_mut().recv_bytes(len).await
120    }
121
122    async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
123        self.deref_mut().recv_bytes_exact(len).await
124    }
125
126    async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
127        self.deref_mut().recv_exact(buf).await
128    }
129
130    fn stop(&mut self, code: VarInt) -> io::Result<()> {
131        self.deref_mut().stop(code)
132    }
133
134    fn id(&self) -> u64 {
135        self.deref().id()
136    }
137}
138
139impl<W: SendStream> SendStream for &mut W {
140    async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
141        self.deref_mut().send_bytes(bytes).await
142    }
143
144    async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
145        self.deref_mut().send(buf).await
146    }
147
148    async fn sync(&mut self) -> io::Result<()> {
149        self.deref_mut().sync().await
150    }
151
152    fn reset(&mut self, code: VarInt) -> io::Result<()> {
153        self.deref_mut().reset(code)
154    }
155
156    async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
157        self.deref_mut().stopped().await
158    }
159
160    fn id(&self) -> u64 {
161        self.deref().id()
162    }
163}
164
165#[derive(Debug)]
166pub struct AsyncReadRecvStream<R>(R);
167
168/// This is a helper trait to work with [`AsyncReadRecvStream`]. If you have an
169/// `AsyncRead + Unpin + Send`, you can implement these additional methods and wrap the result
170/// in an `AsyncReadRecvStream` to get a `RecvStream` that reads from the underlying `AsyncRead`.
171pub trait AsyncReadRecvStreamExtra: Send {
172    /// Get a mutable reference to the inner `AsyncRead`.
173    ///
174    /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance.
175    fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send);
176    /// Stop the stream with the given error code.
177    fn stop(&mut self, code: VarInt) -> io::Result<()>;
178    /// A local unique identifier for the stream.
179    ///
180    /// This allows distinguishing between streams, but once the stream is closed, the id may be reused.
181    fn id(&self) -> u64;
182}
183
184impl<R> AsyncReadRecvStream<R> {
185    pub fn new(inner: R) -> Self {
186        Self(inner)
187    }
188}
189
190impl<R: AsyncReadRecvStreamExtra> RecvStream for AsyncReadRecvStream<R> {
191    async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
192        let mut res = vec![0; len];
193        let mut n = 0;
194        loop {
195            let read = self.0.inner().read(&mut res[n..]).await?;
196            if read == 0 {
197                res.truncate(n);
198                break;
199            }
200            n += read;
201            if n == len {
202                break;
203            }
204        }
205        Ok(res.into())
206    }
207
208    async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
209        let mut res = vec![0; len];
210        self.0.inner().read_exact(&mut res).await?;
211        Ok(res.into())
212    }
213
214    async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
215        self.0.inner().read_exact(buf).await?;
216        Ok(())
217    }
218
219    fn stop(&mut self, code: VarInt) -> io::Result<()> {
220        self.0.stop(code)
221    }
222
223    fn id(&self) -> u64 {
224        self.0.id()
225    }
226}
227
228impl RecvStream for Bytes {
229    async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
230        let n = len.min(self.len());
231        let res = self.slice(..n);
232        *self = self.slice(n..);
233        Ok(res)
234    }
235
236    async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
237        if self.len() < len {
238            return Err(io::ErrorKind::UnexpectedEof.into());
239        }
240        let res = self.slice(..len);
241        *self = self.slice(len..);
242        Ok(res)
243    }
244
245    async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
246        if self.len() < buf.len() {
247            return Err(io::ErrorKind::UnexpectedEof.into());
248        }
249        buf.copy_from_slice(&self[..buf.len()]);
250        *self = self.slice(buf.len()..);
251        Ok(())
252    }
253
254    fn stop(&mut self, _code: VarInt) -> io::Result<()> {
255        Ok(())
256    }
257
258    fn id(&self) -> u64 {
259        0
260    }
261}
262
263/// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream].
264#[derive(Debug, Clone)]
265pub struct AsyncWriteSendStream<W>(W);
266
267/// This is a helper trait to work with [`AsyncWriteSendStream`].
268///
269/// If you have an `AsyncWrite + Unpin + Send`, you can implement these additional
270/// methods and wrap the result in an `AsyncWriteSendStream` to get a `SendStream`
271/// that writes to the underlying `AsyncWrite`.
272pub trait AsyncWriteSendStreamExtra: Send {
273    /// Get a mutable reference to the inner `AsyncWrite`.
274    ///
275    /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance.
276    fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send);
277    /// Reset the stream with the given error code.
278    fn reset(&mut self, code: VarInt) -> io::Result<()>;
279    /// Wait for the stream to be stopped, returning the optional error code if it was.
280    fn stopped(&mut self) -> impl Future<Output = io::Result<Option<VarInt>>> + Send;
281    /// A local unique identifier for the stream.
282    ///
283    /// This allows distinguishing between streams, but once the stream is closed, the id may be reused.
284    fn id(&self) -> u64;
285}
286
287impl<W: AsyncWriteSendStreamExtra> AsyncWriteSendStream<W> {
288    pub fn new(inner: W) -> Self {
289        Self(inner)
290    }
291}
292
293impl<W: AsyncWriteSendStreamExtra> AsyncWriteSendStream<W> {
294    pub fn into_inner(self) -> W {
295        self.0
296    }
297}
298
299impl<W: AsyncWriteSendStreamExtra> SendStream for AsyncWriteSendStream<W> {
300    async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
301        self.0.inner().write_all(&bytes).await
302    }
303
304    async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
305        self.0.inner().write_all(buf).await
306    }
307
308    async fn sync(&mut self) -> io::Result<()> {
309        self.0.inner().flush().await
310    }
311
312    fn reset(&mut self, code: VarInt) -> io::Result<()> {
313        self.0.reset(code)?;
314        Ok(())
315    }
316
317    async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
318        let res = self.0.stopped().await?;
319        Ok(res)
320    }
321
322    fn id(&self) -> u64 {
323        self.0.id()
324    }
325}
326
327#[derive(Debug)]
328pub struct RecvStreamAsyncStreamReader<R>(R);
329
330impl<R: RecvStream> RecvStreamAsyncStreamReader<R> {
331    pub fn new(inner: R) -> Self {
332        Self(inner)
333    }
334
335    pub fn into_inner(self) -> R {
336        self.0
337    }
338}
339
340impl<R: RecvStream> AsyncStreamReader for RecvStreamAsyncStreamReader<R> {
341    async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
342        self.0.recv_bytes_exact(len).await
343    }
344
345    async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
346        let mut buf = [0; L];
347        self.0.recv_exact(&mut buf).await?;
348        Ok(buf)
349    }
350}
351
352pub(crate) trait RecvStreamExt: RecvStream {
353    async fn expect_eof(&mut self) -> io::Result<()> {
354        match self.read_u8().await {
355            Ok(_) => Err(io::Error::new(
356                io::ErrorKind::InvalidData,
357                "unexpected data",
358            )),
359            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
360            Err(e) => Err(e),
361        }
362    }
363
364    async fn read_u8(&mut self) -> io::Result<u8> {
365        let mut buf = [0; 1];
366        self.recv_exact(&mut buf).await?;
367        Ok(buf[0])
368    }
369
370    async fn read_to_end_as<T: DeserializeOwned>(
371        &mut self,
372        max_size: usize,
373    ) -> io::Result<(T, usize)> {
374        let data = self.recv_bytes(max_size).await?;
375        self.expect_eof().await?;
376        let value = postcard::from_bytes(&data)
377            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
378        Ok((value, data.len()))
379    }
380
381    async fn read_length_prefixed<T: DeserializeOwned>(
382        &mut self,
383        max_size: usize,
384    ) -> io::Result<T> {
385        let Some(n) = self.read_varint_u64().await? else {
386            return Err(io::ErrorKind::UnexpectedEof.into());
387        };
388        if n > max_size as u64 {
389            return Err(io::Error::new(
390                io::ErrorKind::InvalidData,
391                "length prefix too large",
392            ));
393        }
394        let n = n as usize;
395        let data = self.recv_bytes(n).await?;
396        let value = postcard::from_bytes(&data)
397            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
398        Ok(value)
399    }
400
401    /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format.
402    ///
403    /// In Postcard's varint format (LEB128):
404    /// - Each byte uses 7 bits for the value
405    /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0)
406    /// - Values are stored in little-endian order (least significant group first)
407    ///
408    /// Returns the decoded u64 value.
409    async fn read_varint_u64(&mut self) -> io::Result<Option<u64>> {
410        let mut result: u64 = 0;
411        let mut shift: u32 = 0;
412
413        loop {
414            // We can only shift up to 63 bits (for a u64)
415            if shift >= 64 {
416                return Err(io::Error::new(
417                    io::ErrorKind::InvalidData,
418                    "Varint is too large for u64",
419                ));
420            }
421
422            // Read a single byte
423            let res = self.read_u8().await;
424            if shift == 0 {
425                if let Err(cause) = res {
426                    if cause.kind() == io::ErrorKind::UnexpectedEof {
427                        return Ok(None);
428                    } else {
429                        return Err(cause);
430                    }
431                }
432            }
433
434            let byte = res?;
435
436            // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit)
437            let value = (byte & 0x7F) as u64;
438
439            // Add the bits to our result at the current shift position
440            result |= value << shift;
441
442            // If the high bit is not set (0), this is the last byte
443            if byte & 0x80 == 0 {
444                break;
445            }
446
447            // Move to the next 7 bits
448            shift += 7;
449        }
450
451        Ok(Some(result))
452    }
453}
454
455impl<R: RecvStream> RecvStreamExt for R {}
456
457pub(crate) trait SendStreamExt: SendStream {
458    async fn write_length_prefixed<T: Serialize>(&mut self, value: T) -> io::Result<usize> {
459        let size = postcard::experimental::serialized_size(&value)
460            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
461        let mut buf = Vec::with_capacity(size + 9);
462        irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?;
463        let n = buf.len();
464        self.send_bytes(buf.into()).await?;
465        Ok(n)
466    }
467}
468
469impl<W: SendStream> SendStreamExt for W {}