1use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine};
22
23use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
24use pin_project::pin_project;
25use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
26
27#[pin_project]
39#[derive(Debug)]
40pub struct Negotiated<TInner> {
41 #[pin]
42 state: State<TInner>
43}
44
45#[derive(Debug)]
47pub struct NegotiatedComplete<TInner> {
48 inner: Option<Negotiated<TInner>>,
49}
50
51impl<TInner> Future for NegotiatedComplete<TInner>
52where
53 TInner: AsyncRead + AsyncWrite + Unpin,
56{
57 type Output = Result<Negotiated<TInner>, NegotiationError>;
58
59 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60 let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
61 match Negotiated::poll(Pin::new(&mut io), cx) {
62 Poll::Pending => {
63 self.inner = Some(io);
64 Poll::Pending
65 },
66 Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
67 Poll::Ready(Err(err)) => {
68 self.inner = Some(io);
69 Poll::Ready(Err(err))
70 }
71 }
72 }
73}
74
75impl<TInner> Negotiated<TInner> {
76 pub(crate) fn completed(io: TInner) -> Self {
78 Negotiated { state: State::Completed { io } }
79 }
80
81 pub(crate) fn expecting(
84 io: MessageReader<TInner>,
85 protocol: Protocol,
86 header: Option<HeaderLine>
87 ) -> Self {
88 Negotiated { state: State::Expecting { io, protocol, header } }
89 }
90
91 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
93 where
94 TInner: AsyncRead + AsyncWrite + Unpin
95 {
96 match self.as_mut().poll_flush(cx) {
98 Poll::Ready(Ok(())) => {},
99 Poll::Pending => return Poll::Pending,
100 Poll::Ready(Err(e)) => {
101 if e.kind() != io::ErrorKind::WriteZero {
104 return Poll::Ready(Err(e.into()))
105 }
106 }
107 }
108
109 let mut this = self.project();
110
111 if let StateProj::Completed { .. } = this.state.as_mut().project() {
112 return Poll::Ready(Ok(()));
113 }
114
115 loop {
117 match mem::replace(&mut *this.state, State::Invalid) {
118 State::Expecting { mut io, header, protocol } => {
119 let msg = match Pin::new(&mut io).poll_next(cx)? {
120 Poll::Ready(Some(msg)) => msg,
121 Poll::Pending => {
122 *this.state = State::Expecting { io, header, protocol };
123 return Poll::Pending
124 },
125 Poll::Ready(None) => {
126 return Poll::Ready(Err(ProtocolError::IoError(
127 io::ErrorKind::UnexpectedEof.into()).into()));
128 }
129 };
130
131 if let Message::Header(h) = &msg {
132 if Some(h) == header.as_ref() {
133 *this.state = State::Expecting { io, protocol, header: None };
134 continue
135 }
136 }
137
138 if let Message::Protocol(p) = &msg {
139 if p.as_ref() == protocol.as_ref() {
140 log::debug!("Negotiated: Received confirmation for protocol: {}", p);
141 *this.state = State::Completed { io: io.into_inner() };
142 return Poll::Ready(Ok(()));
143 }
144 }
145
146 return Poll::Ready(Err(NegotiationError::Failed));
147 }
148
149 _ => panic!("Negotiated: Invalid state")
150 }
151 }
152 }
153
154 pub fn complete(self) -> NegotiatedComplete<TInner> {
157 NegotiatedComplete { inner: Some(self) }
158 }
159}
160
161#[pin_project(project = StateProj)]
163#[derive(Debug)]
164enum State<R> {
165 Expecting {
169 #[pin]
171 io: MessageReader<R>,
172 header: Option<HeaderLine>,
175 protocol: Protocol,
177 },
178
179 Completed { #[pin] io: R },
182
183 Invalid,
186}
187
188impl<TInner> AsyncRead for Negotiated<TInner>
189where
190 TInner: AsyncRead + AsyncWrite + Unpin
191{
192 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
193 -> Poll<Result<usize, io::Error>>
194 {
195 loop {
196 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
197 return io.poll_read(cx, buf);
199 }
200
201 match self.as_mut().poll(cx) {
204 Poll::Ready(Ok(())) => {},
205 Poll::Pending => return Poll::Pending,
206 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
207 }
208 }
209 }
210
211 fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>])
221 -> Poll<Result<usize, io::Error>>
222 {
223 loop {
224 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
225 return io.poll_read_vectored(cx, bufs)
227 }
228
229 match self.as_mut().poll(cx) {
232 Poll::Ready(Ok(())) => {},
233 Poll::Pending => return Poll::Pending,
234 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
235 }
236 }
237 }
238}
239
240impl<TInner> AsyncWrite for Negotiated<TInner>
241where
242 TInner: AsyncWrite + AsyncRead + Unpin
243{
244 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
245 match self.project().state.project() {
246 StateProj::Completed { io } => io.poll_write(cx, buf),
247 StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
248 StateProj::Invalid => panic!("Negotiated: Invalid state"),
249 }
250 }
251
252 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
253 match self.project().state.project() {
254 StateProj::Completed { io } => io.poll_flush(cx),
255 StateProj::Expecting { io, .. } => io.poll_flush(cx),
256 StateProj::Invalid => panic!("Negotiated: Invalid state"),
257 }
258 }
259
260 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
261 ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
264 ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
265
266 match self.project().state.project() {
268 StateProj::Completed { io, .. } => io.poll_close(cx),
269 StateProj::Expecting { io, .. } => io.poll_close(cx),
270 StateProj::Invalid => panic!("Negotiated: Invalid state"),
271 }
272 }
273
274 fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
275 -> Poll<Result<usize, io::Error>>
276 {
277 match self.project().state.project() {
278 StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
279 StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
280 StateProj::Invalid => panic!("Negotiated: Invalid state"),
281 }
282 }
283}
284
285#[derive(Debug)]
287pub enum NegotiationError {
288 ProtocolError(ProtocolError),
290
291 Failed,
293}
294
295impl From<ProtocolError> for NegotiationError {
296 fn from(err: ProtocolError) -> NegotiationError {
297 NegotiationError::ProtocolError(err)
298 }
299}
300
301impl From<io::Error> for NegotiationError {
302 fn from(err: io::Error) -> NegotiationError {
303 ProtocolError::from(err).into()
304 }
305}
306
307impl From<NegotiationError> for io::Error {
308 fn from(err: NegotiationError) -> io::Error {
309 if let NegotiationError::ProtocolError(e) = err {
310 return e.into()
311 }
312 io::Error::new(io::ErrorKind::Other, err)
313 }
314}
315
316impl Error for NegotiationError {
317 fn source(&self) -> Option<&(dyn Error + 'static)> {
318 match self {
319 NegotiationError::ProtocolError(err) => Some(err),
320 _ => None,
321 }
322 }
323}
324
325impl fmt::Display for NegotiationError {
326 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
327 match self {
328 NegotiationError::ProtocolError(p) =>
329 fmt.write_fmt(format_args!("Protocol error: {}", p)),
330 NegotiationError::Failed =>
331 fmt.write_str("Protocol negotiation failed.")
332 }
333 }
334}