1use std::sync::Arc;
2
3use ssh2::{Channel, ExitSignal, ExtendedData, PtyModes, ReadWindow, Session, Stream, WriteWindow};
4
5use crate::{error::Error, session_stream::AsyncSessionStream};
6
7pub struct AsyncChannel<S> {
9 inner: Channel,
10 sess: Session,
11 stream: Arc<S>,
12}
13
14impl<S> AsyncChannel<S> {
15 pub(crate) fn from_parts(inner: Channel, sess: Session, stream: Arc<S>) -> Self {
16 Self {
17 inner,
18 sess,
19 stream,
20 }
21 }
22}
23
24impl<S> AsyncChannel<S>
25where
26 S: AsyncSessionStream + Send + Sync + 'static,
27{
28 pub async fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> {
29 self.stream
30 .rw_with(|| self.inner.setenv(var, val), &self.sess)
31 .await
32 }
33
34 pub async fn request_pty(
35 &mut self,
36 term: &str,
37 mode: Option<PtyModes>,
38 dim: Option<(u32, u32, u32, u32)>,
39 ) -> Result<(), Error> {
40 self.stream
41 .rw_with(
42 || self.inner.request_pty(term, mode.clone(), dim),
43 &self.sess,
44 )
45 .await
46 }
47
48 pub async fn request_pty_size(
49 &mut self,
50 width: u32,
51 height: u32,
52 width_px: Option<u32>,
53 height_px: Option<u32>,
54 ) -> Result<(), Error> {
55 self.stream
56 .rw_with(
57 || {
58 self.inner
59 .request_pty_size(width, height, width_px, height_px)
60 },
61 &self.sess,
62 )
63 .await
64 }
65
66 pub async fn request_auth_agent_forwarding(&mut self) -> Result<(), Error> {
67 self.stream
68 .rw_with(|| self.inner.request_auth_agent_forwarding(), &self.sess)
69 .await
70 }
71
72 pub async fn exec(&mut self, command: &str) -> Result<(), Error> {
73 self.stream
74 .rw_with(|| self.inner.exec(command), &self.sess)
75 .await
76 }
77
78 pub async fn shell(&mut self) -> Result<(), Error> {
79 self.stream.rw_with(|| self.inner.shell(), &self.sess).await
80 }
81
82 pub async fn subsystem(&mut self, system: &str) -> Result<(), Error> {
83 self.stream
84 .rw_with(|| self.inner.subsystem(system), &self.sess)
85 .await
86 }
87
88 pub async fn process_startup(
89 &mut self,
90 request: &str,
91 message: Option<&str>,
92 ) -> Result<(), Error> {
93 self.stream
94 .rw_with(|| self.inner.process_startup(request, message), &self.sess)
95 .await
96 }
97
98 pub fn stderr(&self) -> AsyncStream<S> {
99 AsyncStream::from_parts(self.inner.stderr(), self.sess.clone(), self.stream.clone())
100 }
101
102 pub fn stream(&self, stream_id: i32) -> AsyncStream<S> {
103 AsyncStream::from_parts(
104 self.inner.stream(stream_id),
105 self.sess.clone(),
106 self.stream.clone(),
107 )
108 }
109
110 pub async fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> {
111 self.stream
112 .rw_with(|| self.inner.handle_extended_data(mode), &self.sess)
113 .await
114 }
115
116 pub fn exit_status(&self) -> Result<i32, Error> {
117 self.inner.exit_status().map_err(Into::into)
118 }
119
120 pub async fn exit_signal(&self) -> Result<ExitSignal, Error> {
121 self.inner.exit_signal().map_err(Into::into)
122 }
123
124 pub fn read_window(&self) -> ReadWindow {
125 self.inner.read_window()
126 }
127 pub fn write_window(&self) -> WriteWindow {
128 self.inner.write_window()
129 }
130
131 pub async fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> {
132 self.stream
133 .rw_with(
134 || self.inner.adjust_receive_window(adjust, force),
135 &self.sess,
136 )
137 .await
138 }
139
140 pub fn eof(&self) -> bool {
141 self.inner.eof()
142 }
143
144 pub async fn send_eof(&mut self) -> Result<(), Error> {
145 self.stream
146 .rw_with(|| self.inner.send_eof(), &self.sess)
147 .await
148 }
149
150 pub async fn wait_eof(&mut self) -> Result<(), Error> {
151 self.stream
152 .rw_with(|| self.inner.wait_eof(), &self.sess)
153 .await
154 }
155
156 pub async fn close(&mut self) -> Result<(), Error> {
157 self.stream.rw_with(|| self.inner.close(), &self.sess).await
158 }
159
160 pub async fn wait_close(&mut self) -> Result<(), Error> {
161 self.stream
162 .rw_with(|| self.inner.wait_close(), &self.sess)
163 .await
164 }
165}
166
167pub struct AsyncStream<S> {
169 inner: Stream,
170 sess: Session,
171 stream: Arc<S>,
172}
173
174impl<S> AsyncStream<S> {
175 pub(crate) fn from_parts(inner: Stream, sess: Session, stream: Arc<S>) -> Self {
176 Self {
177 inner,
178 sess,
179 stream,
180 }
181 }
182}
183
184mod impl_futures_util {
185 use core::{
186 pin::Pin,
187 task::{Context, Poll},
188 };
189 use std::io::{Error as IoError, Read as _, Write as _};
190
191 use futures_util::io::{AsyncRead, AsyncWrite};
192
193 use super::{AsyncChannel, AsyncStream};
194 use crate::session_stream::AsyncSessionStream;
195
196 impl<S> AsyncRead for AsyncChannel<S>
198 where
199 S: AsyncSessionStream + Send + Sync + 'static,
200 {
201 fn poll_read(
202 self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &mut [u8],
205 ) -> Poll<Result<usize, IoError>> {
206 Pin::new(&mut self.stream(0)).poll_read(cx, buf)
207 }
208 }
209
210 impl<S> AsyncWrite for AsyncChannel<S>
211 where
212 S: AsyncSessionStream + Send + Sync + 'static,
213 {
214 fn poll_write(
215 self: Pin<&mut Self>,
216 cx: &mut Context,
217 buf: &[u8],
218 ) -> Poll<Result<usize, IoError>> {
219 Pin::new(&mut self.stream(0)).poll_write(cx, buf)
220 }
221
222 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
223 Pin::new(&mut self.stream(0)).poll_flush(cx)
224 }
225
226 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
227 Pin::new(&mut self.stream(0)).poll_close(cx)
228 }
229 }
230
231 impl<S> AsyncRead for AsyncStream<S>
233 where
234 S: AsyncSessionStream + Send + Sync + 'static,
235 {
236 fn poll_read(
237 self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 buf: &mut [u8],
240 ) -> Poll<Result<usize, IoError>> {
241 let this = self.get_mut();
242 let sess = this.sess.clone();
243 let inner = &mut this.inner;
244
245 this.stream.poll_read_with(cx, || inner.read(buf), &sess)
246 }
247 }
248
249 impl<S> AsyncWrite for AsyncStream<S>
250 where
251 S: AsyncSessionStream + Send + Sync + 'static,
252 {
253 fn poll_write(
254 self: Pin<&mut Self>,
255 cx: &mut Context,
256 buf: &[u8],
257 ) -> Poll<Result<usize, IoError>> {
258 let this = self.get_mut();
259 let sess = this.sess.clone();
260 let inner = &mut this.inner;
261
262 this.stream.poll_write_with(cx, || inner.write(buf), &sess)
263 }
264
265 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
266 let this = self.get_mut();
267 let sess = this.sess.clone();
268 let inner = &mut this.inner;
269
270 this.stream.poll_write_with(cx, || inner.flush(), &sess)
271 }
272
273 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
274 self.poll_flush(cx)
275 }
276 }
277}
278
279#[cfg(feature = "tokio")]
280mod impl_tokio {
281 use core::{
282 pin::Pin,
283 task::{Context, Poll},
284 };
285 use std::io::{Error as IoError, Read as _, Write as _};
286
287 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
288
289 use super::{AsyncChannel, AsyncStream};
290 use crate::session_stream::AsyncSessionStream;
291
292 impl<S> AsyncRead for AsyncChannel<S>
294 where
295 S: AsyncSessionStream + Send + Sync + 'static,
296 {
297 fn poll_read(
298 self: Pin<&mut Self>,
299 cx: &mut Context<'_>,
300 buf: &mut ReadBuf<'_>,
301 ) -> Poll<Result<(), IoError>> {
302 Pin::new(&mut self.stream(0)).poll_read(cx, buf)
303 }
304 }
305
306 impl<S> AsyncWrite for AsyncChannel<S>
307 where
308 S: AsyncSessionStream + Send + Sync + 'static,
309 {
310 fn poll_write(
311 self: Pin<&mut Self>,
312 cx: &mut Context<'_>,
313 buf: &[u8],
314 ) -> Poll<Result<usize, IoError>> {
315 Pin::new(&mut self.stream(0)).poll_write(cx, buf)
316 }
317
318 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
319 Pin::new(&mut self.stream(0)).poll_flush(cx)
320 }
321
322 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
323 Pin::new(&mut self.stream(0)).poll_shutdown(cx)
324 }
325 }
326
327 impl<S> AsyncRead for AsyncStream<S>
329 where
330 S: AsyncSessionStream + Send + Sync + 'static,
331 {
332 fn poll_read(
333 self: Pin<&mut Self>,
334 cx: &mut Context<'_>,
335 buf: &mut ReadBuf<'_>,
336 ) -> Poll<Result<(), IoError>> {
337 let this = self.get_mut();
338 let sess = this.sess.clone();
339 let inner = &mut this.inner;
340
341 this.stream.poll_read_with(
342 cx,
343 || {
344 let size = inner.read(buf.initialize_unfilled());
345 match size {
346 Ok(size) => {
347 buf.advance(size);
348 Ok(())
349 }
350 Err(e) => Err(e),
351 }
352 },
353 &sess,
354 )
355 }
356 }
357
358 impl<S> AsyncWrite for AsyncStream<S>
359 where
360 S: AsyncSessionStream + Send + Sync + 'static,
361 {
362 fn poll_write(
363 self: Pin<&mut Self>,
364 cx: &mut Context<'_>,
365 buf: &[u8],
366 ) -> Poll<Result<usize, IoError>> {
367 let this = self.get_mut();
368 let sess = this.sess.clone();
369 let inner = &mut this.inner;
370
371 this.stream.poll_write_with(cx, || inner.write(buf), &sess)
372 }
373
374 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
375 let this = self.get_mut();
376 let sess = this.sess.clone();
377 let inner = &mut this.inner;
378
379 this.stream.poll_write_with(cx, || inner.flush(), &sess)
380 }
381
382 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
383 self.poll_flush(cx)
384 }
385 }
386}