futures_loco_protocol/
lib.rs

1pub mod secure;
2pub mod session;
3
4pub use loco_protocol;
5
6use futures_core::Future;
7use futures_io::{AsyncRead, AsyncWrite};
8use loco_protocol::command::{
9    client::{LocoSink, LocoStream, StreamState},
10    BoxedCommand, Command, Header, Method,
11};
12use std::{
13    future::poll_fn,
14    io::{self, ErrorKind},
15    mem,
16    pin::Pin,
17    task::{ready, Context, Poll},
18};
19
20pin_project_lite::pin_project!(
21    #[derive(Debug)]
22    pub struct LocoClient<T> {
23        current_id: u32,
24
25        sink: LocoSink,
26        stream: LocoStream,
27
28        read_state: ReadState,
29
30        #[pin]
31        inner: T,
32    }
33);
34
35impl<T> LocoClient<T> {
36    pub const MAX_READ_SIZE: u64 = 16 * 1024 * 1024;
37
38    pub const fn new(inner: T) -> Self {
39        Self {
40            current_id: 0,
41
42            sink: LocoSink::new(),
43            stream: LocoStream::new(),
44
45            read_state: ReadState::Pending,
46
47            inner,
48        }
49    }
50
51    pub const fn inner(&self) -> &T {
52        &self.inner
53    }
54
55    pub fn inner_mut(&mut self) -> &mut T {
56        &mut self.inner
57    }
58
59    pub fn inner_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
60        self.project().inner
61    }
62
63    pub fn into_inner(self) -> T {
64        self.inner
65    }
66}
67
68impl<T: AsyncRead> LocoClient<T> {
69    pub async fn read(&mut self) -> io::Result<BoxedCommand>
70    where
71        T: Unpin,
72    {
73        let mut this = Pin::new(self);
74
75        poll_fn(|cx| this.as_mut().poll_read(cx)).await
76    }
77
78    pub fn poll_read(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<BoxedCommand>> {
79        let mut this = self.project();
80
81        let mut buffer = [0_u8; 1024];
82        loop {
83            match mem::replace(this.read_state, ReadState::Corrupted) {
84                ReadState::Pending => match this.stream.read() {
85                    Some(command) => {
86                        *this.read_state = ReadState::Pending;
87                        break Poll::Ready(Ok(command));
88                    }
89
90                    None => {
91                        if let StreamState::Header(header) = this.stream.state() {
92                            if header.data_size as u64 > Self::MAX_READ_SIZE {
93                                *this.read_state = ReadState::PacketTooLarge;
94                                continue;
95                            }
96                        }
97
98                        *this.read_state = ReadState::Pending;
99
100                        let read = ready!(this.inner.as_mut().poll_read(cx, &mut buffer))?;
101                        if read == 0 {
102                            *this.read_state = ReadState::Done;
103                            continue;
104                        }
105
106                        this.stream.read_buffer.extend(&buffer[..read]);
107                    }
108                },
109
110                ReadState::PacketTooLarge => {
111                    *this.read_state = ReadState::PacketTooLarge;
112
113                    break Poll::Ready(Err(io::Error::new(
114                        ErrorKind::InvalidData,
115                        "packet is too large",
116                    )));
117                }
118
119                ReadState::Done => break Poll::Ready(Err(ErrorKind::UnexpectedEof.into())),
120
121                ReadState::Corrupted => unreachable!(),
122            }
123        }
124    }
125}
126
127impl<T: AsyncWrite> LocoClient<T> {
128    pub async fn send(&mut self, method: Method, data: &[u8]) -> io::Result<u32>
129    where
130        T: Unpin,
131    {
132        let mut this = Pin::new(self);
133
134        let id = this.as_mut().write(method, data);
135
136        poll_fn(|cx| this.as_mut().poll_flush(cx)).await?;
137
138        Ok(id)
139    }
140
141    pub fn write(self: Pin<&mut Self>, method: Method, data: &[u8]) -> u32 {
142        let this = self.project();
143
144        let id = {
145            *this.current_id += 1;
146
147            *this.current_id
148        };
149
150        this.sink.send(Command {
151            header: Header {
152                id,
153                status: 0,
154                method,
155                data_type: 0,
156            },
157            data,
158        });
159
160        id
161    }
162
163    pub fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
164        let mut this = self.project();
165
166        while !this.sink.write_buffer.is_empty() {
167            let written = ready!(this.inner.as_mut().poll_write(cx, {
168                let slices = this.sink.write_buffer.as_slices();
169
170                if !slices.0.is_empty() {
171                    slices.0
172                } else {
173                    slices.1
174                }
175            }))?;
176
177            this.sink.write_buffer.drain(..written);
178        }
179
180        ready!(this.inner.poll_flush(cx))?;
181
182        Poll::Ready(Ok(()))
183    }
184}
185
186impl<T: AsyncRead + AsyncWrite + Unpin> LocoClient<T> {
187    pub async fn request(
188        &mut self,
189        method: Method,
190        data: &[u8],
191    ) -> io::Result<impl Future<Output = io::Result<BoxedCommand>> + '_> {
192        let mut this = Pin::new(self);
193
194        let id = this.as_mut().write(method, data);
195
196        poll_fn(|cx| this.as_mut().poll_flush(cx)).await?;
197
198        let read_task = async move {
199            Ok(loop {
200                let read = poll_fn(|cx| this.as_mut().poll_read(cx)).await?;
201
202                if read.header.id == id {
203                    break read;
204                }
205            })
206        };
207
208        Ok(read_task)
209    }
210}
211
212#[derive(Debug)]
213enum ReadState {
214    Pending,
215    PacketTooLarge,
216    Done,
217    Corrupted,
218}