async_ssh2_lite/session_stream/
impl_tokio.rs

1use core::{
2    task::{Context, Poll},
3    time::Duration,
4};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind};
6
7use async_trait::async_trait;
8use futures_util::ready;
9use ssh2::{BlockDirections, Error as Ssh2Error, Session};
10use tokio::net::TcpStream;
11#[cfg(unix)]
12use tokio::net::UnixStream;
13
14use super::{AsyncSessionStream, BlockDirectionsExt as _};
15use crate::{error::Error, util::ssh2_error_is_would_block};
16
17//
18#[async_trait]
19impl AsyncSessionStream for TcpStream {
20    async fn x_with<R>(
21        &self,
22        mut op: impl FnMut() -> Result<R, Ssh2Error> + Send,
23        sess: &Session,
24        expected_block_directions: BlockDirections,
25        sleep_dur: Option<Duration>,
26    ) -> Result<R, Error> {
27        loop {
28            match op() {
29                Ok(x) => return Ok(x),
30                Err(err) => {
31                    if !ssh2_error_is_would_block(&err) {
32                        return Err(err.into());
33                    }
34                }
35            }
36
37            match sess.block_directions() {
38                BlockDirections::None => continue,
39                BlockDirections::Inbound => {
40                    assert!(expected_block_directions.is_readable());
41
42                    self.readable().await?
43                }
44                BlockDirections::Outbound => {
45                    assert!(expected_block_directions.is_writable());
46
47                    self.writable().await?
48                }
49                BlockDirections::Both => {
50                    assert!(expected_block_directions.is_readable());
51                    assert!(expected_block_directions.is_writable());
52
53                    self.ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
54                        .await?;
55                }
56            }
57
58            if let Some(dur) = sleep_dur {
59                sleep_async_fn(dur).await;
60            }
61        }
62    }
63
64    fn poll_x_with<R>(
65        &self,
66        cx: &mut Context,
67        mut op: impl FnMut() -> Result<R, IoError> + Send,
68        sess: &Session,
69        expected_block_directions: BlockDirections,
70        sleep_dur: Option<Duration>,
71    ) -> Poll<Result<R, IoError>> {
72        match op() {
73            Err(err) if err.kind() == IoErrorKind::WouldBlock => {}
74            ret => return Poll::Ready(ret),
75        }
76
77        match sess.block_directions() {
78            BlockDirections::None => return Poll::Pending,
79            BlockDirections::Inbound => {
80                assert!(expected_block_directions.is_readable());
81
82                ready!(self.poll_read_ready(cx))?;
83            }
84            BlockDirections::Outbound => {
85                assert!(expected_block_directions.is_writable());
86
87                ready!(self.poll_write_ready(cx))?;
88            }
89            BlockDirections::Both => {
90                assert!(expected_block_directions.is_readable());
91                assert!(expected_block_directions.is_writable());
92
93                ready!(self.poll_write_ready(cx))?;
94                ready!(self.poll_read_ready(cx))?;
95            }
96        }
97
98        if let Some(dur) = sleep_dur {
99            let waker = cx.waker().clone();
100            tokio::spawn(async move {
101                sleep_async_fn(dur).await;
102                waker.wake();
103            });
104        } else {
105            let waker = cx.waker().clone();
106            waker.wake();
107        }
108
109        Poll::Pending
110    }
111}
112
113#[cfg(unix)]
114#[async_trait]
115impl AsyncSessionStream for UnixStream {
116    async fn x_with<R>(
117        &self,
118        mut op: impl FnMut() -> Result<R, Ssh2Error> + Send,
119        sess: &Session,
120        expected_block_directions: BlockDirections,
121        sleep_dur: Option<Duration>,
122    ) -> Result<R, Error> {
123        loop {
124            match op() {
125                Ok(x) => return Ok(x),
126                Err(err) => {
127                    if !ssh2_error_is_would_block(&err) {
128                        return Err(err.into());
129                    }
130                }
131            }
132
133            match sess.block_directions() {
134                BlockDirections::None => continue,
135                BlockDirections::Inbound => {
136                    assert!(expected_block_directions.is_readable());
137
138                    self.readable().await?
139                }
140                BlockDirections::Outbound => {
141                    assert!(expected_block_directions.is_writable());
142
143                    self.writable().await?
144                }
145                BlockDirections::Both => {
146                    assert!(expected_block_directions.is_readable());
147                    assert!(expected_block_directions.is_writable());
148
149                    self.ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
150                        .await?;
151                }
152            }
153
154            if let Some(dur) = sleep_dur {
155                sleep_async_fn(dur).await;
156            }
157        }
158    }
159
160    fn poll_x_with<R>(
161        &self,
162        cx: &mut Context,
163        mut op: impl FnMut() -> Result<R, IoError> + Send,
164        sess: &Session,
165        expected_block_directions: BlockDirections,
166        sleep_dur: Option<Duration>,
167    ) -> Poll<Result<R, IoError>> {
168        match op() {
169            Err(err) if err.kind() == IoErrorKind::WouldBlock => {}
170            ret => return Poll::Ready(ret),
171        }
172
173        match sess.block_directions() {
174            BlockDirections::None => return Poll::Pending,
175            BlockDirections::Inbound => {
176                assert!(expected_block_directions.is_readable());
177
178                ready!(self.poll_read_ready(cx))?;
179            }
180            BlockDirections::Outbound => {
181                assert!(expected_block_directions.is_writable());
182
183                ready!(self.poll_write_ready(cx))?;
184            }
185            BlockDirections::Both => {
186                assert!(expected_block_directions.is_readable());
187                assert!(expected_block_directions.is_writable());
188
189                ready!(self.poll_write_ready(cx))?;
190                ready!(self.poll_read_ready(cx))?;
191            }
192        }
193
194        if let Some(dur) = sleep_dur {
195            let waker = cx.waker().clone();
196            tokio::spawn(async move {
197                sleep_async_fn(dur).await;
198                waker.wake();
199            });
200        } else {
201            let waker = cx.waker().clone();
202            waker.wake();
203        }
204
205        Poll::Pending
206    }
207}
208
209//
210//
211//
212async fn sleep_async_fn(dur: Duration) {
213    sleep(dur).await;
214}
215
216fn sleep(dur: Duration) -> tokio::time::Sleep {
217    tokio::time::sleep(tokio::time::Duration::from_millis(dur.as_millis() as u64))
218}