fluke_maybe_uring/io/
chan.rs

1use std::{cell::RefCell, rc::Rc};
2
3use crate::{
4    buf::{IoBuf, IoBufMut},
5    io::WriteOwned,
6    BufResult,
7};
8use tokio::sync::mpsc;
9
10use super::ReadOwned;
11
12/// Allows sending `Vec<u8>` chunks, which can be read through its [ReadOwned]
13/// implementation.
14pub struct ChanRead {
15    inner: Rc<ChanReadInner>,
16}
17
18pub struct ChanReadSend {
19    inner: Rc<ChanReadInner>,
20}
21
22struct ChanReadInner {
23    notify: tokio::sync::Notify,
24    guarded: RefCell<ChanReadGuarded>,
25}
26
27struct ChanReadGuarded {
28    state: ChanReadState,
29    pos: usize,
30    buf: Vec<u8>,
31}
32
33enum ChanReadState {
34    // Data may still come in
35    Live,
36
37    // [ChanReaderSend] was dropped, no more data is coming
38    Eof,
39
40    // [ChanReaderSend::rest] was called
41    Reset,
42}
43
44impl ChanRead {
45    pub fn new() -> (ChanReadSend, Self) {
46        let inner = Rc::new(ChanReadInner {
47            notify: Default::default(),
48            guarded: RefCell::new(ChanReadGuarded {
49                state: ChanReadState::Live,
50                pos: 0,
51                buf: Vec::new(),
52            }),
53        });
54        (
55            ChanReadSend {
56                inner: inner.clone(),
57            },
58            Self { inner },
59        )
60    }
61}
62
63impl ChanReadSend {
64    /// Sever this connection abnormally - read will eventually return [std::io::ErrorKind::ConnectionReset]
65    pub fn reset(self) {
66        let mut guarded = self.inner.guarded.borrow_mut();
67        guarded.state = ChanReadState::Reset;
68        // let it drop, which will notify waiters
69    }
70
71    /// Send a chunk of data. Readers will not be able to read _more_ than the
72    /// length of this chunk in a single call, but may read less (if their buffer
73    /// is too small).
74    pub async fn send(&self, next_buf: impl Into<Vec<u8>>) -> Result<(), std::io::Error> {
75        let next_buf = next_buf.into();
76
77        loop {
78            {
79                let mut guarded = self.inner.guarded.borrow_mut();
80                match guarded.state {
81                    ChanReadState::Live => {
82                        if guarded.pos == guarded.buf.len() {
83                            guarded.pos = 0;
84                            guarded.buf = next_buf;
85                            self.inner.notify.notify_waiters();
86                            return Ok(());
87                        } else {
88                            // wait for read
89                        }
90                    }
91
92                    // can't send after dropping
93                    ChanReadState::Eof => unreachable!(),
94
95                    // can't send after calling abort
96                    ChanReadState::Reset => unreachable!(),
97                }
98            }
99            self.inner.notify.notified().await
100        }
101    }
102}
103
104impl Drop for ChanReadSend {
105    fn drop(&mut self) {
106        let mut guarded = self.inner.guarded.borrow_mut();
107        if let ChanReadState::Live = guarded.state {
108            guarded.state = ChanReadState::Eof;
109        }
110        self.inner.notify.notify_waiters();
111    }
112}
113
114impl ReadOwned for ChanRead {
115    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
116        let out =
117            unsafe { std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) };
118
119        loop {
120            {
121                let mut guarded = self.inner.guarded.borrow_mut();
122                let remain = guarded.buf.len() - guarded.pos;
123
124                if remain > 0 {
125                    let n = std::cmp::min(remain, out.len());
126
127                    out[..n].copy_from_slice(&guarded.buf[guarded.pos..guarded.pos + n]);
128                    guarded.pos += n;
129
130                    self.inner.notify.notify_waiters();
131
132                    unsafe {
133                        buf.set_init(n);
134                    }
135                    return (Ok(n), buf);
136                }
137
138                match guarded.state {
139                    ChanReadState::Live => {
140                        // muffin
141                    }
142                    ChanReadState::Eof => {
143                        return (Ok(0), buf);
144                    }
145                    ChanReadState::Reset => {
146                        return (Err(std::io::ErrorKind::ConnectionReset.into()), buf);
147                    }
148                }
149            }
150
151            self.inner.notify.notified().await;
152        }
153    }
154}
155
156pub struct ChanWrite {
157    tx: mpsc::Sender<Vec<u8>>,
158}
159
160impl ChanWrite {
161    pub fn new() -> (mpsc::Receiver<Vec<u8>>, Self) {
162        let (tx, rx) = mpsc::channel(1);
163        (rx, Self { tx })
164    }
165}
166
167impl WriteOwned for ChanWrite {
168    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
169        let slice = unsafe { std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) };
170        match self.tx.send(slice.to_vec()).await {
171            Ok(()) => (Ok(buf.bytes_init()), buf),
172            Err(_) => (Err(std::io::ErrorKind::BrokenPipe.into()), buf),
173        }
174    }
175}
176
177#[cfg(all(test, not(feature = "miri")))]
178mod tests {
179    use super::{ChanRead, ReadOwned};
180    use std::{cell::RefCell, rc::Rc};
181
182    #[test]
183    fn test_chan_reader() {
184        crate::start(async move {
185            let (send, mut cr) = ChanRead::new();
186            let wrote_three = Rc::new(RefCell::new(false));
187
188            crate::spawn({
189                let wrote_three = wrote_three.clone();
190                async move {
191                    send.send("one").await.unwrap();
192                    send.send("two").await.unwrap();
193                    send.send("three").await.unwrap();
194                    *wrote_three.borrow_mut() = true;
195                    send.send("splitread").await.unwrap();
196                }
197            });
198
199            {
200                let buf = vec![0u8; 256];
201                let (res, buf) = cr.read(buf).await;
202                let n = res.unwrap();
203                assert_eq!(&buf[..n], b"one");
204            }
205
206            assert!(!*wrote_three.borrow());
207
208            {
209                let buf = vec![0u8; 256];
210                let (res, buf) = cr.read(buf).await;
211                let n = res.unwrap();
212                assert_eq!(&buf[..n], b"two");
213            }
214
215            tokio::task::yield_now().await;
216            assert!(*wrote_three.borrow());
217
218            {
219                let buf = vec![0u8; 256];
220                let (res, buf) = cr.read(buf).await;
221                let n = res.unwrap();
222                assert_eq!(&buf[..n], b"three");
223            }
224
225            {
226                let buf = vec![0u8; 5];
227                let (res, buf) = cr.read(buf).await;
228                let n = res.unwrap();
229                assert_eq!(&buf[..n], b"split");
230
231                let buf = vec![0u8; 256];
232                let (res, buf) = cr.read(buf).await;
233                let n = res.unwrap();
234                assert_eq!(&buf[..n], b"read");
235            }
236
237            {
238                let buf = vec![0u8; 0];
239                let (res, _) = cr.read(buf).await;
240                let n = res.unwrap();
241                assert_eq!(n, 0, "reached EOF");
242            }
243
244            let (send, mut cr) = ChanRead::new();
245
246            crate::spawn({
247                async move {
248                    send.send("two-part").await.unwrap();
249                    send.reset();
250                }
251            });
252
253            for _ in 0..5 {
254                tokio::task::yield_now().await;
255            }
256
257            {
258                let buf = vec![0u8; 4];
259                let (res, buf) = cr.read(buf).await;
260                let n = res.unwrap();
261                assert_eq!(&buf[..n], b"two-");
262            }
263
264            {
265                let buf = vec![0u8; 4];
266                let (res, buf) = cr.read(buf).await;
267                let n = res.unwrap();
268                assert_eq!(&buf[..n], b"part");
269            }
270
271            {
272                let buf = vec![0u8; 0];
273                let (res, _) = cr.read(buf).await;
274                let err = res.unwrap_err();
275                assert_eq!(
276                    err.kind(),
277                    std::io::ErrorKind::ConnectionReset,
278                    "reached EOF"
279                );
280            }
281        })
282    }
283}