async_ssh2_lite/
session.rs

1use core::time::Duration;
2use std::{path::Path, sync::Arc};
3
4#[cfg(unix)]
5use std::os::unix::io::AsRawFd;
6#[cfg(windows)]
7use std::os::windows::io::{AsRawSocket, BorrowedSocket};
8
9use ssh2::{
10    BlockDirections, DisconnectCode, Error as Ssh2Error, HashType, HostKeyType,
11    KeyboardInteractivePrompt, KnownHosts, MethodType, PublicKey, ScpFileStat, Session, TraceFlags,
12};
13
14use crate::{
15    agent::AsyncAgent, channel::AsyncChannel, error::Error, listener::AsyncListener,
16    session_stream::AsyncSessionStream, sftp::AsyncSftp,
17};
18
19//
20pub struct AsyncSession<S> {
21    inner: Session,
22    stream: Arc<S>,
23}
24
25impl<S> Clone for AsyncSession<S> {
26    fn clone(&self) -> Self {
27        Self {
28            inner: self.inner.clone(),
29            stream: self.stream.clone(),
30        }
31    }
32}
33
34#[cfg(unix)]
35impl<S> AsyncSession<S>
36where
37    S: AsRawFd + 'static,
38{
39    pub fn new(
40        stream: S,
41        configuration: impl Into<Option<SessionConfiguration>>,
42    ) -> Result<Self, Error> {
43        let mut session = get_session(configuration)?;
44        session.set_tcp_stream(stream.as_raw_fd());
45
46        let stream = Arc::new(stream);
47
48        Ok(Self {
49            inner: session,
50            stream,
51        })
52    }
53}
54
55#[cfg(windows)]
56impl<S> AsyncSession<S>
57where
58    S: AsRawSocket + 'static,
59{
60    pub fn new(
61        stream: S,
62        configuration: impl Into<Option<SessionConfiguration>>,
63    ) -> Result<Self, Error> {
64        let mut session = get_session(configuration)?;
65        session.set_tcp_stream(unsafe { BorrowedSocket::borrow_raw(stream.as_raw_socket()) });
66
67        let stream = Arc::new(stream);
68
69        Ok(Self {
70            inner: session,
71            stream,
72        })
73    }
74}
75
76#[cfg(feature = "async-io")]
77impl AsyncSession<crate::AsyncIoTcpStream> {
78    pub async fn connect<A: Into<std::net::SocketAddr>>(
79        addr: A,
80        configuration: impl Into<Option<SessionConfiguration>>,
81    ) -> Result<Self, Error> {
82        let stream = crate::AsyncIoTcpStream::connect(addr).await?;
83
84        Self::new(stream, configuration)
85    }
86}
87
88#[cfg(all(unix, feature = "async-io"))]
89impl AsyncSession<crate::AsyncIoUnixStream> {
90    #[cfg(unix)]
91    pub async fn connect<P: AsRef<Path>>(
92        path: P,
93        configuration: impl Into<Option<SessionConfiguration>>,
94    ) -> Result<Self, Error> {
95        let stream = crate::AsyncIoUnixStream::connect(path).await?;
96
97        Self::new(stream, configuration)
98    }
99}
100
101#[cfg(feature = "tokio")]
102impl AsyncSession<crate::TokioTcpStream> {
103    pub async fn connect<A: Into<std::net::SocketAddr>>(
104        addr: A,
105        configuration: impl Into<Option<SessionConfiguration>>,
106    ) -> Result<Self, Error> {
107        let stream = crate::TokioTcpStream::connect(addr.into()).await?;
108
109        Self::new(stream, configuration)
110    }
111}
112
113#[cfg(all(unix, feature = "tokio"))]
114impl AsyncSession<crate::TokioUnixStream> {
115    #[cfg(unix)]
116    pub async fn connect<P: AsRef<Path>>(
117        path: P,
118        configuration: impl Into<Option<SessionConfiguration>>,
119    ) -> Result<Self, Error> {
120        let stream = crate::TokioUnixStream::connect(path).await?;
121
122        Self::new(stream, configuration)
123    }
124}
125
126impl<S> AsyncSession<S> {
127    pub fn is_blocking(&self) -> bool {
128        self.inner.is_blocking()
129    }
130
131    pub fn banner(&self) -> Option<&str> {
132        self.inner.banner()
133    }
134
135    pub fn banner_bytes(&self) -> Option<&[u8]> {
136        self.inner.banner_bytes()
137    }
138
139    pub fn timeout(&self) -> u32 {
140        self.inner.timeout()
141    }
142
143    pub fn trace(&self, bitmask: TraceFlags) {
144        self.inner.trace(bitmask)
145    }
146}
147
148impl<S> AsyncSession<S>
149where
150    S: AsyncSessionStream + Send + Sync + 'static,
151{
152    pub async fn handshake(&mut self) -> Result<(), Error> {
153        let sess = self.inner.clone();
154        self.stream.rw_with(|| self.inner.handshake(), &sess).await
155    }
156
157    pub async fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> {
158        self.stream
159            .rw_with(
160                || self.inner.userauth_password(username, password),
161                &self.inner,
162            )
163            .await
164    }
165
166    #[allow(unknown_lints)]
167    #[allow(clippy::needless_pass_by_ref_mut)]
168    pub async fn userauth_keyboard_interactive<P: KeyboardInteractivePrompt + Send>(
169        &self,
170        username: &str,
171        prompter: &mut P,
172    ) -> Result<(), Error> {
173        self.stream
174            .rw_with(
175                || self.inner.userauth_keyboard_interactive(username, prompter),
176                &self.inner,
177            )
178            .await
179    }
180
181    pub async fn userauth_agent(&self, username: &str) -> Result<(), Error> {
182        let mut agent = self.agent()?;
183        agent.connect().await?;
184        agent.list_identities().await?;
185        let identities = agent.identities()?;
186        let identity = match identities.first() {
187            Some(identity) => identity,
188            None => return Err(Error::Other("no identities found in the ssh agent".into())),
189        };
190        agent.userauth(username, identity).await
191    }
192
193    pub async fn userauth_pubkey_file(
194        &self,
195        username: &str,
196        pubkey: Option<&Path>,
197        privatekey: &Path,
198        passphrase: Option<&str>,
199    ) -> Result<(), Error> {
200        self.stream
201            .rw_with(
202                || {
203                    self.inner
204                        .userauth_pubkey_file(username, pubkey, privatekey, passphrase)
205                },
206                &self.inner,
207            )
208            .await
209    }
210
211    #[cfg(any(unix, feature = "vendored-openssl", feature = "openssl-on-win32"))]
212    pub async fn userauth_pubkey_memory(
213        &self,
214        username: &str,
215        pubkeydata: Option<&str>,
216        privatekeydata: &str,
217        passphrase: Option<&str>,
218    ) -> Result<(), Error> {
219        self.stream
220            .rw_with(
221                || {
222                    self.inner.userauth_pubkey_memory(
223                        username,
224                        pubkeydata,
225                        privatekeydata,
226                        passphrase,
227                    )
228                },
229                &self.inner,
230            )
231            .await
232    }
233
234    pub async fn userauth_hostbased_file(
235        &self,
236        username: &str,
237        publickey: &Path,
238        privatekey: &Path,
239        passphrase: Option<&str>,
240        hostname: &str,
241        local_username: Option<&str>,
242    ) -> Result<(), Error> {
243        self.stream
244            .rw_with(
245                || {
246                    self.inner.userauth_hostbased_file(
247                        username,
248                        publickey,
249                        privatekey,
250                        passphrase,
251                        hostname,
252                        local_username,
253                    )
254                },
255                &self.inner,
256            )
257            .await
258    }
259
260    pub fn authenticated(&self) -> bool {
261        self.inner.authenticated()
262    }
263
264    pub async fn auth_methods<'a>(&'a self, username: &'a str) -> Result<&str, Error> {
265        self.stream
266            .rw_with(|| self.inner.auth_methods(username), &self.inner)
267            .await
268    }
269
270    pub async fn method_pref(&self, method_type: MethodType, prefs: &str) -> Result<(), Error> {
271        self.stream
272            .rw_with(|| self.inner.method_pref(method_type, prefs), &self.inner)
273            .await
274    }
275
276    pub fn methods(&self, method_type: MethodType) -> Option<&str> {
277        self.inner.methods(method_type)
278    }
279
280    pub async fn supported_algs(
281        &self,
282        method_type: MethodType,
283    ) -> Result<Vec<&'static str>, Error> {
284        self.stream
285            .rw_with(|| self.inner.supported_algs(method_type), &self.inner)
286            .await
287    }
288
289    pub fn agent(&self) -> Result<AsyncAgent<S>, Error> {
290        let agent = self.inner.agent()?;
291
292        Ok(AsyncAgent::from_parts(
293            agent,
294            self.inner.clone(),
295            self.stream.clone(),
296        ))
297    }
298
299    pub fn known_hosts(&self) -> Result<KnownHosts, Error> {
300        self.inner.known_hosts().map_err(Into::into)
301    }
302
303    pub async fn channel_session(&self) -> Result<AsyncChannel<S>, Error> {
304        let channel = self
305            .stream
306            .rw_with(|| self.inner.channel_session(), &self.inner)
307            .await?;
308
309        Ok(AsyncChannel::from_parts(
310            channel,
311            self.inner.clone(),
312            self.stream.clone(),
313        ))
314    }
315
316    pub async fn channel_direct_tcpip(
317        &self,
318        host: &str,
319        port: u16,
320        src: Option<(&str, u16)>,
321    ) -> Result<AsyncChannel<S>, Error> {
322        let channel = self
323            .stream
324            .rw_with(
325                || self.inner.channel_direct_tcpip(host, port, src),
326                &self.inner,
327            )
328            .await?;
329
330        Ok(AsyncChannel::from_parts(
331            channel,
332            self.inner.clone(),
333            self.stream.clone(),
334        ))
335    }
336
337    pub async fn channel_forward_listen(
338        &self,
339        remote_port: u16,
340        host: Option<&str>,
341        queue_maxsize: Option<u32>,
342    ) -> Result<(AsyncListener<S>, u16), Error> {
343        let (listener, port) = self
344            .stream
345            .rw_with(
346                || {
347                    self.inner
348                        .channel_forward_listen(remote_port, host, queue_maxsize)
349                },
350                &self.inner,
351            )
352            .await?;
353
354        Ok((
355            AsyncListener::from_parts(listener, self.inner.clone(), self.stream.clone()),
356            port,
357        ))
358    }
359
360    pub async fn scp_recv(&self, path: &Path) -> Result<(AsyncChannel<S>, ScpFileStat), Error> {
361        let (channel, scp_file_stat) = self
362            .stream
363            .rw_with(|| self.inner.scp_recv(path), &self.inner)
364            .await?;
365
366        Ok((
367            AsyncChannel::from_parts(channel, self.inner.clone(), self.stream.clone()),
368            scp_file_stat,
369        ))
370    }
371
372    pub async fn scp_send(
373        &self,
374        remote_path: &Path,
375        mode: i32,
376        size: u64,
377        times: Option<(u64, u64)>,
378    ) -> Result<AsyncChannel<S>, Error> {
379        let channel = self
380            .stream
381            .rw_with(
382                || self.inner.scp_send(remote_path, mode, size, times),
383                &self.inner,
384            )
385            .await?;
386
387        Ok(AsyncChannel::from_parts(
388            channel,
389            self.inner.clone(),
390            self.stream.clone(),
391        ))
392    }
393
394    pub async fn sftp(&self) -> Result<AsyncSftp<S>, Error> {
395        let sftp = self
396            .stream
397            .rw_with(|| self.inner.sftp(), &self.inner)
398            .await?;
399
400        Ok(AsyncSftp::from_parts(
401            sftp,
402            self.inner.clone(),
403            self.stream.clone(),
404        ))
405    }
406
407    pub async fn channel_open(
408        &self,
409        channel_type: &str,
410        window_size: u32,
411        packet_size: u32,
412        message: Option<&str>,
413    ) -> Result<AsyncChannel<S>, Error> {
414        let channel = self
415            .stream
416            .rw_with(
417                || {
418                    self.inner
419                        .channel_open(channel_type, window_size, packet_size, message)
420                },
421                &self.inner,
422            )
423            .await?;
424
425        Ok(AsyncChannel::from_parts(
426            channel,
427            self.inner.clone(),
428            self.stream.clone(),
429        ))
430    }
431
432    pub fn host_key(&self) -> Option<(&[u8], HostKeyType)> {
433        self.inner.host_key()
434    }
435
436    pub fn host_key_hash(&self, hash: HashType) -> Option<&[u8]> {
437        self.inner.host_key_hash(hash)
438    }
439
440    pub async fn keepalive_send(&self) -> Result<u32, Error> {
441        self.stream
442            .rw_with(|| self.inner.keepalive_send(), &self.inner)
443            .await
444    }
445
446    pub async fn disconnect(
447        &self,
448        reason: Option<DisconnectCode>,
449        description: &str,
450        lang: Option<&str>,
451    ) -> Result<(), Error> {
452        self.stream
453            .rw_with(
454                || self.inner.disconnect(reason, description, lang),
455                &self.inner,
456            )
457            .await
458    }
459
460    pub fn block_directions(&self) -> BlockDirections {
461        self.inner.block_directions()
462    }
463}
464
465#[cfg(feature = "tokio")]
466impl<S> AsyncSession<S>
467where
468    S: AsyncSessionStream + Send + Sync + 'static,
469{
470    pub async fn remote_port_forwarding(
471        &self,
472        remote_port: u16,
473        host: Option<&str>,
474        queue_maxsize: Option<u32>,
475        local: crate::util::ConnectInfo,
476    ) -> Result<(), Error> {
477        use std::io::Error as IoError;
478
479        use futures_util::{select, FutureExt as _};
480        use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
481
482        #[cfg(unix)]
483        use crate::TokioUnixStream;
484        use crate::{util::ConnectInfo, TokioTcpStream};
485
486        match local {
487            ConnectInfo::Tcp(addr) => {
488                let (mut listener, _remote_port) = self
489                    .channel_forward_listen(remote_port, host, queue_maxsize)
490                    .await?;
491
492                // TODO, tokio::io::copy_bidirectional not working
493
494                loop {
495                    match listener.accept().await {
496                        Ok(mut channel) => {
497                            let join_handle: tokio::task::JoinHandle<Result<(), IoError>> =
498                                tokio::task::spawn(async move {
499                                    let mut stream = TokioTcpStream::connect(addr).await?;
500
501                                    let mut buf_channel = vec![0; 2048];
502                                    let mut buf_stream = vec![0; 2048];
503
504                                    loop {
505                                        select! {
506                                            ret_channel_read = futures_util::AsyncReadExt::read(&mut channel, &mut buf_channel).fuse() => match ret_channel_read {
507                                                Ok(0)  => {
508                                                    break
509                                                },
510                                                Ok(n) => {
511                                                    #[allow(clippy::map_identity)]
512                                                    stream.write(&buf_channel[..n]).await.map(|_| ()).map_err(|err| {
513                                                        // TODO, log
514                                                        err
515                                                    })?
516                                                },
517                                                Err(err) =>  {
518                                                    return Err(err);
519                                                }
520                                            },
521                                            ret_stream_read = stream.read(&mut buf_stream).fuse() => match ret_stream_read {
522                                                Ok(0)  => {
523                                                    break
524                                                },
525                                                Ok(n) => {
526                                                    #[allow(clippy::map_identity)]
527                                                    futures_util::AsyncWriteExt::write(&mut channel,&buf_stream[..n]).await.map(|_| ()).map_err(|err| {
528                                                        // TODO, log
529                                                        err
530                                                    })?
531                                                },
532                                                Err(err) => {
533                                                    return Err(err);
534                                                }
535                                            },
536                                        }
537                                    }
538
539                                    Result::<_, IoError>::Ok(())
540                                });
541                            match join_handle.await {
542                                Ok(_) => {}
543                                Err(err) => {
544                                    eprintln!("join_handle failed, err:{err:?}");
545                                }
546                            }
547                        }
548                        Err(err) => {
549                            eprintln!("listener.accept failed, err:{err:?}");
550                        }
551                    }
552                }
553            }
554            #[cfg(unix)]
555            ConnectInfo::Unix(path) => {
556                let (mut listener, _remote_port) = self
557                    .channel_forward_listen(remote_port, host, queue_maxsize)
558                    .await?;
559
560                // TODO, tokio::io::copy_bidirectional not working
561
562                loop {
563                    match listener.accept().await {
564                        Ok(mut channel) => {
565                            let path = path.clone();
566                            let join_handle: tokio::task::JoinHandle<Result<(), IoError>> =
567                                tokio::task::spawn(async move {
568                                    let mut stream = TokioUnixStream::connect(path).await?;
569
570                                    let mut buf_channel = vec![0; 2048];
571                                    let mut buf_stream = vec![0; 2048];
572
573                                    loop {
574                                        select! {
575                                            ret_channel_read = futures_util::AsyncReadExt::read(&mut channel, &mut buf_channel).fuse() => match ret_channel_read {
576                                                Ok(0)  => {
577                                                    break
578                                                },
579                                                Ok(n) => {
580                                                    #[allow(clippy::map_identity)]
581                                                    stream.write(&buf_channel[..n]).await.map(|_| ()).map_err(|err| {
582                                                        // TODO, log
583                                                        err
584                                                    })?
585                                                },
586                                                Err(err) =>  {
587                                                    return Err(err);
588                                                }
589                                            },
590                                            ret_stream_read = stream.read(&mut buf_stream).fuse() => match ret_stream_read {
591                                                Ok(0)  => {
592                                                    break
593                                                },
594                                                Ok(n) => {
595                                                    #[allow(clippy::map_identity)]
596                                                    futures_util::AsyncWriteExt::write(&mut channel,&buf_stream[..n]).await.map(|_| ()).map_err(|err| {
597                                                        // TODO, log
598                                                        err
599                                                    })?
600                                                },
601                                                Err(err) => {
602                                                    return Err(err);
603                                                }
604                                            },
605                                        }
606                                    }
607
608                                    Result::<_, IoError>::Ok(())
609                                });
610                            match join_handle.await {
611                                Ok(_) => {}
612                                Err(err) => {
613                                    eprintln!("join_handle failed, err:{err:?}");
614                                }
615                            }
616                        }
617                        Err(err) => {
618                            eprintln!("listener.accept failed, err:{err:?}");
619                        }
620                    }
621                }
622            }
623        }
624    }
625}
626
627//
628// extension
629//
630impl<S> AsyncSession<S> {
631    pub fn last_error(&self) -> Option<Ssh2Error> {
632        Ssh2Error::last_session_error(&self.inner)
633    }
634}
635
636impl<S> AsyncSession<S>
637where
638    S: AsyncSessionStream + Send + Sync + 'static,
639{
640    pub async fn userauth_agent_with_try_next(&self, username: &str) -> Result<(), Error> {
641        self.userauth_agent_with_try_next_with_callback(username, |identities| identities)
642            .await
643    }
644
645    pub async fn userauth_agent_with_try_next_with_callback<CB>(
646        &self,
647        username: &str,
648        mut cb: CB,
649    ) -> Result<(), Error>
650    where
651        CB: FnMut(Vec<PublicKey>) -> Vec<PublicKey>,
652    {
653        let mut agent = self.agent()?;
654        agent.connect().await?;
655        agent.list_identities().await?;
656        let identities = agent.identities()?;
657
658        if identities.is_empty() {
659            return Err(Error::Other("no identities found in the ssh agent".into()));
660        }
661
662        let identities = cb(identities);
663
664        for identity in identities {
665            match agent.userauth(username, &identity).await {
666                Ok(_) => {
667                    if self.authenticated() {
668                        return Ok(());
669                    }
670                }
671                Err(_err) => {
672                    continue;
673                }
674            }
675        }
676
677        Err(Error::Other("all identities cannot authenticated".into()))
678    }
679}
680
681//
682//
683//
684#[derive(Debug, Clone, Default)]
685pub struct SessionConfiguration {
686    banner: Option<String>,
687    allow_sigpipe: Option<bool>,
688    compress: Option<bool>,
689    timeout: Option<Duration>,
690    keepalive: Option<SessionKeepaliveConfiguration>,
691}
692impl SessionConfiguration {
693    pub fn new() -> Self {
694        Default::default()
695    }
696
697    pub fn set_banner(&mut self, banner: &str) {
698        self.banner = Some(banner.to_owned());
699    }
700
701    pub fn set_allow_sigpipe(&mut self, block: bool) {
702        self.allow_sigpipe = Some(block);
703    }
704
705    pub fn set_compress(&mut self, compress: bool) {
706        self.compress = Some(compress);
707    }
708
709    pub fn set_timeout(&mut self, timeout_ms: u32) {
710        self.timeout = Some(Duration::from_millis(timeout_ms as u64));
711    }
712
713    pub fn set_keepalive(&mut self, want_reply: bool, interval: u32) {
714        self.keepalive = Some(SessionKeepaliveConfiguration {
715            want_reply,
716            interval,
717        });
718    }
719}
720
721#[derive(Debug, Clone)]
722struct SessionKeepaliveConfiguration {
723    want_reply: bool,
724    interval: u32,
725}
726
727pub(crate) fn get_session(
728    configuration: impl Into<Option<SessionConfiguration>>,
729) -> Result<Session, Error> {
730    let session = Session::new()?;
731    session.set_blocking(false);
732
733    if let Some(configuration) = configuration.into() {
734        if let Some(banner) = configuration.banner {
735            session.set_banner(banner.as_ref())?;
736        }
737        if let Some(allow_sigpipe) = configuration.allow_sigpipe {
738            session.set_allow_sigpipe(allow_sigpipe);
739        }
740        if let Some(compress) = configuration.compress {
741            session.set_compress(compress);
742        }
743        if let Some(timeout) = configuration.timeout {
744            session.set_timeout(timeout.as_millis() as u32);
745        }
746        if let Some(keepalive) = configuration.keepalive {
747            session.set_keepalive(keepalive.want_reply, keepalive.interval);
748        }
749    }
750
751    Ok(session)
752}