xet_runtime/utils/
pipe.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use futures::Stream;
7use tokio::io::AsyncWrite;
8use tokio::sync::mpsc;
9use tokio::sync::mpsc::error::TrySendError;
10use tokio_util::io::StreamReader;
11
12pub fn pipe(buffer_size: usize) -> (ChannelWriter, ChannelStream) {
13 let (sender, receiver) = mpsc::channel(buffer_size);
14 (ChannelWriter::new(sender), ChannelStream::new(receiver))
15}
16
17pub struct ChannelStream(mpsc::Receiver<io::Result<Bytes>>);
19
20impl ChannelStream {
21 fn new(rx: mpsc::Receiver<io::Result<Bytes>>) -> Self {
22 Self(rx)
23 }
24
25 pub fn reader(self) -> ChannelReader {
26 ChannelReader::new(self)
27 }
28}
29
30impl Stream for ChannelStream {
31 type Item = io::Result<Bytes>;
32
33 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34 self.0.poll_recv(cx)
35 }
36}
37
38type ChannelReader = StreamReader<ChannelStream, Bytes>;
39
40pub struct ChannelWriter(mpsc::Sender<io::Result<Bytes>>);
42
43impl ChannelWriter {
44 fn new(tx: mpsc::Sender<io::Result<Bytes>>) -> Self {
45 Self(tx)
46 }
47}
48
49impl AsyncWrite for ChannelWriter {
50 fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
51 let perm = match self.0.try_reserve() {
52 Ok(p) => p,
53 Err(TrySendError::Closed(_)) => {
54 return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, "receiver closed")));
55 },
56 Err(TrySendError::Full(_)) => return Poll::Pending,
57 };
58
59 let data = Bytes::copy_from_slice(buf);
60 let len = data.len();
61 perm.send(Ok(data));
62
63 Poll::Ready(Ok(len))
64 }
65
66 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67 Poll::Ready(Ok(()))
69 }
70
71 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 Poll::Ready(Ok(()))
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use tokio::io::{AsyncReadExt, AsyncWriteExt};
80
81 use super::*;
82
83 #[tokio::test]
84 async fn test_channel_read_write() {
85 let (mut writer, stream) = pipe(10);
86 let mut reader = stream.reader();
87
88 writer.write_all(b"Hello, ").await.unwrap();
90 writer.write_all(b"World!").await.unwrap();
91
92 drop(writer);
94
95 let mut buf = Vec::new();
97 reader.read_to_end(&mut buf).await.unwrap();
98
99 assert_eq!(buf, b"Hello, World!");
100 }
101}