1use crate as mfio;
4use crate::error::Result;
5use crate::io::*;
6use crate::locks::Mutex;
7use crate::std_prelude::*;
8use crate::traits::*;
9use crate::util::{PosShift, UsizeMath};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use mfio_derive::*;
14
15pub trait StreamPos<Param> {
16 fn set_pos(&self, pos: Param);
17
18 fn get_pos(&self) -> Param;
19
20 fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F);
21
22 fn end(&self) -> Option<Param> {
23 None
24 }
25}
26
27#[cfg(feature = "std")]
28pub fn std_seek(
29 io: &(impl StreamPos<u64> + ?Sized),
30 pos: std::io::SeekFrom,
31) -> std::io::Result<u64> {
32 match pos {
33 std::io::SeekFrom::Start(val) => {
34 io.set_pos(val);
35 Ok(val)
36 }
37 std::io::SeekFrom::End(val) => {
38 if let Some(end) = io.end() {
39 let pos = if val < 0 {
40 end.checked_sub((-val) as u64)
41 .ok_or(std::io::ErrorKind::InvalidInput)?
42 } else {
43 end + val as u64
44 };
45 io.set_pos(pos);
46 Ok(pos)
47 } else {
48 Err(std::io::ErrorKind::Unsupported.into())
49 }
50 }
51 std::io::SeekFrom::Current(val) => {
52 let pos = io.get_pos();
53 let pos = if val < 0 {
54 pos.checked_sub((-val) as u64)
55 .ok_or(std::io::ErrorKind::InvalidInput)?
56 } else {
57 pos + val as u64
58 };
59 io.set_pos(pos);
60 Ok(pos)
61 }
62 }
63}
64
65impl<Param: Copy + UsizeMath, Io: StreamPos<Param>> PosShift<Io> for Param {
66 fn add_pos(&mut self, out: usize, io: &Io) {
67 self.add_assign(out);
68 io.set_pos(*self);
69 }
70
71 fn add_io_pos(io: &Io, out: usize) {
72 io.update_pos(|pos| pos.add(out))
73 }
74}
75
76pub trait AsyncRead<Param: 'static>: IoRead<Param> {
77 fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, Param, &'a mut [u8]>;
78 fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, Param>;
79}
80
81impl<T: IoRead<Param> + StreamPos<Param>, Param: 'static + Copy> AsyncRead<Param> for T {
82 fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, Param, &'a mut [u8]> {
83 let len = buf.len();
84 let (pkt, sync) = <&'a mut [u8] as IntoPacket<Write>>::into_packet(buf);
85 AsyncIoFut {
86 io: self,
87 len,
88 fut: self.io(self.get_pos(), pkt),
89 sync: Some(sync),
90 }
91 }
92
93 fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, Param> {
94 StdReadToEndFut {
95 io: self,
96 fut: <Self as IoRead<Param>>::read_to_end(self, self.get_pos(), buf),
97 }
98 }
99}
100
101impl<T: IoRead<NoPos>> AsyncRead<NoPos> for T {
102 fn read<'a>(&'a self, buf: &'a mut [u8]) -> AsyncIoFut<'a, Self, Write, NoPos, &'a mut [u8]> {
103 let len = buf.len();
104 let (pkt, sync) = <&'a mut [u8] as IntoPacket<Write>>::into_packet(buf);
105 AsyncIoFut {
106 io: self,
107 len,
108 fut: self.io(NoPos::new(), pkt),
109 sync: Some(sync),
110 }
111 }
112
113 fn read_to_end<'a>(&'a self, buf: &'a mut Vec<u8>) -> StdReadToEndFut<'a, Self, NoPos> {
114 StdReadToEndFut {
115 io: self,
116 fut: <Self as IoRead<NoPos>>::read_to_end(self, NoPos::new(), buf),
117 }
118 }
119}
120
121pub trait AsyncWrite<Param>: IoWrite<Param> {
122 fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, Param, &'a [u8]>;
123}
124
125impl<T: IoWrite<Param> + StreamPos<Param>, Param: Copy> AsyncWrite<Param> for T {
126 fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, Param, &'a [u8]> {
127 let len = buf.len();
128 let (pkt, sync) = buf.into_packet();
129 AsyncIoFut {
130 io: self,
131 len,
132 fut: self.io(self.get_pos(), pkt),
133 sync: Some(sync),
134 }
135 }
136}
137
138impl<T: IoWrite<NoPos>> AsyncWrite<NoPos> for T {
139 fn write<'a>(&'a self, buf: &'a [u8]) -> AsyncIoFut<'a, Self, Read, NoPos, &'a [u8]> {
140 let len = buf.len();
141 let (pkt, sync) = buf.into_packet();
142 AsyncIoFut {
143 io: self,
144 len,
145 fut: self.io(NoPos::new(), pkt),
146 sync: Some(sync),
147 }
148 }
149}
150
151pub struct AsyncIoFut<'a, Io: ?Sized, Perms: PacketPerms, Param: 'a, Obj: IntoPacket<'a, Perms>> {
152 io: &'a Io,
153 fut: IoFut<'a, Io, Perms, Param, Obj::Target>,
154 pub(crate) sync: Option<Obj::SyncHandle>,
155 len: usize,
156}
157
158impl<
159 'a,
160 Io: PacketIo<Perms, Param>,
161 Perms: PacketPerms,
162 Param: PosShift<Io>,
163 Obj: IntoPacket<'a, Perms>,
164 > Future for AsyncIoFut<'a, Io, Perms, Param, Obj>
165{
166 type Output = Result<usize>;
167
168 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
169 let this = unsafe { self.get_unchecked_mut() };
170
171 let fut = unsafe { Pin::new_unchecked(&mut this.fut) };
172
173 fut.poll(cx).map(|pkt| {
174 let hdr = <<Obj as IntoPacket<'a, Perms>>::Target as OpaqueStore>::stack_hdr(&pkt);
175 Obj::sync_back(hdr, this.sync.take().unwrap());
177 let progressed = core::cmp::min(hdr.error_clamp() as usize, this.len);
178 Param::add_io_pos(this.io, progressed);
179 Ok(progressed)
181 })
182 }
183}
184
185pub struct StdReadToEndFut<'a, Io: PacketIo<Write, Param>, Param> {
186 io: &'a Io,
187 fut: ReadToEndFut<'a, Io, Param>,
188}
189
190impl<'a, Io: PacketIo<Write, Param>, Param: PosShift<Io>> Future
191 for StdReadToEndFut<'a, Io, Param>
192{
193 type Output = Result<()>;
194
195 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
196 let this = unsafe { self.get_unchecked_mut() };
197
198 match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) {
199 Poll::Ready(Ok(r)) => {
200 Param::add_io_pos(this.io, r);
201 Poll::Ready(Ok(()))
202 }
203 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
204 Poll::Pending => Poll::Pending,
205 }
206 }
207}
208
209#[macro_export]
210macro_rules! stdio_impl {
215 (<$($lt2:lifetime,)* $($ty2:ident),*> $t:ident <$($lt:lifetime,)* $($ty:ident),*> @ $($tt:tt)*) => {
216 impl<$($lt2,)* $($ty2),*> std::io::Seek for $t<$($lt,)* $($ty),*> where $($tt)* {
217 fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
218 $crate::stdeq::std_seek(self, pos)
219 }
220
221 fn stream_position(&mut self) -> std::io::Result<u64> {
222 Ok(self.get_pos())
223 }
224
225 fn rewind(&mut self) -> std::io::Result<()> {
226 self.set_pos(0);
227 Ok(())
228 }
229 }
230
231 impl<$($lt2,)* $($ty2),*> std::io::Read for $t<$($lt,)* $($ty),*> where $t<$($lt,)* $($ty),*>: $crate::stdeq::AsyncRead<u64> + $crate::backend::IoBackend, $($tt)* {
232 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
233 use $crate::backend::IoBackendExt;
234 self.block_on($crate::stdeq::AsyncRead::read(self, buf)).map_err(|_| std::io::ErrorKind::Other.into())
235 }
236
237 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
238 use $crate::backend::IoBackendExt;
239 let len = buf.len();
240 self.block_on($crate::stdeq::AsyncRead::read_to_end(self, buf)).map_err(|_| std::io::ErrorKind::Other)?;
241 Ok(buf.len() - len)
242 }
243 }
244
245 impl<$($lt2,)* $($ty2),*> std::io::Write for $t<$($lt,)* $($ty),*> where $t<$($lt,)* $($ty),*>: $crate::stdeq::AsyncWrite<u64> + $crate::backend::IoBackend, $($tt)* {
246 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
247 use $crate::backend::IoBackendExt;
248 self.block_on(AsyncWrite::write(self, buf)).map_err(|_| std::io::ErrorKind::Other.into())
249 }
250
251 fn flush(&mut self) -> std::io::Result<()> {
252 Ok(())
253 }
254 }
255 };
256 ($t:ident @ $($tt:tt)*) => {
257 $crate::stdio_impl!($t<> @ $($tt)*);
258 }
259}
260
261#[derive(SyncIoWrite, SyncIoRead)]
262pub struct Seekable<T, Param> {
263 pos: Mutex<Param>,
264 handle: T,
265}
266
267impl<T, Param: Default> From<T> for Seekable<T, Param> {
268 fn from(handle: T) -> Self {
269 Self {
270 pos: Default::default(),
271 handle,
272 }
273 }
274}
275
276impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param> PacketIo<Perms, Param>
277 for Seekable<T, Param>
278{
279 fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
280 self.handle.send_io(param, view)
281 }
282}
283
284impl<T, Param: Copy> StreamPos<Param> for Seekable<T, Param> {
285 fn get_pos(&self) -> Param {
286 *self.pos.lock()
287 }
288
289 fn set_pos(&self, pos: Param) {
290 *self.pos.lock() = pos;
291 }
292
293 fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F) {
294 let mut pos = self.pos.lock();
295 *pos = f(*pos);
296 }
297}
298
299#[cfg(feature = "std")]
300stdio_impl!(<T> Seekable<T, u64> @);
301
302#[derive(SyncIoWrite, SyncIoRead)]
303pub struct SeekableRef<'a, T, Param> {
304 pos: Mutex<Param>,
305 handle: &'a T,
306}
307
308impl<'a, T, Param: Default> From<&'a T> for SeekableRef<'a, T, Param> {
309 fn from(handle: &'a T) -> Self {
310 Self {
311 pos: Default::default(),
312 handle,
313 }
314 }
315}
316
317impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param: core::fmt::Debug> PacketIo<Perms, Param>
318 for SeekableRef<'_, T, Param>
319{
320 fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
321 self.handle.send_io(param, view)
322 }
323}
324
325impl<T, Param: Copy> StreamPos<Param> for SeekableRef<'_, T, Param> {
326 fn get_pos(&self) -> Param {
327 *self.pos.lock()
328 }
329
330 fn set_pos(&self, pos: Param) {
331 *self.pos.lock() = pos;
332 }
333
334 fn update_pos<F: FnOnce(Param) -> Param>(&self, f: F) {
335 let mut pos = self.pos.lock();
336 *pos = f(*pos);
337 }
338}
339
340#[cfg(feature = "std")]
341stdio_impl!(<'a, T> SeekableRef<'a, T, u64> @);
342
343#[derive(SyncIoWrite, SyncIoRead)]
344pub struct FakeSeek<T> {
345 handle: T,
346}
347
348impl<T> From<T> for FakeSeek<T> {
349 fn from(handle: T) -> Self {
350 Self { handle }
351 }
352}
353
354impl<T: PacketIo<Perms, Param>, Perms: PacketPerms, Param> PacketIo<Perms, Param> for FakeSeek<T> {
355 fn send_io(&self, param: Param, view: BoundPacketView<Perms>) {
356 self.handle.send_io(param, view)
357 }
358}
359
360impl<T, Param: Default + core::ops::Not<Output = Param>> StreamPos<Param> for FakeSeek<T> {
361 fn get_pos(&self) -> Param {
362 !Param::default()
363 }
364
365 fn set_pos(&self, _: Param) {}
366
367 fn update_pos<F: FnOnce(Param) -> Param>(&self, _: F) {}
368}
369
370#[cfg(feature = "std")]
371stdio_impl!(<T> FakeSeek<T> @);