iroh_blobs/util/
stream.rs1use std::{
2 future::Future,
3 io,
4 ops::{Deref, DerefMut},
5};
6
7use bytes::Bytes;
8use iroh::endpoint::{ReadExactError, VarInt};
9use iroh_io::AsyncStreamReader;
10use serde::{de::DeserializeOwned, Serialize};
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12
13pub trait SendStream: Send {
15 fn send_bytes(&mut self, bytes: Bytes) -> impl Future<Output = io::Result<()>> + Send;
19 fn send(&mut self, buf: &[u8]) -> impl Future<Output = io::Result<()>> + Send;
21 fn sync(&mut self) -> impl Future<Output = io::Result<()>> + Send;
23 fn reset(&mut self, code: VarInt) -> io::Result<()>;
25 fn stopped(&mut self) -> impl Future<Output = io::Result<Option<VarInt>>> + Send;
27 fn id(&self) -> u64;
29}
30
31pub trait RecvStream: Send {
33 fn recv_bytes(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>> + Send;
35 fn recv_bytes_exact(&mut self, len: usize) -> impl Future<Output = io::Result<Bytes>> + Send;
41 fn recv_exact(&mut self, target: &mut [u8]) -> impl Future<Output = io::Result<()>> + Send;
43 fn stop(&mut self, code: VarInt) -> io::Result<()>;
45 fn id(&self) -> u64;
47}
48
49impl SendStream for iroh::endpoint::SendStream {
50 async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
51 Ok(self.write_chunk(bytes).await?)
52 }
53
54 async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
55 Ok(self.write_all(buf).await?)
56 }
57
58 async fn sync(&mut self) -> io::Result<()> {
59 Ok(())
60 }
61
62 fn reset(&mut self, code: VarInt) -> io::Result<()> {
63 Ok(self.reset(code)?)
64 }
65
66 async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
67 Ok(self.stopped().await?)
68 }
69
70 fn id(&self) -> u64 {
71 self.id().index()
72 }
73}
74
75impl RecvStream for iroh::endpoint::RecvStream {
76 async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
77 let mut buf = vec![0; len];
78 match self.read_exact(&mut buf).await {
79 Err(ReadExactError::FinishedEarly(n)) => {
80 buf.truncate(n);
81 }
82 Err(ReadExactError::ReadError(e)) => {
83 return Err(e.into());
84 }
85 Ok(()) => {}
86 };
87 Ok(buf.into())
88 }
89
90 async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
91 let mut buf = vec![0; len];
92 self.read_exact(&mut buf).await.map_err(|e| match e {
93 ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""),
94 ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""),
95 ReadExactError::ReadError(e) => e.into(),
96 })?;
97 Ok(buf.into())
98 }
99
100 async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
101 self.read_exact(buf).await.map_err(|e| match e {
102 ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""),
103 ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""),
104 ReadExactError::ReadError(e) => e.into(),
105 })
106 }
107
108 fn stop(&mut self, code: VarInt) -> io::Result<()> {
109 Ok(self.stop(code)?)
110 }
111
112 fn id(&self) -> u64 {
113 self.id().index()
114 }
115}
116
117impl<R: RecvStream> RecvStream for &mut R {
118 async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
119 self.deref_mut().recv_bytes(len).await
120 }
121
122 async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
123 self.deref_mut().recv_bytes_exact(len).await
124 }
125
126 async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
127 self.deref_mut().recv_exact(buf).await
128 }
129
130 fn stop(&mut self, code: VarInt) -> io::Result<()> {
131 self.deref_mut().stop(code)
132 }
133
134 fn id(&self) -> u64 {
135 self.deref().id()
136 }
137}
138
139impl<W: SendStream> SendStream for &mut W {
140 async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
141 self.deref_mut().send_bytes(bytes).await
142 }
143
144 async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
145 self.deref_mut().send(buf).await
146 }
147
148 async fn sync(&mut self) -> io::Result<()> {
149 self.deref_mut().sync().await
150 }
151
152 fn reset(&mut self, code: VarInt) -> io::Result<()> {
153 self.deref_mut().reset(code)
154 }
155
156 async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
157 self.deref_mut().stopped().await
158 }
159
160 fn id(&self) -> u64 {
161 self.deref().id()
162 }
163}
164
165#[derive(Debug)]
166pub struct AsyncReadRecvStream<R>(R);
167
168pub trait AsyncReadRecvStreamExtra: Send {
172 fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send);
176 fn stop(&mut self, code: VarInt) -> io::Result<()>;
178 fn id(&self) -> u64;
182}
183
184impl<R> AsyncReadRecvStream<R> {
185 pub fn new(inner: R) -> Self {
186 Self(inner)
187 }
188}
189
190impl<R: AsyncReadRecvStreamExtra> RecvStream for AsyncReadRecvStream<R> {
191 async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
192 let mut res = vec![0; len];
193 let mut n = 0;
194 loop {
195 let read = self.0.inner().read(&mut res[n..]).await?;
196 if read == 0 {
197 res.truncate(n);
198 break;
199 }
200 n += read;
201 if n == len {
202 break;
203 }
204 }
205 Ok(res.into())
206 }
207
208 async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
209 let mut res = vec![0; len];
210 self.0.inner().read_exact(&mut res).await?;
211 Ok(res.into())
212 }
213
214 async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
215 self.0.inner().read_exact(buf).await?;
216 Ok(())
217 }
218
219 fn stop(&mut self, code: VarInt) -> io::Result<()> {
220 self.0.stop(code)
221 }
222
223 fn id(&self) -> u64 {
224 self.0.id()
225 }
226}
227
228impl RecvStream for Bytes {
229 async fn recv_bytes(&mut self, len: usize) -> io::Result<Bytes> {
230 let n = len.min(self.len());
231 let res = self.slice(..n);
232 *self = self.slice(n..);
233 Ok(res)
234 }
235
236 async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<Bytes> {
237 if self.len() < len {
238 return Err(io::ErrorKind::UnexpectedEof.into());
239 }
240 let res = self.slice(..len);
241 *self = self.slice(len..);
242 Ok(res)
243 }
244
245 async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
246 if self.len() < buf.len() {
247 return Err(io::ErrorKind::UnexpectedEof.into());
248 }
249 buf.copy_from_slice(&self[..buf.len()]);
250 *self = self.slice(buf.len()..);
251 Ok(())
252 }
253
254 fn stop(&mut self, _code: VarInt) -> io::Result<()> {
255 Ok(())
256 }
257
258 fn id(&self) -> u64 {
259 0
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct AsyncWriteSendStream<W>(W);
266
267pub trait AsyncWriteSendStreamExtra: Send {
273 fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send);
277 fn reset(&mut self, code: VarInt) -> io::Result<()>;
279 fn stopped(&mut self) -> impl Future<Output = io::Result<Option<VarInt>>> + Send;
281 fn id(&self) -> u64;
285}
286
287impl<W: AsyncWriteSendStreamExtra> AsyncWriteSendStream<W> {
288 pub fn new(inner: W) -> Self {
289 Self(inner)
290 }
291}
292
293impl<W: AsyncWriteSendStreamExtra> AsyncWriteSendStream<W> {
294 pub fn into_inner(self) -> W {
295 self.0
296 }
297}
298
299impl<W: AsyncWriteSendStreamExtra> SendStream for AsyncWriteSendStream<W> {
300 async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> {
301 self.0.inner().write_all(&bytes).await
302 }
303
304 async fn send(&mut self, buf: &[u8]) -> io::Result<()> {
305 self.0.inner().write_all(buf).await
306 }
307
308 async fn sync(&mut self) -> io::Result<()> {
309 self.0.inner().flush().await
310 }
311
312 fn reset(&mut self, code: VarInt) -> io::Result<()> {
313 self.0.reset(code)?;
314 Ok(())
315 }
316
317 async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
318 let res = self.0.stopped().await?;
319 Ok(res)
320 }
321
322 fn id(&self) -> u64 {
323 self.0.id()
324 }
325}
326
327#[derive(Debug)]
328pub struct RecvStreamAsyncStreamReader<R>(R);
329
330impl<R: RecvStream> RecvStreamAsyncStreamReader<R> {
331 pub fn new(inner: R) -> Self {
332 Self(inner)
333 }
334
335 pub fn into_inner(self) -> R {
336 self.0
337 }
338}
339
340impl<R: RecvStream> AsyncStreamReader for RecvStreamAsyncStreamReader<R> {
341 async fn read_bytes(&mut self, len: usize) -> io::Result<Bytes> {
342 self.0.recv_bytes_exact(len).await
343 }
344
345 async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
346 let mut buf = [0; L];
347 self.0.recv_exact(&mut buf).await?;
348 Ok(buf)
349 }
350}
351
352pub(crate) trait RecvStreamExt: RecvStream {
353 async fn expect_eof(&mut self) -> io::Result<()> {
354 match self.read_u8().await {
355 Ok(_) => Err(io::Error::new(
356 io::ErrorKind::InvalidData,
357 "unexpected data",
358 )),
359 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
360 Err(e) => Err(e),
361 }
362 }
363
364 async fn read_u8(&mut self) -> io::Result<u8> {
365 let mut buf = [0; 1];
366 self.recv_exact(&mut buf).await?;
367 Ok(buf[0])
368 }
369
370 async fn read_to_end_as<T: DeserializeOwned>(
371 &mut self,
372 max_size: usize,
373 ) -> io::Result<(T, usize)> {
374 let data = self.recv_bytes(max_size).await?;
375 self.expect_eof().await?;
376 let value = postcard::from_bytes(&data)
377 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
378 Ok((value, data.len()))
379 }
380
381 async fn read_length_prefixed<T: DeserializeOwned>(
382 &mut self,
383 max_size: usize,
384 ) -> io::Result<T> {
385 let Some(n) = self.read_varint_u64().await? else {
386 return Err(io::ErrorKind::UnexpectedEof.into());
387 };
388 if n > max_size as u64 {
389 return Err(io::Error::new(
390 io::ErrorKind::InvalidData,
391 "length prefix too large",
392 ));
393 }
394 let n = n as usize;
395 let data = self.recv_bytes(n).await?;
396 let value = postcard::from_bytes(&data)
397 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
398 Ok(value)
399 }
400
401 async fn read_varint_u64(&mut self) -> io::Result<Option<u64>> {
410 let mut result: u64 = 0;
411 let mut shift: u32 = 0;
412
413 loop {
414 if shift >= 64 {
416 return Err(io::Error::new(
417 io::ErrorKind::InvalidData,
418 "Varint is too large for u64",
419 ));
420 }
421
422 let res = self.read_u8().await;
424 if shift == 0 {
425 if let Err(cause) = res {
426 if cause.kind() == io::ErrorKind::UnexpectedEof {
427 return Ok(None);
428 } else {
429 return Err(cause);
430 }
431 }
432 }
433
434 let byte = res?;
435
436 let value = (byte & 0x7F) as u64;
438
439 result |= value << shift;
441
442 if byte & 0x80 == 0 {
444 break;
445 }
446
447 shift += 7;
449 }
450
451 Ok(Some(result))
452 }
453}
454
455impl<R: RecvStream> RecvStreamExt for R {}
456
457pub(crate) trait SendStreamExt: SendStream {
458 async fn write_length_prefixed<T: Serialize>(&mut self, value: T) -> io::Result<usize> {
459 let size = postcard::experimental::serialized_size(&value)
460 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
461 let mut buf = Vec::with_capacity(size + 9);
462 irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?;
463 let n = buf.len();
464 self.send_bytes(buf.into()).await?;
465 Ok(n)
466 }
467}
468
469impl<W: SendStream> SendStreamExt for W {}