fluke_maybe_uring/io/
chan.rs1use std::{cell::RefCell, rc::Rc};
2
3use crate::{
4 buf::{IoBuf, IoBufMut},
5 io::WriteOwned,
6 BufResult,
7};
8use tokio::sync::mpsc;
9
10use super::ReadOwned;
11
12pub struct ChanRead {
15 inner: Rc<ChanReadInner>,
16}
17
18pub struct ChanReadSend {
19 inner: Rc<ChanReadInner>,
20}
21
22struct ChanReadInner {
23 notify: tokio::sync::Notify,
24 guarded: RefCell<ChanReadGuarded>,
25}
26
27struct ChanReadGuarded {
28 state: ChanReadState,
29 pos: usize,
30 buf: Vec<u8>,
31}
32
33enum ChanReadState {
34 Live,
36
37 Eof,
39
40 Reset,
42}
43
44impl ChanRead {
45 pub fn new() -> (ChanReadSend, Self) {
46 let inner = Rc::new(ChanReadInner {
47 notify: Default::default(),
48 guarded: RefCell::new(ChanReadGuarded {
49 state: ChanReadState::Live,
50 pos: 0,
51 buf: Vec::new(),
52 }),
53 });
54 (
55 ChanReadSend {
56 inner: inner.clone(),
57 },
58 Self { inner },
59 )
60 }
61}
62
63impl ChanReadSend {
64 pub fn reset(self) {
66 let mut guarded = self.inner.guarded.borrow_mut();
67 guarded.state = ChanReadState::Reset;
68 }
70
71 pub async fn send(&self, next_buf: impl Into<Vec<u8>>) -> Result<(), std::io::Error> {
75 let next_buf = next_buf.into();
76
77 loop {
78 {
79 let mut guarded = self.inner.guarded.borrow_mut();
80 match guarded.state {
81 ChanReadState::Live => {
82 if guarded.pos == guarded.buf.len() {
83 guarded.pos = 0;
84 guarded.buf = next_buf;
85 self.inner.notify.notify_waiters();
86 return Ok(());
87 } else {
88 }
90 }
91
92 ChanReadState::Eof => unreachable!(),
94
95 ChanReadState::Reset => unreachable!(),
97 }
98 }
99 self.inner.notify.notified().await
100 }
101 }
102}
103
104impl Drop for ChanReadSend {
105 fn drop(&mut self) {
106 let mut guarded = self.inner.guarded.borrow_mut();
107 if let ChanReadState::Live = guarded.state {
108 guarded.state = ChanReadState::Eof;
109 }
110 self.inner.notify.notify_waiters();
111 }
112}
113
114impl ReadOwned for ChanRead {
115 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
116 let out =
117 unsafe { std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) };
118
119 loop {
120 {
121 let mut guarded = self.inner.guarded.borrow_mut();
122 let remain = guarded.buf.len() - guarded.pos;
123
124 if remain > 0 {
125 let n = std::cmp::min(remain, out.len());
126
127 out[..n].copy_from_slice(&guarded.buf[guarded.pos..guarded.pos + n]);
128 guarded.pos += n;
129
130 self.inner.notify.notify_waiters();
131
132 unsafe {
133 buf.set_init(n);
134 }
135 return (Ok(n), buf);
136 }
137
138 match guarded.state {
139 ChanReadState::Live => {
140 }
142 ChanReadState::Eof => {
143 return (Ok(0), buf);
144 }
145 ChanReadState::Reset => {
146 return (Err(std::io::ErrorKind::ConnectionReset.into()), buf);
147 }
148 }
149 }
150
151 self.inner.notify.notified().await;
152 }
153 }
154}
155
156pub struct ChanWrite {
157 tx: mpsc::Sender<Vec<u8>>,
158}
159
160impl ChanWrite {
161 pub fn new() -> (mpsc::Receiver<Vec<u8>>, Self) {
162 let (tx, rx) = mpsc::channel(1);
163 (rx, Self { tx })
164 }
165}
166
167impl WriteOwned for ChanWrite {
168 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
169 let slice = unsafe { std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) };
170 match self.tx.send(slice.to_vec()).await {
171 Ok(()) => (Ok(buf.bytes_init()), buf),
172 Err(_) => (Err(std::io::ErrorKind::BrokenPipe.into()), buf),
173 }
174 }
175}
176
177#[cfg(all(test, not(feature = "miri")))]
178mod tests {
179 use super::{ChanRead, ReadOwned};
180 use std::{cell::RefCell, rc::Rc};
181
182 #[test]
183 fn test_chan_reader() {
184 crate::start(async move {
185 let (send, mut cr) = ChanRead::new();
186 let wrote_three = Rc::new(RefCell::new(false));
187
188 crate::spawn({
189 let wrote_three = wrote_three.clone();
190 async move {
191 send.send("one").await.unwrap();
192 send.send("two").await.unwrap();
193 send.send("three").await.unwrap();
194 *wrote_three.borrow_mut() = true;
195 send.send("splitread").await.unwrap();
196 }
197 });
198
199 {
200 let buf = vec![0u8; 256];
201 let (res, buf) = cr.read(buf).await;
202 let n = res.unwrap();
203 assert_eq!(&buf[..n], b"one");
204 }
205
206 assert!(!*wrote_three.borrow());
207
208 {
209 let buf = vec![0u8; 256];
210 let (res, buf) = cr.read(buf).await;
211 let n = res.unwrap();
212 assert_eq!(&buf[..n], b"two");
213 }
214
215 tokio::task::yield_now().await;
216 assert!(*wrote_three.borrow());
217
218 {
219 let buf = vec![0u8; 256];
220 let (res, buf) = cr.read(buf).await;
221 let n = res.unwrap();
222 assert_eq!(&buf[..n], b"three");
223 }
224
225 {
226 let buf = vec![0u8; 5];
227 let (res, buf) = cr.read(buf).await;
228 let n = res.unwrap();
229 assert_eq!(&buf[..n], b"split");
230
231 let buf = vec![0u8; 256];
232 let (res, buf) = cr.read(buf).await;
233 let n = res.unwrap();
234 assert_eq!(&buf[..n], b"read");
235 }
236
237 {
238 let buf = vec![0u8; 0];
239 let (res, _) = cr.read(buf).await;
240 let n = res.unwrap();
241 assert_eq!(n, 0, "reached EOF");
242 }
243
244 let (send, mut cr) = ChanRead::new();
245
246 crate::spawn({
247 async move {
248 send.send("two-part").await.unwrap();
249 send.reset();
250 }
251 });
252
253 for _ in 0..5 {
254 tokio::task::yield_now().await;
255 }
256
257 {
258 let buf = vec![0u8; 4];
259 let (res, buf) = cr.read(buf).await;
260 let n = res.unwrap();
261 assert_eq!(&buf[..n], b"two-");
262 }
263
264 {
265 let buf = vec![0u8; 4];
266 let (res, buf) = cr.read(buf).await;
267 let n = res.unwrap();
268 assert_eq!(&buf[..n], b"part");
269 }
270
271 {
272 let buf = vec![0u8; 0];
273 let (res, _) = cr.read(buf).await;
274 let err = res.unwrap_err();
275 assert_eq!(
276 err.kind(),
277 std::io::ErrorKind::ConnectionReset,
278 "reached EOF"
279 );
280 }
281 })
282 }
283}