futures_loco_protocol/
lib.rs1pub mod secure;
2pub mod session;
3
4pub use loco_protocol;
5
6use futures_core::Future;
7use futures_io::{AsyncRead, AsyncWrite};
8use loco_protocol::command::{
9 client::{LocoSink, LocoStream, StreamState},
10 BoxedCommand, Command, Header, Method,
11};
12use std::{
13 future::poll_fn,
14 io::{self, ErrorKind},
15 mem,
16 pin::Pin,
17 task::{ready, Context, Poll},
18};
19
20pin_project_lite::pin_project!(
21 #[derive(Debug)]
22 pub struct LocoClient<T> {
23 current_id: u32,
24
25 sink: LocoSink,
26 stream: LocoStream,
27
28 read_state: ReadState,
29
30 #[pin]
31 inner: T,
32 }
33);
34
35impl<T> LocoClient<T> {
36 pub const MAX_READ_SIZE: u64 = 16 * 1024 * 1024;
37
38 pub const fn new(inner: T) -> Self {
39 Self {
40 current_id: 0,
41
42 sink: LocoSink::new(),
43 stream: LocoStream::new(),
44
45 read_state: ReadState::Pending,
46
47 inner,
48 }
49 }
50
51 pub const fn inner(&self) -> &T {
52 &self.inner
53 }
54
55 pub fn inner_mut(&mut self) -> &mut T {
56 &mut self.inner
57 }
58
59 pub fn inner_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
60 self.project().inner
61 }
62
63 pub fn into_inner(self) -> T {
64 self.inner
65 }
66}
67
68impl<T: AsyncRead> LocoClient<T> {
69 pub async fn read(&mut self) -> io::Result<BoxedCommand>
70 where
71 T: Unpin,
72 {
73 let mut this = Pin::new(self);
74
75 poll_fn(|cx| this.as_mut().poll_read(cx)).await
76 }
77
78 pub fn poll_read(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<BoxedCommand>> {
79 let mut this = self.project();
80
81 let mut buffer = [0_u8; 1024];
82 loop {
83 match mem::replace(this.read_state, ReadState::Corrupted) {
84 ReadState::Pending => match this.stream.read() {
85 Some(command) => {
86 *this.read_state = ReadState::Pending;
87 break Poll::Ready(Ok(command));
88 }
89
90 None => {
91 if let StreamState::Header(header) = this.stream.state() {
92 if header.data_size as u64 > Self::MAX_READ_SIZE {
93 *this.read_state = ReadState::PacketTooLarge;
94 continue;
95 }
96 }
97
98 *this.read_state = ReadState::Pending;
99
100 let read = ready!(this.inner.as_mut().poll_read(cx, &mut buffer))?;
101 if read == 0 {
102 *this.read_state = ReadState::Done;
103 continue;
104 }
105
106 this.stream.read_buffer.extend(&buffer[..read]);
107 }
108 },
109
110 ReadState::PacketTooLarge => {
111 *this.read_state = ReadState::PacketTooLarge;
112
113 break Poll::Ready(Err(io::Error::new(
114 ErrorKind::InvalidData,
115 "packet is too large",
116 )));
117 }
118
119 ReadState::Done => break Poll::Ready(Err(ErrorKind::UnexpectedEof.into())),
120
121 ReadState::Corrupted => unreachable!(),
122 }
123 }
124 }
125}
126
127impl<T: AsyncWrite> LocoClient<T> {
128 pub async fn send(&mut self, method: Method, data: &[u8]) -> io::Result<u32>
129 where
130 T: Unpin,
131 {
132 let mut this = Pin::new(self);
133
134 let id = this.as_mut().write(method, data);
135
136 poll_fn(|cx| this.as_mut().poll_flush(cx)).await?;
137
138 Ok(id)
139 }
140
141 pub fn write(self: Pin<&mut Self>, method: Method, data: &[u8]) -> u32 {
142 let this = self.project();
143
144 let id = {
145 *this.current_id += 1;
146
147 *this.current_id
148 };
149
150 this.sink.send(Command {
151 header: Header {
152 id,
153 status: 0,
154 method,
155 data_type: 0,
156 },
157 data,
158 });
159
160 id
161 }
162
163 pub fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
164 let mut this = self.project();
165
166 while !this.sink.write_buffer.is_empty() {
167 let written = ready!(this.inner.as_mut().poll_write(cx, {
168 let slices = this.sink.write_buffer.as_slices();
169
170 if !slices.0.is_empty() {
171 slices.0
172 } else {
173 slices.1
174 }
175 }))?;
176
177 this.sink.write_buffer.drain(..written);
178 }
179
180 ready!(this.inner.poll_flush(cx))?;
181
182 Poll::Ready(Ok(()))
183 }
184}
185
186impl<T: AsyncRead + AsyncWrite + Unpin> LocoClient<T> {
187 pub async fn request(
188 &mut self,
189 method: Method,
190 data: &[u8],
191 ) -> io::Result<impl Future<Output = io::Result<BoxedCommand>> + '_> {
192 let mut this = Pin::new(self);
193
194 let id = this.as_mut().write(method, data);
195
196 poll_fn(|cx| this.as_mut().poll_flush(cx)).await?;
197
198 let read_task = async move {
199 Ok(loop {
200 let read = poll_fn(|cx| this.as_mut().poll_read(cx)).await?;
201
202 if read.header.id == id {
203 break read;
204 }
205 })
206 };
207
208 Ok(read_task)
209 }
210}
211
212#[derive(Debug)]
213enum ReadState {
214 Pending,
215 PacketTooLarge,
216 Done,
217 Corrupted,
218}