async_ssh2_lite/session_stream/
impl_async_io.rs

1use core::{
2    task::{Context, Poll},
3    time::Duration,
4};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind};
6
7use async_io::{Async, Timer};
8use async_trait::async_trait;
9use futures_util::{future, pin_mut, ready};
10use ssh2::{BlockDirections, Error as Ssh2Error, Session};
11
12use super::{AsyncSessionStream, BlockDirectionsExt as _};
13use crate::{error::Error, util::ssh2_error_is_would_block};
14
15//
16#[async_trait]
17impl<S> AsyncSessionStream for Async<S>
18where
19    S: Send + Sync,
20{
21    async fn x_with<R>(
22        &self,
23        mut op: impl FnMut() -> Result<R, Ssh2Error> + Send,
24        sess: &Session,
25        expected_block_directions: BlockDirections,
26        sleep_dur: Option<Duration>,
27    ) -> Result<R, Error> {
28        loop {
29            match op() {
30                Ok(x) => return Ok(x),
31                Err(err) => {
32                    if !ssh2_error_is_would_block(&err) {
33                        return Err(err.into());
34                    }
35                }
36            }
37
38            match sess.block_directions() {
39                BlockDirections::None => continue,
40                BlockDirections::Inbound => {
41                    assert!(expected_block_directions.is_readable());
42
43                    self.readable().await?
44                }
45                BlockDirections::Outbound => {
46                    assert!(expected_block_directions.is_writable());
47
48                    self.writable().await?
49                }
50                BlockDirections::Both => {
51                    assert!(expected_block_directions.is_readable());
52                    assert!(expected_block_directions.is_writable());
53
54                    let (ret, _) = future::select(self.readable(), self.writable())
55                        .await
56                        .factor_first();
57                    ret?
58                }
59            }
60
61            if let Some(dur) = sleep_dur {
62                sleep_async_fn(dur).await;
63            }
64        }
65    }
66
67    fn poll_x_with<R>(
68        &self,
69        cx: &mut Context,
70        mut op: impl FnMut() -> Result<R, IoError> + Send,
71        sess: &Session,
72        expected_block_directions: BlockDirections,
73        sleep_dur: Option<Duration>,
74    ) -> Poll<Result<R, IoError>> {
75        match op() {
76            Err(err) if err.kind() == IoErrorKind::WouldBlock => {}
77            ret => return Poll::Ready(ret),
78        }
79
80        match sess.block_directions() {
81            BlockDirections::None => return Poll::Pending,
82            BlockDirections::Inbound => {
83                assert!(expected_block_directions.is_readable());
84
85                ready!(self.poll_readable(cx))?;
86            }
87            BlockDirections::Outbound => {
88                assert!(expected_block_directions.is_writable());
89
90                ready!(self.poll_writable(cx))?;
91            }
92            BlockDirections::Both => {
93                assert!(expected_block_directions.is_readable());
94                assert!(expected_block_directions.is_writable());
95
96                // Must first poll_writable, because session__scp_send_and_scp_recv.rs
97                ready!(self.poll_writable(cx))?;
98                ready!(self.poll_readable(cx))?;
99            }
100        }
101
102        if let Some(dur) = sleep_dur {
103            let waker = cx.waker().clone();
104            // TODO, maybe wrong
105            let timer = sleep(dur);
106            pin_mut!(timer);
107            ready!(future::Future::poll(timer, cx));
108            waker.wake();
109        } else {
110            let waker = cx.waker().clone();
111            waker.wake();
112        }
113
114        Poll::Pending
115    }
116}
117
118//
119//
120//
121async fn sleep_async_fn(dur: Duration) {
122    sleep(dur).await;
123}
124
125async fn sleep(dur: Duration) -> Timer {
126    Timer::after(dur)
127}