async_ssh2_lite/
channel.rs

1use std::sync::Arc;
2
3use ssh2::{Channel, ExitSignal, ExtendedData, PtyModes, ReadWindow, Session, Stream, WriteWindow};
4
5use crate::{error::Error, session_stream::AsyncSessionStream};
6
7//
8pub struct AsyncChannel<S> {
9    inner: Channel,
10    sess: Session,
11    stream: Arc<S>,
12}
13
14impl<S> AsyncChannel<S> {
15    pub(crate) fn from_parts(inner: Channel, sess: Session, stream: Arc<S>) -> Self {
16        Self {
17            inner,
18            sess,
19            stream,
20        }
21    }
22}
23
24impl<S> AsyncChannel<S>
25where
26    S: AsyncSessionStream + Send + Sync + 'static,
27{
28    pub async fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> {
29        self.stream
30            .rw_with(|| self.inner.setenv(var, val), &self.sess)
31            .await
32    }
33
34    pub async fn request_pty(
35        &mut self,
36        term: &str,
37        mode: Option<PtyModes>,
38        dim: Option<(u32, u32, u32, u32)>,
39    ) -> Result<(), Error> {
40        self.stream
41            .rw_with(
42                || self.inner.request_pty(term, mode.clone(), dim),
43                &self.sess,
44            )
45            .await
46    }
47
48    pub async fn request_pty_size(
49        &mut self,
50        width: u32,
51        height: u32,
52        width_px: Option<u32>,
53        height_px: Option<u32>,
54    ) -> Result<(), Error> {
55        self.stream
56            .rw_with(
57                || {
58                    self.inner
59                        .request_pty_size(width, height, width_px, height_px)
60                },
61                &self.sess,
62            )
63            .await
64    }
65
66    pub async fn request_auth_agent_forwarding(&mut self) -> Result<(), Error> {
67        self.stream
68            .rw_with(|| self.inner.request_auth_agent_forwarding(), &self.sess)
69            .await
70    }
71
72    pub async fn exec(&mut self, command: &str) -> Result<(), Error> {
73        self.stream
74            .rw_with(|| self.inner.exec(command), &self.sess)
75            .await
76    }
77
78    pub async fn shell(&mut self) -> Result<(), Error> {
79        self.stream.rw_with(|| self.inner.shell(), &self.sess).await
80    }
81
82    pub async fn subsystem(&mut self, system: &str) -> Result<(), Error> {
83        self.stream
84            .rw_with(|| self.inner.subsystem(system), &self.sess)
85            .await
86    }
87
88    pub async fn process_startup(
89        &mut self,
90        request: &str,
91        message: Option<&str>,
92    ) -> Result<(), Error> {
93        self.stream
94            .rw_with(|| self.inner.process_startup(request, message), &self.sess)
95            .await
96    }
97
98    pub fn stderr(&self) -> AsyncStream<S> {
99        AsyncStream::from_parts(self.inner.stderr(), self.sess.clone(), self.stream.clone())
100    }
101
102    pub fn stream(&self, stream_id: i32) -> AsyncStream<S> {
103        AsyncStream::from_parts(
104            self.inner.stream(stream_id),
105            self.sess.clone(),
106            self.stream.clone(),
107        )
108    }
109
110    pub async fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> {
111        self.stream
112            .rw_with(|| self.inner.handle_extended_data(mode), &self.sess)
113            .await
114    }
115
116    pub fn exit_status(&self) -> Result<i32, Error> {
117        self.inner.exit_status().map_err(Into::into)
118    }
119
120    pub async fn exit_signal(&self) -> Result<ExitSignal, Error> {
121        self.inner.exit_signal().map_err(Into::into)
122    }
123
124    pub fn read_window(&self) -> ReadWindow {
125        self.inner.read_window()
126    }
127    pub fn write_window(&self) -> WriteWindow {
128        self.inner.write_window()
129    }
130
131    pub async fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> {
132        self.stream
133            .rw_with(
134                || self.inner.adjust_receive_window(adjust, force),
135                &self.sess,
136            )
137            .await
138    }
139
140    pub fn eof(&self) -> bool {
141        self.inner.eof()
142    }
143
144    pub async fn send_eof(&mut self) -> Result<(), Error> {
145        self.stream
146            .rw_with(|| self.inner.send_eof(), &self.sess)
147            .await
148    }
149
150    pub async fn wait_eof(&mut self) -> Result<(), Error> {
151        self.stream
152            .rw_with(|| self.inner.wait_eof(), &self.sess)
153            .await
154    }
155
156    pub async fn close(&mut self) -> Result<(), Error> {
157        self.stream.rw_with(|| self.inner.close(), &self.sess).await
158    }
159
160    pub async fn wait_close(&mut self) -> Result<(), Error> {
161        self.stream
162            .rw_with(|| self.inner.wait_close(), &self.sess)
163            .await
164    }
165}
166
167//
168pub struct AsyncStream<S> {
169    inner: Stream,
170    sess: Session,
171    stream: Arc<S>,
172}
173
174impl<S> AsyncStream<S> {
175    pub(crate) fn from_parts(inner: Stream, sess: Session, stream: Arc<S>) -> Self {
176        Self {
177            inner,
178            sess,
179            stream,
180        }
181    }
182}
183
184mod impl_futures_util {
185    use core::{
186        pin::Pin,
187        task::{Context, Poll},
188    };
189    use std::io::{Error as IoError, Read as _, Write as _};
190
191    use futures_util::io::{AsyncRead, AsyncWrite};
192
193    use super::{AsyncChannel, AsyncStream};
194    use crate::session_stream::AsyncSessionStream;
195
196    //
197    impl<S> AsyncRead for AsyncChannel<S>
198    where
199        S: AsyncSessionStream + Send + Sync + 'static,
200    {
201        fn poll_read(
202            self: Pin<&mut Self>,
203            cx: &mut Context<'_>,
204            buf: &mut [u8],
205        ) -> Poll<Result<usize, IoError>> {
206            Pin::new(&mut self.stream(0)).poll_read(cx, buf)
207        }
208    }
209
210    impl<S> AsyncWrite for AsyncChannel<S>
211    where
212        S: AsyncSessionStream + Send + Sync + 'static,
213    {
214        fn poll_write(
215            self: Pin<&mut Self>,
216            cx: &mut Context,
217            buf: &[u8],
218        ) -> Poll<Result<usize, IoError>> {
219            Pin::new(&mut self.stream(0)).poll_write(cx, buf)
220        }
221
222        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
223            Pin::new(&mut self.stream(0)).poll_flush(cx)
224        }
225
226        fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
227            Pin::new(&mut self.stream(0)).poll_close(cx)
228        }
229    }
230
231    //
232    impl<S> AsyncRead for AsyncStream<S>
233    where
234        S: AsyncSessionStream + Send + Sync + 'static,
235    {
236        fn poll_read(
237            self: Pin<&mut Self>,
238            cx: &mut Context<'_>,
239            buf: &mut [u8],
240        ) -> Poll<Result<usize, IoError>> {
241            let this = self.get_mut();
242            let sess = this.sess.clone();
243            let inner = &mut this.inner;
244
245            this.stream.poll_read_with(cx, || inner.read(buf), &sess)
246        }
247    }
248
249    impl<S> AsyncWrite for AsyncStream<S>
250    where
251        S: AsyncSessionStream + Send + Sync + 'static,
252    {
253        fn poll_write(
254            self: Pin<&mut Self>,
255            cx: &mut Context,
256            buf: &[u8],
257        ) -> Poll<Result<usize, IoError>> {
258            let this = self.get_mut();
259            let sess = this.sess.clone();
260            let inner = &mut this.inner;
261
262            this.stream.poll_write_with(cx, || inner.write(buf), &sess)
263        }
264
265        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
266            let this = self.get_mut();
267            let sess = this.sess.clone();
268            let inner = &mut this.inner;
269
270            this.stream.poll_write_with(cx, || inner.flush(), &sess)
271        }
272
273        fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
274            self.poll_flush(cx)
275        }
276    }
277}
278
279#[cfg(feature = "tokio")]
280mod impl_tokio {
281    use core::{
282        pin::Pin,
283        task::{Context, Poll},
284    };
285    use std::io::{Error as IoError, Read as _, Write as _};
286
287    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
288
289    use super::{AsyncChannel, AsyncStream};
290    use crate::session_stream::AsyncSessionStream;
291
292    //
293    impl<S> AsyncRead for AsyncChannel<S>
294    where
295        S: AsyncSessionStream + Send + Sync + 'static,
296    {
297        fn poll_read(
298            self: Pin<&mut Self>,
299            cx: &mut Context<'_>,
300            buf: &mut ReadBuf<'_>,
301        ) -> Poll<Result<(), IoError>> {
302            Pin::new(&mut self.stream(0)).poll_read(cx, buf)
303        }
304    }
305
306    impl<S> AsyncWrite for AsyncChannel<S>
307    where
308        S: AsyncSessionStream + Send + Sync + 'static,
309    {
310        fn poll_write(
311            self: Pin<&mut Self>,
312            cx: &mut Context<'_>,
313            buf: &[u8],
314        ) -> Poll<Result<usize, IoError>> {
315            Pin::new(&mut self.stream(0)).poll_write(cx, buf)
316        }
317
318        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
319            Pin::new(&mut self.stream(0)).poll_flush(cx)
320        }
321
322        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
323            Pin::new(&mut self.stream(0)).poll_shutdown(cx)
324        }
325    }
326
327    //
328    impl<S> AsyncRead for AsyncStream<S>
329    where
330        S: AsyncSessionStream + Send + Sync + 'static,
331    {
332        fn poll_read(
333            self: Pin<&mut Self>,
334            cx: &mut Context<'_>,
335            buf: &mut ReadBuf<'_>,
336        ) -> Poll<Result<(), IoError>> {
337            let this = self.get_mut();
338            let sess = this.sess.clone();
339            let inner = &mut this.inner;
340
341            this.stream.poll_read_with(
342                cx,
343                || {
344                    let size = inner.read(buf.initialize_unfilled());
345                    match size {
346                        Ok(size) => {
347                            buf.advance(size);
348                            Ok(())
349                        }
350                        Err(e) => Err(e),
351                    }
352                },
353                &sess,
354            )
355        }
356    }
357
358    impl<S> AsyncWrite for AsyncStream<S>
359    where
360        S: AsyncSessionStream + Send + Sync + 'static,
361    {
362        fn poll_write(
363            self: Pin<&mut Self>,
364            cx: &mut Context<'_>,
365            buf: &[u8],
366        ) -> Poll<Result<usize, IoError>> {
367            let this = self.get_mut();
368            let sess = this.sess.clone();
369            let inner = &mut this.inner;
370
371            this.stream.poll_write_with(cx, || inner.write(buf), &sess)
372        }
373
374        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
375            let this = self.get_mut();
376            let sess = this.sess.clone();
377            let inner = &mut this.inner;
378
379            this.stream.poll_write_with(cx, || inner.flush(), &sess)
380        }
381
382        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
383            self.poll_flush(cx)
384        }
385    }
386}