async_ssh2_lite/session_stream/
impl_tokio.rs1use core::{
2 task::{Context, Poll},
3 time::Duration,
4};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind};
6
7use async_trait::async_trait;
8use futures_util::ready;
9use ssh2::{BlockDirections, Error as Ssh2Error, Session};
10use tokio::net::TcpStream;
11#[cfg(unix)]
12use tokio::net::UnixStream;
13
14use super::{AsyncSessionStream, BlockDirectionsExt as _};
15use crate::{error::Error, util::ssh2_error_is_would_block};
16
17#[async_trait]
19impl AsyncSessionStream for TcpStream {
20 async fn x_with<R>(
21 &self,
22 mut op: impl FnMut() -> Result<R, Ssh2Error> + Send,
23 sess: &Session,
24 expected_block_directions: BlockDirections,
25 sleep_dur: Option<Duration>,
26 ) -> Result<R, Error> {
27 loop {
28 match op() {
29 Ok(x) => return Ok(x),
30 Err(err) => {
31 if !ssh2_error_is_would_block(&err) {
32 return Err(err.into());
33 }
34 }
35 }
36
37 match sess.block_directions() {
38 BlockDirections::None => continue,
39 BlockDirections::Inbound => {
40 assert!(expected_block_directions.is_readable());
41
42 self.readable().await?
43 }
44 BlockDirections::Outbound => {
45 assert!(expected_block_directions.is_writable());
46
47 self.writable().await?
48 }
49 BlockDirections::Both => {
50 assert!(expected_block_directions.is_readable());
51 assert!(expected_block_directions.is_writable());
52
53 self.ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
54 .await?;
55 }
56 }
57
58 if let Some(dur) = sleep_dur {
59 sleep_async_fn(dur).await;
60 }
61 }
62 }
63
64 fn poll_x_with<R>(
65 &self,
66 cx: &mut Context,
67 mut op: impl FnMut() -> Result<R, IoError> + Send,
68 sess: &Session,
69 expected_block_directions: BlockDirections,
70 sleep_dur: Option<Duration>,
71 ) -> Poll<Result<R, IoError>> {
72 match op() {
73 Err(err) if err.kind() == IoErrorKind::WouldBlock => {}
74 ret => return Poll::Ready(ret),
75 }
76
77 match sess.block_directions() {
78 BlockDirections::None => return Poll::Pending,
79 BlockDirections::Inbound => {
80 assert!(expected_block_directions.is_readable());
81
82 ready!(self.poll_read_ready(cx))?;
83 }
84 BlockDirections::Outbound => {
85 assert!(expected_block_directions.is_writable());
86
87 ready!(self.poll_write_ready(cx))?;
88 }
89 BlockDirections::Both => {
90 assert!(expected_block_directions.is_readable());
91 assert!(expected_block_directions.is_writable());
92
93 ready!(self.poll_write_ready(cx))?;
94 ready!(self.poll_read_ready(cx))?;
95 }
96 }
97
98 if let Some(dur) = sleep_dur {
99 let waker = cx.waker().clone();
100 tokio::spawn(async move {
101 sleep_async_fn(dur).await;
102 waker.wake();
103 });
104 } else {
105 let waker = cx.waker().clone();
106 waker.wake();
107 }
108
109 Poll::Pending
110 }
111}
112
113#[cfg(unix)]
114#[async_trait]
115impl AsyncSessionStream for UnixStream {
116 async fn x_with<R>(
117 &self,
118 mut op: impl FnMut() -> Result<R, Ssh2Error> + Send,
119 sess: &Session,
120 expected_block_directions: BlockDirections,
121 sleep_dur: Option<Duration>,
122 ) -> Result<R, Error> {
123 loop {
124 match op() {
125 Ok(x) => return Ok(x),
126 Err(err) => {
127 if !ssh2_error_is_would_block(&err) {
128 return Err(err.into());
129 }
130 }
131 }
132
133 match sess.block_directions() {
134 BlockDirections::None => continue,
135 BlockDirections::Inbound => {
136 assert!(expected_block_directions.is_readable());
137
138 self.readable().await?
139 }
140 BlockDirections::Outbound => {
141 assert!(expected_block_directions.is_writable());
142
143 self.writable().await?
144 }
145 BlockDirections::Both => {
146 assert!(expected_block_directions.is_readable());
147 assert!(expected_block_directions.is_writable());
148
149 self.ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
150 .await?;
151 }
152 }
153
154 if let Some(dur) = sleep_dur {
155 sleep_async_fn(dur).await;
156 }
157 }
158 }
159
160 fn poll_x_with<R>(
161 &self,
162 cx: &mut Context,
163 mut op: impl FnMut() -> Result<R, IoError> + Send,
164 sess: &Session,
165 expected_block_directions: BlockDirections,
166 sleep_dur: Option<Duration>,
167 ) -> Poll<Result<R, IoError>> {
168 match op() {
169 Err(err) if err.kind() == IoErrorKind::WouldBlock => {}
170 ret => return Poll::Ready(ret),
171 }
172
173 match sess.block_directions() {
174 BlockDirections::None => return Poll::Pending,
175 BlockDirections::Inbound => {
176 assert!(expected_block_directions.is_readable());
177
178 ready!(self.poll_read_ready(cx))?;
179 }
180 BlockDirections::Outbound => {
181 assert!(expected_block_directions.is_writable());
182
183 ready!(self.poll_write_ready(cx))?;
184 }
185 BlockDirections::Both => {
186 assert!(expected_block_directions.is_readable());
187 assert!(expected_block_directions.is_writable());
188
189 ready!(self.poll_write_ready(cx))?;
190 ready!(self.poll_read_ready(cx))?;
191 }
192 }
193
194 if let Some(dur) = sleep_dur {
195 let waker = cx.waker().clone();
196 tokio::spawn(async move {
197 sleep_async_fn(dur).await;
198 waker.wake();
199 });
200 } else {
201 let waker = cx.waker().clone();
202 waker.wake();
203 }
204
205 Poll::Pending
206 }
207}
208
209async fn sleep_async_fn(dur: Duration) {
213 sleep(dur).await;
214}
215
216fn sleep(dur: Duration) -> tokio::time::Sleep {
217 tokio::time::sleep(tokio::time::Duration::from_millis(dur.as_millis() as u64))
218}