1use crate::{Negotiated, NegotiationError, Version};
24use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
25
26use futures::{future::Either, prelude::*};
27use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
28
29pub fn dialer_select_proto<R, I>(
49 inner: R,
50 protocols: I,
51 version: Version
52) -> DialerSelectFuture<R, I::IntoIter>
53where
54 R: AsyncRead + AsyncWrite,
55 I: IntoIterator,
56 I::Item: AsRef<[u8]>
57{
58 let iter = protocols.into_iter();
59 if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
61 Either::Left(dialer_select_proto_serial(inner, iter, version))
62 } else {
63 Either::Right(dialer_select_proto_parallel(inner, iter, version))
64 }
65}
66
67pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
72
73pub(crate) fn dialer_select_proto_serial<R, I>(
80 inner: R,
81 protocols: I,
82 version: Version
83) -> DialerSelectSeq<R, I::IntoIter>
84where
85 R: AsyncRead + AsyncWrite,
86 I: IntoIterator,
87 I::Item: AsRef<[u8]>
88{
89 let protocols = protocols.into_iter().peekable();
90 DialerSelectSeq {
91 version,
92 protocols,
93 state: SeqState::SendHeader {
94 io: MessageIO::new(inner),
95 }
96 }
97}
98
99pub(crate) fn dialer_select_proto_parallel<R, I>(
109 inner: R,
110 protocols: I,
111 version: Version
112) -> DialerSelectPar<R, I::IntoIter>
113where
114 R: AsyncRead + AsyncWrite,
115 I: IntoIterator,
116 I::Item: AsRef<[u8]>
117{
118 let protocols = protocols.into_iter();
119 DialerSelectPar {
120 version,
121 protocols,
122 state: ParState::SendHeader {
123 io: MessageIO::new(inner)
124 }
125 }
126}
127
128#[pin_project::pin_project]
131pub struct DialerSelectSeq<R, I>
132where
133 R: AsyncRead + AsyncWrite,
134 I: Iterator,
135 I::Item: AsRef<[u8]>
136{
137 protocols: iter::Peekable<I>,
139 state: SeqState<R, I::Item>,
140 version: Version,
141}
142
143enum SeqState<R, N>
144where
145 R: AsyncRead + AsyncWrite,
146 N: AsRef<[u8]>
147{
148 SendHeader { io: MessageIO<R>, },
149 SendProtocol { io: MessageIO<R>, protocol: N },
150 FlushProtocol { io: MessageIO<R>, protocol: N },
151 AwaitProtocol { io: MessageIO<R>, protocol: N },
152 Done
153}
154
155impl<R, I> Future for DialerSelectSeq<R, I>
156where
157 R: AsyncRead + AsyncWrite + Unpin,
160 I: Iterator,
161 I::Item: AsRef<[u8]>
162{
163 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
164
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 let this = self.project();
167
168 loop {
169 match mem::replace(this.state, SeqState::Done) {
170 SeqState::SendHeader { mut io } => {
171 match Pin::new(&mut io).poll_ready(cx)? {
172 Poll::Ready(()) => {},
173 Poll::Pending => {
174 *this.state = SeqState::SendHeader { io };
175 return Poll::Pending
176 },
177 }
178
179 let h = HeaderLine::from(*this.version);
180 if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
181 return Poll::Ready(Err(From::from(err)));
182 }
183
184 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
185
186 *this.state = SeqState::SendProtocol { io, protocol };
189 }
190
191 SeqState::SendProtocol { mut io, protocol } => {
192 match Pin::new(&mut io).poll_ready(cx)? {
193 Poll::Ready(()) => {},
194 Poll::Pending => {
195 *this.state = SeqState::SendProtocol { io, protocol };
196 return Poll::Pending
197 },
198 }
199
200 let p = Protocol::try_from(protocol.as_ref())?;
201 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
202 return Poll::Ready(Err(From::from(err)));
203 }
204 log::debug!("Dialer: Proposed protocol: {}", p);
205
206 if this.protocols.peek().is_some() {
207 *this.state = SeqState::FlushProtocol { io, protocol }
208 } else {
209 match this.version {
210 Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol },
211 Version::V1Lazy => {
216 log::debug!("Dialer: Expecting proposed protocol: {}", p);
217 let hl = HeaderLine::from(Version::V1Lazy);
218 let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
219 return Poll::Ready(Ok((protocol, io)))
220 }
221 }
222 }
223 }
224
225 SeqState::FlushProtocol { mut io, protocol } => {
226 match Pin::new(&mut io).poll_flush(cx)? {
227 Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
228 Poll::Pending => {
229 *this.state = SeqState::FlushProtocol { io, protocol };
230 return Poll::Pending
231 },
232 }
233 }
234
235 SeqState::AwaitProtocol { mut io, protocol } => {
236 let msg = match Pin::new(&mut io).poll_next(cx)? {
237 Poll::Ready(Some(msg)) => msg,
238 Poll::Pending => {
239 *this.state = SeqState::AwaitProtocol { io, protocol };
240 return Poll::Pending
241 }
242 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
246 };
247
248 match msg {
249 Message::Header(v) if v == HeaderLine::from(*this.version) => {
250 *this.state = SeqState::AwaitProtocol { io, protocol };
251 }
252 Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
253 log::debug!("Dialer: Received confirmation for protocol: {}", p);
254 let io = Negotiated::completed(io.into_inner());
255 return Poll::Ready(Ok((protocol, io)));
256 }
257 Message::NotAvailable => {
258 log::debug!("Dialer: Received rejection of protocol: {}",
259 String::from_utf8_lossy(protocol.as_ref()));
260 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
261 *this.state = SeqState::SendProtocol { io, protocol }
262 }
263 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
264 }
265 }
266
267 SeqState::Done => panic!("SeqState::poll called after completion")
268 }
269 }
270 }
271}
272
273#[pin_project::pin_project]
277pub struct DialerSelectPar<R, I>
278where
279 R: AsyncRead + AsyncWrite,
280 I: Iterator,
281 I::Item: AsRef<[u8]>
282{
283 protocols: I,
284 state: ParState<R, I::Item>,
285 version: Version,
286}
287
288enum ParState<R, N>
289where
290 R: AsyncRead + AsyncWrite,
291 N: AsRef<[u8]>
292{
293 SendHeader { io: MessageIO<R> },
294 SendProtocolsRequest { io: MessageIO<R> },
295 Flush { io: MessageIO<R> },
296 RecvProtocols { io: MessageIO<R> },
297 SendProtocol { io: MessageIO<R>, protocol: N },
298 Done
299}
300
301impl<R, I> Future for DialerSelectPar<R, I>
302where
303 R: AsyncRead + AsyncWrite + Unpin,
306 I: Iterator,
307 I::Item: AsRef<[u8]>
308{
309 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
310
311 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
312 let this = self.project();
313
314 loop {
315 match mem::replace(this.state, ParState::Done) {
316 ParState::SendHeader { mut io } => {
317 match Pin::new(&mut io).poll_ready(cx)? {
318 Poll::Ready(()) => {},
319 Poll::Pending => {
320 *this.state = ParState::SendHeader { io };
321 return Poll::Pending
322 },
323 }
324
325 let msg = Message::Header(HeaderLine::from(*this.version));
326 if let Err(err) = Pin::new(&mut io).start_send(msg) {
327 return Poll::Ready(Err(From::from(err)));
328 }
329
330 *this.state = ParState::SendProtocolsRequest { io };
331 }
332
333 ParState::SendProtocolsRequest { mut io } => {
334 match Pin::new(&mut io).poll_ready(cx)? {
335 Poll::Ready(()) => {},
336 Poll::Pending => {
337 *this.state = ParState::SendProtocolsRequest { io };
338 return Poll::Pending
339 },
340 }
341
342 if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
343 return Poll::Ready(Err(From::from(err)));
344 }
345
346 log::debug!("Dialer: Requested supported protocols.");
347 *this.state = ParState::Flush { io }
348 }
349
350 ParState::Flush { mut io } => {
351 match Pin::new(&mut io).poll_flush(cx)? {
352 Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
353 Poll::Pending => {
354 *this.state = ParState::Flush { io };
355 return Poll::Pending
356 },
357 }
358 }
359
360 ParState::RecvProtocols { mut io } => {
361 let msg = match Pin::new(&mut io).poll_next(cx)? {
362 Poll::Ready(Some(msg)) => msg,
363 Poll::Pending => {
364 *this.state = ParState::RecvProtocols { io };
365 return Poll::Pending
366 }
367 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
371 };
372
373 match &msg {
374 Message::Header(h) if h == &HeaderLine::from(*this.version) => {
375 *this.state = ParState::RecvProtocols { io }
376 }
377 Message::Protocols(supported) => {
378 let protocol = this.protocols.by_ref()
379 .find(|p| supported.iter().any(|s|
380 s.as_ref() == p.as_ref()))
381 .ok_or(NegotiationError::Failed)?;
382 log::debug!("Dialer: Found supported protocol: {}",
383 String::from_utf8_lossy(protocol.as_ref()));
384 *this.state = ParState::SendProtocol { io, protocol };
385 }
386 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
387 }
388 }
389
390 ParState::SendProtocol { mut io, protocol } => {
391 match Pin::new(&mut io).poll_ready(cx)? {
392 Poll::Ready(()) => {},
393 Poll::Pending => {
394 *this.state = ParState::SendProtocol { io, protocol };
395 return Poll::Pending
396 },
397 }
398
399 let p = Protocol::try_from(protocol.as_ref())?;
400 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
401 return Poll::Ready(Err(From::from(err)));
402 }
403
404 log::debug!("Dialer: Expecting proposed protocol: {}", p);
405 let io = Negotiated::expecting(io.into_reader(), p, None);
406
407 return Poll::Ready(Ok((protocol, io)))
408 }
409
410 ParState::Done => panic!("ParState::poll called after completion")
411 }
412 }
413 }
414}