1use core::time::Duration;
2use std::{path::Path, sync::Arc};
3
4#[cfg(unix)]
5use std::os::unix::io::AsRawFd;
6#[cfg(windows)]
7use std::os::windows::io::{AsRawSocket, BorrowedSocket};
8
9use ssh2::{
10 BlockDirections, DisconnectCode, Error as Ssh2Error, HashType, HostKeyType,
11 KeyboardInteractivePrompt, KnownHosts, MethodType, PublicKey, ScpFileStat, Session, TraceFlags,
12};
13
14use crate::{
15 agent::AsyncAgent, channel::AsyncChannel, error::Error, listener::AsyncListener,
16 session_stream::AsyncSessionStream, sftp::AsyncSftp,
17};
18
19pub struct AsyncSession<S> {
21 inner: Session,
22 stream: Arc<S>,
23}
24
25impl<S> Clone for AsyncSession<S> {
26 fn clone(&self) -> Self {
27 Self {
28 inner: self.inner.clone(),
29 stream: self.stream.clone(),
30 }
31 }
32}
33
34#[cfg(unix)]
35impl<S> AsyncSession<S>
36where
37 S: AsRawFd + 'static,
38{
39 pub fn new(
40 stream: S,
41 configuration: impl Into<Option<SessionConfiguration>>,
42 ) -> Result<Self, Error> {
43 let mut session = get_session(configuration)?;
44 session.set_tcp_stream(stream.as_raw_fd());
45
46 let stream = Arc::new(stream);
47
48 Ok(Self {
49 inner: session,
50 stream,
51 })
52 }
53}
54
55#[cfg(windows)]
56impl<S> AsyncSession<S>
57where
58 S: AsRawSocket + 'static,
59{
60 pub fn new(
61 stream: S,
62 configuration: impl Into<Option<SessionConfiguration>>,
63 ) -> Result<Self, Error> {
64 let mut session = get_session(configuration)?;
65 session.set_tcp_stream(unsafe { BorrowedSocket::borrow_raw(stream.as_raw_socket()) });
66
67 let stream = Arc::new(stream);
68
69 Ok(Self {
70 inner: session,
71 stream,
72 })
73 }
74}
75
76#[cfg(feature = "async-io")]
77impl AsyncSession<crate::AsyncIoTcpStream> {
78 pub async fn connect<A: Into<std::net::SocketAddr>>(
79 addr: A,
80 configuration: impl Into<Option<SessionConfiguration>>,
81 ) -> Result<Self, Error> {
82 let stream = crate::AsyncIoTcpStream::connect(addr).await?;
83
84 Self::new(stream, configuration)
85 }
86}
87
88#[cfg(all(unix, feature = "async-io"))]
89impl AsyncSession<crate::AsyncIoUnixStream> {
90 #[cfg(unix)]
91 pub async fn connect<P: AsRef<Path>>(
92 path: P,
93 configuration: impl Into<Option<SessionConfiguration>>,
94 ) -> Result<Self, Error> {
95 let stream = crate::AsyncIoUnixStream::connect(path).await?;
96
97 Self::new(stream, configuration)
98 }
99}
100
101#[cfg(feature = "tokio")]
102impl AsyncSession<crate::TokioTcpStream> {
103 pub async fn connect<A: Into<std::net::SocketAddr>>(
104 addr: A,
105 configuration: impl Into<Option<SessionConfiguration>>,
106 ) -> Result<Self, Error> {
107 let stream = crate::TokioTcpStream::connect(addr.into()).await?;
108
109 Self::new(stream, configuration)
110 }
111}
112
113#[cfg(all(unix, feature = "tokio"))]
114impl AsyncSession<crate::TokioUnixStream> {
115 #[cfg(unix)]
116 pub async fn connect<P: AsRef<Path>>(
117 path: P,
118 configuration: impl Into<Option<SessionConfiguration>>,
119 ) -> Result<Self, Error> {
120 let stream = crate::TokioUnixStream::connect(path).await?;
121
122 Self::new(stream, configuration)
123 }
124}
125
126impl<S> AsyncSession<S> {
127 pub fn is_blocking(&self) -> bool {
128 self.inner.is_blocking()
129 }
130
131 pub fn banner(&self) -> Option<&str> {
132 self.inner.banner()
133 }
134
135 pub fn banner_bytes(&self) -> Option<&[u8]> {
136 self.inner.banner_bytes()
137 }
138
139 pub fn timeout(&self) -> u32 {
140 self.inner.timeout()
141 }
142
143 pub fn trace(&self, bitmask: TraceFlags) {
144 self.inner.trace(bitmask)
145 }
146}
147
148impl<S> AsyncSession<S>
149where
150 S: AsyncSessionStream + Send + Sync + 'static,
151{
152 pub async fn handshake(&mut self) -> Result<(), Error> {
153 let sess = self.inner.clone();
154 self.stream.rw_with(|| self.inner.handshake(), &sess).await
155 }
156
157 pub async fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> {
158 self.stream
159 .rw_with(
160 || self.inner.userauth_password(username, password),
161 &self.inner,
162 )
163 .await
164 }
165
166 #[allow(unknown_lints)]
167 #[allow(clippy::needless_pass_by_ref_mut)]
168 pub async fn userauth_keyboard_interactive<P: KeyboardInteractivePrompt + Send>(
169 &self,
170 username: &str,
171 prompter: &mut P,
172 ) -> Result<(), Error> {
173 self.stream
174 .rw_with(
175 || self.inner.userauth_keyboard_interactive(username, prompter),
176 &self.inner,
177 )
178 .await
179 }
180
181 pub async fn userauth_agent(&self, username: &str) -> Result<(), Error> {
182 let mut agent = self.agent()?;
183 agent.connect().await?;
184 agent.list_identities().await?;
185 let identities = agent.identities()?;
186 let identity = match identities.first() {
187 Some(identity) => identity,
188 None => return Err(Error::Other("no identities found in the ssh agent".into())),
189 };
190 agent.userauth(username, identity).await
191 }
192
193 pub async fn userauth_pubkey_file(
194 &self,
195 username: &str,
196 pubkey: Option<&Path>,
197 privatekey: &Path,
198 passphrase: Option<&str>,
199 ) -> Result<(), Error> {
200 self.stream
201 .rw_with(
202 || {
203 self.inner
204 .userauth_pubkey_file(username, pubkey, privatekey, passphrase)
205 },
206 &self.inner,
207 )
208 .await
209 }
210
211 #[cfg(any(unix, feature = "vendored-openssl", feature = "openssl-on-win32"))]
212 pub async fn userauth_pubkey_memory(
213 &self,
214 username: &str,
215 pubkeydata: Option<&str>,
216 privatekeydata: &str,
217 passphrase: Option<&str>,
218 ) -> Result<(), Error> {
219 self.stream
220 .rw_with(
221 || {
222 self.inner.userauth_pubkey_memory(
223 username,
224 pubkeydata,
225 privatekeydata,
226 passphrase,
227 )
228 },
229 &self.inner,
230 )
231 .await
232 }
233
234 pub async fn userauth_hostbased_file(
235 &self,
236 username: &str,
237 publickey: &Path,
238 privatekey: &Path,
239 passphrase: Option<&str>,
240 hostname: &str,
241 local_username: Option<&str>,
242 ) -> Result<(), Error> {
243 self.stream
244 .rw_with(
245 || {
246 self.inner.userauth_hostbased_file(
247 username,
248 publickey,
249 privatekey,
250 passphrase,
251 hostname,
252 local_username,
253 )
254 },
255 &self.inner,
256 )
257 .await
258 }
259
260 pub fn authenticated(&self) -> bool {
261 self.inner.authenticated()
262 }
263
264 pub async fn auth_methods<'a>(&'a self, username: &'a str) -> Result<&str, Error> {
265 self.stream
266 .rw_with(|| self.inner.auth_methods(username), &self.inner)
267 .await
268 }
269
270 pub async fn method_pref(&self, method_type: MethodType, prefs: &str) -> Result<(), Error> {
271 self.stream
272 .rw_with(|| self.inner.method_pref(method_type, prefs), &self.inner)
273 .await
274 }
275
276 pub fn methods(&self, method_type: MethodType) -> Option<&str> {
277 self.inner.methods(method_type)
278 }
279
280 pub async fn supported_algs(
281 &self,
282 method_type: MethodType,
283 ) -> Result<Vec<&'static str>, Error> {
284 self.stream
285 .rw_with(|| self.inner.supported_algs(method_type), &self.inner)
286 .await
287 }
288
289 pub fn agent(&self) -> Result<AsyncAgent<S>, Error> {
290 let agent = self.inner.agent()?;
291
292 Ok(AsyncAgent::from_parts(
293 agent,
294 self.inner.clone(),
295 self.stream.clone(),
296 ))
297 }
298
299 pub fn known_hosts(&self) -> Result<KnownHosts, Error> {
300 self.inner.known_hosts().map_err(Into::into)
301 }
302
303 pub async fn channel_session(&self) -> Result<AsyncChannel<S>, Error> {
304 let channel = self
305 .stream
306 .rw_with(|| self.inner.channel_session(), &self.inner)
307 .await?;
308
309 Ok(AsyncChannel::from_parts(
310 channel,
311 self.inner.clone(),
312 self.stream.clone(),
313 ))
314 }
315
316 pub async fn channel_direct_tcpip(
317 &self,
318 host: &str,
319 port: u16,
320 src: Option<(&str, u16)>,
321 ) -> Result<AsyncChannel<S>, Error> {
322 let channel = self
323 .stream
324 .rw_with(
325 || self.inner.channel_direct_tcpip(host, port, src),
326 &self.inner,
327 )
328 .await?;
329
330 Ok(AsyncChannel::from_parts(
331 channel,
332 self.inner.clone(),
333 self.stream.clone(),
334 ))
335 }
336
337 pub async fn channel_forward_listen(
338 &self,
339 remote_port: u16,
340 host: Option<&str>,
341 queue_maxsize: Option<u32>,
342 ) -> Result<(AsyncListener<S>, u16), Error> {
343 let (listener, port) = self
344 .stream
345 .rw_with(
346 || {
347 self.inner
348 .channel_forward_listen(remote_port, host, queue_maxsize)
349 },
350 &self.inner,
351 )
352 .await?;
353
354 Ok((
355 AsyncListener::from_parts(listener, self.inner.clone(), self.stream.clone()),
356 port,
357 ))
358 }
359
360 pub async fn scp_recv(&self, path: &Path) -> Result<(AsyncChannel<S>, ScpFileStat), Error> {
361 let (channel, scp_file_stat) = self
362 .stream
363 .rw_with(|| self.inner.scp_recv(path), &self.inner)
364 .await?;
365
366 Ok((
367 AsyncChannel::from_parts(channel, self.inner.clone(), self.stream.clone()),
368 scp_file_stat,
369 ))
370 }
371
372 pub async fn scp_send(
373 &self,
374 remote_path: &Path,
375 mode: i32,
376 size: u64,
377 times: Option<(u64, u64)>,
378 ) -> Result<AsyncChannel<S>, Error> {
379 let channel = self
380 .stream
381 .rw_with(
382 || self.inner.scp_send(remote_path, mode, size, times),
383 &self.inner,
384 )
385 .await?;
386
387 Ok(AsyncChannel::from_parts(
388 channel,
389 self.inner.clone(),
390 self.stream.clone(),
391 ))
392 }
393
394 pub async fn sftp(&self) -> Result<AsyncSftp<S>, Error> {
395 let sftp = self
396 .stream
397 .rw_with(|| self.inner.sftp(), &self.inner)
398 .await?;
399
400 Ok(AsyncSftp::from_parts(
401 sftp,
402 self.inner.clone(),
403 self.stream.clone(),
404 ))
405 }
406
407 pub async fn channel_open(
408 &self,
409 channel_type: &str,
410 window_size: u32,
411 packet_size: u32,
412 message: Option<&str>,
413 ) -> Result<AsyncChannel<S>, Error> {
414 let channel = self
415 .stream
416 .rw_with(
417 || {
418 self.inner
419 .channel_open(channel_type, window_size, packet_size, message)
420 },
421 &self.inner,
422 )
423 .await?;
424
425 Ok(AsyncChannel::from_parts(
426 channel,
427 self.inner.clone(),
428 self.stream.clone(),
429 ))
430 }
431
432 pub fn host_key(&self) -> Option<(&[u8], HostKeyType)> {
433 self.inner.host_key()
434 }
435
436 pub fn host_key_hash(&self, hash: HashType) -> Option<&[u8]> {
437 self.inner.host_key_hash(hash)
438 }
439
440 pub async fn keepalive_send(&self) -> Result<u32, Error> {
441 self.stream
442 .rw_with(|| self.inner.keepalive_send(), &self.inner)
443 .await
444 }
445
446 pub async fn disconnect(
447 &self,
448 reason: Option<DisconnectCode>,
449 description: &str,
450 lang: Option<&str>,
451 ) -> Result<(), Error> {
452 self.stream
453 .rw_with(
454 || self.inner.disconnect(reason, description, lang),
455 &self.inner,
456 )
457 .await
458 }
459
460 pub fn block_directions(&self) -> BlockDirections {
461 self.inner.block_directions()
462 }
463}
464
465#[cfg(feature = "tokio")]
466impl<S> AsyncSession<S>
467where
468 S: AsyncSessionStream + Send + Sync + 'static,
469{
470 pub async fn remote_port_forwarding(
471 &self,
472 remote_port: u16,
473 host: Option<&str>,
474 queue_maxsize: Option<u32>,
475 local: crate::util::ConnectInfo,
476 ) -> Result<(), Error> {
477 use std::io::Error as IoError;
478
479 use futures_util::{select, FutureExt as _};
480 use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
481
482 #[cfg(unix)]
483 use crate::TokioUnixStream;
484 use crate::{util::ConnectInfo, TokioTcpStream};
485
486 match local {
487 ConnectInfo::Tcp(addr) => {
488 let (mut listener, _remote_port) = self
489 .channel_forward_listen(remote_port, host, queue_maxsize)
490 .await?;
491
492 loop {
495 match listener.accept().await {
496 Ok(mut channel) => {
497 let join_handle: tokio::task::JoinHandle<Result<(), IoError>> =
498 tokio::task::spawn(async move {
499 let mut stream = TokioTcpStream::connect(addr).await?;
500
501 let mut buf_channel = vec![0; 2048];
502 let mut buf_stream = vec![0; 2048];
503
504 loop {
505 select! {
506 ret_channel_read = futures_util::AsyncReadExt::read(&mut channel, &mut buf_channel).fuse() => match ret_channel_read {
507 Ok(0) => {
508 break
509 },
510 Ok(n) => {
511 #[allow(clippy::map_identity)]
512 stream.write(&buf_channel[..n]).await.map(|_| ()).map_err(|err| {
513 err
515 })?
516 },
517 Err(err) => {
518 return Err(err);
519 }
520 },
521 ret_stream_read = stream.read(&mut buf_stream).fuse() => match ret_stream_read {
522 Ok(0) => {
523 break
524 },
525 Ok(n) => {
526 #[allow(clippy::map_identity)]
527 futures_util::AsyncWriteExt::write(&mut channel,&buf_stream[..n]).await.map(|_| ()).map_err(|err| {
528 err
530 })?
531 },
532 Err(err) => {
533 return Err(err);
534 }
535 },
536 }
537 }
538
539 Result::<_, IoError>::Ok(())
540 });
541 match join_handle.await {
542 Ok(_) => {}
543 Err(err) => {
544 eprintln!("join_handle failed, err:{err:?}");
545 }
546 }
547 }
548 Err(err) => {
549 eprintln!("listener.accept failed, err:{err:?}");
550 }
551 }
552 }
553 }
554 #[cfg(unix)]
555 ConnectInfo::Unix(path) => {
556 let (mut listener, _remote_port) = self
557 .channel_forward_listen(remote_port, host, queue_maxsize)
558 .await?;
559
560 loop {
563 match listener.accept().await {
564 Ok(mut channel) => {
565 let path = path.clone();
566 let join_handle: tokio::task::JoinHandle<Result<(), IoError>> =
567 tokio::task::spawn(async move {
568 let mut stream = TokioUnixStream::connect(path).await?;
569
570 let mut buf_channel = vec![0; 2048];
571 let mut buf_stream = vec![0; 2048];
572
573 loop {
574 select! {
575 ret_channel_read = futures_util::AsyncReadExt::read(&mut channel, &mut buf_channel).fuse() => match ret_channel_read {
576 Ok(0) => {
577 break
578 },
579 Ok(n) => {
580 #[allow(clippy::map_identity)]
581 stream.write(&buf_channel[..n]).await.map(|_| ()).map_err(|err| {
582 err
584 })?
585 },
586 Err(err) => {
587 return Err(err);
588 }
589 },
590 ret_stream_read = stream.read(&mut buf_stream).fuse() => match ret_stream_read {
591 Ok(0) => {
592 break
593 },
594 Ok(n) => {
595 #[allow(clippy::map_identity)]
596 futures_util::AsyncWriteExt::write(&mut channel,&buf_stream[..n]).await.map(|_| ()).map_err(|err| {
597 err
599 })?
600 },
601 Err(err) => {
602 return Err(err);
603 }
604 },
605 }
606 }
607
608 Result::<_, IoError>::Ok(())
609 });
610 match join_handle.await {
611 Ok(_) => {}
612 Err(err) => {
613 eprintln!("join_handle failed, err:{err:?}");
614 }
615 }
616 }
617 Err(err) => {
618 eprintln!("listener.accept failed, err:{err:?}");
619 }
620 }
621 }
622 }
623 }
624 }
625}
626
627impl<S> AsyncSession<S> {
631 pub fn last_error(&self) -> Option<Ssh2Error> {
632 Ssh2Error::last_session_error(&self.inner)
633 }
634}
635
636impl<S> AsyncSession<S>
637where
638 S: AsyncSessionStream + Send + Sync + 'static,
639{
640 pub async fn userauth_agent_with_try_next(&self, username: &str) -> Result<(), Error> {
641 self.userauth_agent_with_try_next_with_callback(username, |identities| identities)
642 .await
643 }
644
645 pub async fn userauth_agent_with_try_next_with_callback<CB>(
646 &self,
647 username: &str,
648 mut cb: CB,
649 ) -> Result<(), Error>
650 where
651 CB: FnMut(Vec<PublicKey>) -> Vec<PublicKey>,
652 {
653 let mut agent = self.agent()?;
654 agent.connect().await?;
655 agent.list_identities().await?;
656 let identities = agent.identities()?;
657
658 if identities.is_empty() {
659 return Err(Error::Other("no identities found in the ssh agent".into()));
660 }
661
662 let identities = cb(identities);
663
664 for identity in identities {
665 match agent.userauth(username, &identity).await {
666 Ok(_) => {
667 if self.authenticated() {
668 return Ok(());
669 }
670 }
671 Err(_err) => {
672 continue;
673 }
674 }
675 }
676
677 Err(Error::Other("all identities cannot authenticated".into()))
678 }
679}
680
681#[derive(Debug, Clone, Default)]
685pub struct SessionConfiguration {
686 banner: Option<String>,
687 allow_sigpipe: Option<bool>,
688 compress: Option<bool>,
689 timeout: Option<Duration>,
690 keepalive: Option<SessionKeepaliveConfiguration>,
691}
692impl SessionConfiguration {
693 pub fn new() -> Self {
694 Default::default()
695 }
696
697 pub fn set_banner(&mut self, banner: &str) {
698 self.banner = Some(banner.to_owned());
699 }
700
701 pub fn set_allow_sigpipe(&mut self, block: bool) {
702 self.allow_sigpipe = Some(block);
703 }
704
705 pub fn set_compress(&mut self, compress: bool) {
706 self.compress = Some(compress);
707 }
708
709 pub fn set_timeout(&mut self, timeout_ms: u32) {
710 self.timeout = Some(Duration::from_millis(timeout_ms as u64));
711 }
712
713 pub fn set_keepalive(&mut self, want_reply: bool, interval: u32) {
714 self.keepalive = Some(SessionKeepaliveConfiguration {
715 want_reply,
716 interval,
717 });
718 }
719}
720
721#[derive(Debug, Clone)]
722struct SessionKeepaliveConfiguration {
723 want_reply: bool,
724 interval: u32,
725}
726
727pub(crate) fn get_session(
728 configuration: impl Into<Option<SessionConfiguration>>,
729) -> Result<Session, Error> {
730 let session = Session::new()?;
731 session.set_blocking(false);
732
733 if let Some(configuration) = configuration.into() {
734 if let Some(banner) = configuration.banner {
735 session.set_banner(banner.as_ref())?;
736 }
737 if let Some(allow_sigpipe) = configuration.allow_sigpipe {
738 session.set_allow_sigpipe(allow_sigpipe);
739 }
740 if let Some(compress) = configuration.compress {
741 session.set_compress(compress);
742 }
743 if let Some(timeout) = configuration.timeout {
744 session.set_timeout(timeout.as_millis() as u32);
745 }
746 if let Some(keepalive) = configuration.keepalive {
747 session.set_keepalive(keepalive.want_reply, keepalive.interval);
748 }
749 }
750
751 Ok(session)
752}