Skip to main content

mfio/
stdeq.rs

1//! `std::io` equivalent Read/Write traits.
2
3use crate as mfio;
4use crate::error::Result;
5use crate::io::*;
6use crate::locks::Mutex;
7use crate::std_prelude::*;
8use crate::traits::*;
9use crate::util::{PosShift, UsizeMath};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use mfio_derive::*;
14
15pub trait StreamPos<Param> {
16    fn set_pos(&self, pos: Param);
17
18    fn get_pos(&self) -> Param;
19
20    fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F);
21
22    fn end(&self) -> Option<Param> {
23        None
24    }
25}
26
27#[cfg(feature = "std")]
28pub fn std_seek(
29    io: &(impl StreamPos<u64> + ?Sized),
30    pos: std::io::SeekFrom,
31) -> std::io::Result<u64> {
32    match pos {
33        std::io::SeekFrom::Start(val) => {
34            io.set_pos(val);
35            Ok(val)
36        }
37        std::io::SeekFrom::End(val) => {
38            if let Some(end) = io.end() {
39                let pos = if val < 0 {
40                    end.checked_sub((-val) as u64)
41                        .ok_or(std::io::ErrorKind::InvalidInput)?
42                } else {
43                    end + val as u64
44                };
45                io.set_pos(pos);
46                Ok(pos)
47            } else {
48                Err(std::io::ErrorKind::Unsupported.into())
49            }
50        }
51        std::io::SeekFrom::Current(val) => {
52            let pos = io.get_pos();
53            let pos = if val < 0 {
54                pos.checked_sub((-val) as u64)
55                    .ok_or(std::io::ErrorKind::InvalidInput)?
56            } else {
57                pos + val as u64
58            };
59            io.set_pos(pos);
60            Ok(pos)
61        }
62    }
63}
64
65impl<Param: Copy + UsizeMath, Io: StreamPos<Param>> PosShift<Io> for Param {
66    fn add_pos(&mut self, out: usize, io: &Io) {
67        self.add_assign(out);
68        io.set_pos(*self);
69    }
70
71    fn add_io_pos(io: &Io, out: usize) {
72        io.update_pos(|pos| pos.add(out))
73    }
74}
75
76pub trait AsyncRead<Param: 'static>: IoRead<Param> {
77    fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, Param, &'a mut [u8]>;
78    fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, Param>;
79}
80
81impl<T: IoRead<Param> + StreamPos<Param>, Param: 'static + Copy> AsyncRead<Param> for T {
82    fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, Param, &'a mut [u8]> {
83        let len = buf.len();
84        let (pkt, sync) = <&'a mut [u8] as IntoPacket<Write>>::into_packet(buf);
85        AsyncIoFut {
86            io: self,
87            len,
88            fut: self.io(self.get_pos(), pkt),
89            sync: Some(sync),
90        }
91    }
92
93    fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, Param> {
94        StdReadToEndFut {
95            io: self,
96            fut: <Self as IoRead<Param>>::read_to_end(self, self.get_pos(), buf),
97        }
98    }
99}
100
101impl<T: IoRead<NoPos>> AsyncRead<NoPos> for T {
102    fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, NoPos, &'a mut [u8]> {
103        let len = buf.len();
104        let (pkt, sync) = <&'a mut [u8] as IntoPacket<Write>>::into_packet(buf);
105        AsyncIoFut {
106            io: self,
107            len,
108            fut: self.io(NoPos::new(), pkt),
109            sync: Some(sync),
110        }
111    }
112
113    fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, NoPos> {
114        StdReadToEndFut {
115            io: self,
116            fut: <Self as IoRead<NoPos>>::read_to_end(self, NoPos::new(), buf),
117        }
118    }
119}
120
121pub trait AsyncWrite<Param>: IoWrite<Param> {
122    fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, Param, &'a [u8]>;
123}
124
125impl<T: IoWrite<Param> + StreamPos<Param>, Param: Copy> AsyncWrite<Param> for T {
126    fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, Param, &'a [u8]> {
127        let len = buf.len();
128        let (pkt, sync) = buf.into_packet();
129        AsyncIoFut {
130            io: self,
131            len,
132            fut: self.io(self.get_pos(), pkt),
133            sync: Some(sync),
134        }
135    }
136}
137
138impl<T: IoWrite<NoPos>> AsyncWrite<NoPos> for T {
139    fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, NoPos, &'a [u8]> {
140        let len = buf.len();
141        let (pkt, sync) = buf.into_packet();
142        AsyncIoFut {
143            io: self,
144            len,
145            fut: self.io(NoPos::new(), pkt),
146            sync: Some(sync),
147        }
148    }
149}
150
151pub struct AsyncIoFut<'a, Io: ?Sized, Perms: PacketPerms, Param: 'a, Obj: IntoPacket<'a, Perms>> {
152    io: &'a Io,
153    fut: IoFut<'a, Io, Perms, Param, Obj::Target>,
154    pub(crate) sync: Option<Obj::SyncHandle>,
155    len: usize,
156}
157
158impl<
159        'a,
160        Io: PacketIo<Perms, Param>,
161        Perms: PacketPerms,
162        Param: PosShift<Io>,
163        Obj: IntoPacket<'a, Perms>,
164    > Future for AsyncIoFut<'a, Io, Perms, Param, Obj>
165{
166    type Output = Result<usize>;
167
168    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
169        let this = unsafe { self.get_unchecked_mut() };
170
171        let fut = unsafe { Pin::new_unchecked(&mut this.fut) };
172
173        fut.poll(cx).map(|pkt| {
174            let hdr = <<Obj as IntoPacket<'a, Perms>>::Target as OpaqueStore>::stack_hdr(&pkt);
175            // TODO: put this after error checking
176            Obj::sync_back(hdr, this.sync.take().unwrap());
177            let progressed = core::cmp::min(hdr.error_clamp() as usize, this.len);
178            Param::add_io_pos(this.io, progressed);
179            // TODO: actual error checking
180            Ok(progressed)
181        })
182    }
183}
184
185pub struct StdReadToEndFut<'a, Io: PacketIo<Write, Param>, Param> {
186    io: &'a Io,
187    fut: ReadToEndFut<'a, Io, Param>,
188}
189
190impl<'a, Io: PacketIo<Write, Param>, Param: PosShift<Io>> Future
191    for StdReadToEndFut<'a, Io, Param>
192{
193    type Output = Result<()>;
194
195    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
196        let this = unsafe { self.get_unchecked_mut() };
197
198        match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) {
199            Poll::Ready(Ok(r)) => {
200                Param::add_io_pos(this.io, r);
201                Poll::Ready(Ok(()))
202            }
203            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
204            Poll::Pending => Poll::Pending,
205        }
206    }
207}
208
209#[macro_export]
210/// Implements `Read`+`Write`+`Seek` traits on compatible type.
211///
212/// Implements `io::Seek` on type implementing `StreamPos<u64>`, `io::Write` on type implementing
213/// `AsyncWrite<u64>` and `io::Read` on type implementing `AsyncRead<u64>`.
214macro_rules! stdio_impl {
215    (<$($lt2:lifetime,)* $($ty2:ident),*> $t:ident <$($lt:lifetime,)* $($ty:ident),*> @ $($tt:tt)*) => {
216        impl<$($lt2,)* $($ty2),*> std::io::Seek for $t<$($lt,)* $($ty),*> where $($tt)* {
217            fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
218                $crate::stdeq::std_seek(self, pos)
219            }
220
221            fn stream_position(&mut self) -> std::io::Result<u64> {
222                Ok(self.get_pos())
223            }
224
225            fn rewind(&mut self) -> std::io::Result<()> {
226                self.set_pos(0);
227                Ok(())
228            }
229        }
230
231        impl<$($lt2,)* $($ty2),*> std::io::Read for $t<$($lt,)* $($ty),*> where $t<$($lt,)* $($ty),*>: $crate::stdeq::AsyncRead<u64> + $crate::backend::IoBackend, $($tt)* {
232            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
233                use $crate::backend::IoBackendExt;
234                self.block_on($crate::stdeq::AsyncRead::read(self, buf)).map_err(|_| std::io::ErrorKind::Other.into())
235            }
236
237            fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
238                use $crate::backend::IoBackendExt;
239                let len = buf.len();
240                self.block_on($crate::stdeq::AsyncRead::read_to_end(self, buf)).map_err(|_| std::io::ErrorKind::Other)?;
241                Ok(buf.len() - len)
242            }
243        }
244
245        impl<$($lt2,)* $($ty2),*> std::io::Write for $t<$($lt,)* $($ty),*> where $t<$($lt,)* $($ty),*>: $crate::stdeq::AsyncWrite<u64> + $crate::backend::IoBackend, $($tt)* {
246            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
247                use $crate::backend::IoBackendExt;
248                self.block_on(AsyncWrite::write(self, buf)).map_err(|_| std::io::ErrorKind::Other.into())
249            }
250
251            fn flush(&mut self) -> std::io::Result<()> {
252                Ok(())
253            }
254        }
255    };
256    ($t:ident @ $($tt:tt)*) => {
257        $crate::stdio_impl!($t<> @ $($tt)*);
258    }
259}
260
261#[derive(SyncIoWrite, SyncIoRead)]
262pub struct Seekable<T, Param> {
263    pos: Mutex<Param>,
264    handle: T,
265}
266
267impl<T, Param: Default> From<T> for Seekable<T, Param> {
268    fn from(handle: T) -> Self {
269        Self {
270            pos: Default::default(),
271            handle,
272        }
273    }
274}
275
276impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param> PacketIo<Perms, Param>
277    for Seekable<T, Param>
278{
279    fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
280        self.handle.send_io(param, view)
281    }
282}
283
284impl<T, Param: Copy> StreamPos<Param> for Seekable<T, Param> {
285    fn get_pos(&self) -> Param {
286        *self.pos.lock()
287    }
288
289    fn set_pos(&self, pos: Param) {
290        *self.pos.lock() = pos;
291    }
292
293    fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F) {
294        let mut pos = self.pos.lock();
295        *pos = f(*pos);
296    }
297}
298
299#[cfg(feature = "std")]
300stdio_impl!(<T> Seekable<T, u64> @);
301
302#[derive(SyncIoWrite, SyncIoRead)]
303pub struct SeekableRef<'a, T, Param> {
304    pos: Mutex<Param>,
305    handle: &'a T,
306}
307
308impl<'a, T, Param: Default> From<&'a T> for SeekableRef<'a, T, Param> {
309    fn from(handle: &'a T) -> Self {
310        Self {
311            pos: Default::default(),
312            handle,
313        }
314    }
315}
316
317impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param: core::fmt::Debug> PacketIo<Perms, Param>
318    for SeekableRef<'_, T, Param>
319{
320    fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
321        self.handle.send_io(param, view)
322    }
323}
324
325impl<T, Param: Copy> StreamPos<Param> for SeekableRef<'_, T, Param> {
326    fn get_pos(&self) -> Param {
327        *self.pos.lock()
328    }
329
330    fn set_pos(&self, pos: Param) {
331        *self.pos.lock() = pos;
332    }
333
334    fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F) {
335        let mut pos = self.pos.lock();
336        *pos = f(*pos);
337    }
338}
339
340#[cfg(feature = "std")]
341stdio_impl!(<'a, T> SeekableRef<'a, T, u64> @);
342
343#[derive(SyncIoWrite, SyncIoRead)]
344pub struct FakeSeek<T> {
345    handle: T,
346}
347
348impl<T> From<T> for FakeSeek<T> {
349    fn from(handle: T) -> Self {
350        Self { handle }
351    }
352}
353
354impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param> PacketIo<Perms, Param> for FakeSeek<T> {
355    fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
356        self.handle.send_io(param, view)
357    }
358}
359
360impl<T, Param: Default + core::ops::Not<Output = Param>> StreamPos<Param> for FakeSeek<T> {
361    fn get_pos(&self) -> Param {
362        !Param::default()
363    }
364
365    fn set_pos(&self, _: Param) {}
366
367    fn update_pos<F: FnOnce(Param) -> Param>(&self, _: F) {}
368}
369
370#[cfg(feature = "std")]
371stdio_impl!(<T> FakeSeek<T> @);