async_ssh2_tokio/
client.rs

1use russh::client::KeyboardInteractiveAuthResponse;
2use russh::{
3    Channel,
4    client::{Config, Handle, Handler, Msg},
5};
6use russh_sftp::{client::SftpSession, protocol::OpenFlags};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::{fmt::Debug, path::Path};
10use std::{io, path::PathBuf};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12
13use crate::ToSocketAddrsWithHostname;
14
15/// An authentification token.
16///
17/// Used when creating a [`Client`] for authentification.
18/// Supports password, private key, public key, SSH agent, and keyboard interactive authentication.
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20#[non_exhaustive]
21pub enum AuthMethod {
22    Password(String),
23    PrivateKey {
24        /// entire contents of private key file
25        key_data: String,
26        key_pass: Option<String>,
27    },
28    PrivateKeyFile {
29        key_file_path: PathBuf,
30        key_pass: Option<String>,
31    },
32    #[cfg(not(target_os = "windows"))]
33    PublicKeyFile {
34        key_file_path: PathBuf,
35    },
36    #[cfg(not(target_os = "windows"))]
37    Agent,
38    KeyboardInteractive(AuthKeyboardInteractive),
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42struct PromptResponse {
43    exact: bool,
44    prompt: String,
45    response: String,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
49#[non_exhaustive]
50pub struct AuthKeyboardInteractive {
51    /// Hnts to the server the preferred methods to be used for authentication.
52    submethods: Option<String>,
53    responses: Vec<PromptResponse>,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57#[non_exhaustive]
58pub enum ServerCheckMethod {
59    NoCheck,
60    /// base64 encoded key without the type prefix or hostname suffix (type is already encoded)
61    PublicKey(String),
62    PublicKeyFile(String),
63    DefaultKnownHostsFile,
64    KnownHostsFile(String),
65}
66
67impl AuthMethod {
68    /// Convenience method to create a [`AuthMethod`] from a string literal.
69    pub fn with_password(password: &str) -> Self {
70        Self::Password(password.to_string())
71    }
72
73    pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
74        Self::PrivateKey {
75            key_data: key.to_string(),
76            key_pass: passphrase.map(str::to_string),
77        }
78    }
79
80    pub fn with_key_file<T: AsRef<Path>>(key_file_path: T, passphrase: Option<&str>) -> Self {
81        Self::PrivateKeyFile {
82            key_file_path: key_file_path.as_ref().to_path_buf(),
83            key_pass: passphrase.map(str::to_string),
84        }
85    }
86
87    #[cfg(not(target_os = "windows"))]
88    pub fn with_public_key_file<T: AsRef<Path>>(key_file_path: T) -> Self {
89        Self::PublicKeyFile {
90            key_file_path: key_file_path.as_ref().to_path_buf(),
91        }
92    }
93
94    /// Creates a new SSH agent authentication method.
95    ///
96    /// This will attempt to authenticate using all identities available in the SSH agent.
97    /// The SSH agent must be running and the SSH_AUTH_SOCK environment variable must be set.
98    ///
99    /// # Example
100    /// ```no_run
101    /// use async_ssh2_tokio::client::AuthMethod;
102    ///
103    /// let auth = AuthMethod::with_agent();
104    /// ```
105    ///
106    /// # Platform Support
107    /// This method is only available on Unix-like systems (Linux, macOS, etc.).
108    /// It is not available on Windows.
109    #[cfg(not(target_os = "windows"))]
110    pub fn with_agent() -> Self {
111        Self::Agent
112    }
113
114    pub const fn with_keyboard_interactive(auth: AuthKeyboardInteractive) -> Self {
115        Self::KeyboardInteractive(auth)
116    }
117}
118
119impl AuthKeyboardInteractive {
120    pub fn new() -> Self {
121        Default::default()
122    }
123
124    /// Hnts to the server the preferred methods to be used for authentication.
125    pub fn with_submethods(mut self, submethods: impl Into<String>) -> Self {
126        self.submethods = Some(submethods.into());
127        self
128    }
129
130    /// Adds a response to the list of responses for a given prompt.
131    ///
132    /// The comparison for the prompt is done using a "contains".
133    pub fn with_response(mut self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
134        self.responses.push(PromptResponse {
135            exact: false,
136            prompt: prompt.into(),
137            response: response.into(),
138        });
139
140        self
141    }
142
143    /// Adds a response to the list of responses for a given exact prompt.
144    pub fn with_response_exact(
145        mut self,
146        prompt: impl Into<String>,
147        response: impl Into<String>,
148    ) -> Self {
149        self.responses.push(PromptResponse {
150            exact: true,
151            prompt: prompt.into(),
152            response: response.into(),
153        });
154
155        self
156    }
157}
158
159impl PromptResponse {
160    fn matches(&self, received_prompt: &str) -> bool {
161        if self.exact {
162            self.prompt.eq(received_prompt)
163        } else {
164            received_prompt.contains(&self.prompt)
165        }
166    }
167}
168
169impl From<AuthKeyboardInteractive> for AuthMethod {
170    fn from(value: AuthKeyboardInteractive) -> Self {
171        Self::with_keyboard_interactive(value)
172    }
173}
174
175impl ServerCheckMethod {
176    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
177    pub fn with_public_key(key: &str) -> Self {
178        Self::PublicKey(key.to_string())
179    }
180
181    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
182    pub fn with_public_key_file(key_file_name: &str) -> Self {
183        Self::PublicKeyFile(key_file_name.to_string())
184    }
185
186    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
187    pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
188        Self::KnownHostsFile(known_hosts_file.to_string())
189    }
190}
191
192/// A ssh connection to a remote server.
193///
194/// After creating a `Client` by [`connect`]ing to a remote host,
195/// use [`execute`] to send commands and receive results through the connections.
196///
197/// [`connect`]: Client::connect
198/// [`execute`]: Client::execute
199///
200/// # Examples
201///
202/// ```no_run
203/// use async_ssh2_tokio::{Client, AuthMethod, ServerCheckMethod};
204/// #[tokio::main]
205/// async fn main() -> Result<(), async_ssh2_tokio::Error> {
206///     let mut client = Client::connect(
207///         ("10.10.10.2", 22),
208///         "root",
209///         AuthMethod::with_password("root"),
210///         ServerCheckMethod::NoCheck,
211///     ).await?;
212///
213///     let result = client.execute("echo Hello SSH").await?;
214///     assert_eq!(result.stdout, "Hello SSH\n");
215///     assert_eq!(result.exit_status, 0);
216///
217///     Ok(())
218/// }
219#[derive(Clone)]
220pub struct Client {
221    connection_handle: Arc<Handle<ClientHandler>>,
222    username: String,
223    address: SocketAddr,
224}
225
226impl Client {
227    /// Open a ssh connection to a remote host.
228    ///
229    /// `addr` is an address of the remote host. Anything which implements
230    /// [`ToSocketAddrsWithHostname`] trait can be supplied for the address;
231    /// ToSocketAddrsWithHostname reimplements all of [`ToSocketAddrs`];
232    /// see this trait's documentation for concrete examples.
233    ///
234    /// If `addr` yields multiple addresses, `connect` will be attempted with
235    /// each of the addresses until a connection is successful.
236    /// Authentification is tried on the first successful connection and the whole
237    /// process aborted if this fails.
238    pub async fn connect(
239        addr: impl ToSocketAddrsWithHostname,
240        username: &str,
241        auth: AuthMethod,
242        server_check: ServerCheckMethod,
243    ) -> Result<Self, crate::Error> {
244        Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
245    }
246
247    /// Same as `connect`, but with the option to specify a non default
248    /// [`russh::client::Config`].
249    pub async fn connect_with_config(
250        addr: impl ToSocketAddrsWithHostname,
251        username: &str,
252        auth: AuthMethod,
253        server_check: ServerCheckMethod,
254        config: Config,
255    ) -> Result<Self, crate::Error> {
256        let config = Arc::new(config);
257
258        // Connection code inspired from std::net::TcpStream::connect and std::net::each_addr
259        let socket_addrs = addr
260            .to_socket_addrs()
261            .map_err(crate::Error::AddressInvalid)?;
262        let mut connect_res = Err(crate::Error::AddressInvalid(io::Error::new(
263            io::ErrorKind::InvalidInput,
264            "could not resolve to any addresses",
265        )));
266        for socket_addr in socket_addrs {
267            let handler = ClientHandler {
268                hostname: addr.hostname(),
269                host: socket_addr,
270                server_check: server_check.clone(),
271            };
272            match russh::client::connect(config.clone(), socket_addr, handler).await {
273                Ok(h) => {
274                    connect_res = Ok((socket_addr, h));
275                    break;
276                }
277                Err(e) => connect_res = Err(e),
278            }
279        }
280        let (address, mut handle) = connect_res?;
281        let username = username.to_string();
282
283        Self::authenticate(&mut handle, &username, auth).await?;
284
285        Ok(Self {
286            connection_handle: Arc::new(handle),
287            username,
288            address,
289        })
290    }
291
292    /// This takes a handle and performs authentification with the given method.
293    async fn authenticate(
294        handle: &mut Handle<ClientHandler>,
295        username: &String,
296        auth: AuthMethod,
297    ) -> Result<(), crate::Error> {
298        match auth {
299            AuthMethod::Password(password) => {
300                let is_authentificated = handle.authenticate_password(username, password).await?;
301                if !is_authentificated.success() {
302                    return Err(crate::Error::PasswordWrong);
303                }
304            }
305            AuthMethod::PrivateKey { key_data, key_pass } => {
306                let cprivk = russh::keys::decode_secret_key(key_data.as_str(), key_pass.as_deref())
307                    .map_err(crate::Error::KeyInvalid)?;
308                let is_authentificated = handle
309                    .authenticate_publickey(
310                        username,
311                        russh::keys::PrivateKeyWithHashAlg::new(
312                            Arc::new(cprivk),
313                            handle.best_supported_rsa_hash().await?.flatten(),
314                        ),
315                    )
316                    .await?;
317                if !is_authentificated.success() {
318                    return Err(crate::Error::KeyAuthFailed);
319                }
320            }
321            AuthMethod::PrivateKeyFile {
322                key_file_path,
323                key_pass,
324            } => {
325                let cprivk = russh::keys::load_secret_key(key_file_path, key_pass.as_deref())
326                    .map_err(crate::Error::KeyInvalid)?;
327                let is_authentificated = handle
328                    .authenticate_publickey(
329                        username,
330                        russh::keys::PrivateKeyWithHashAlg::new(
331                            Arc::new(cprivk),
332                            handle.best_supported_rsa_hash().await?.flatten(),
333                        ),
334                    )
335                    .await?;
336                if !is_authentificated.success() {
337                    return Err(crate::Error::KeyAuthFailed);
338                }
339            }
340            #[cfg(not(target_os = "windows"))]
341            AuthMethod::PublicKeyFile { key_file_path } => {
342                let cpubk = russh::keys::load_public_key(key_file_path)
343                    .map_err(crate::Error::KeyInvalid)?;
344                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
345                    .await
346                    .unwrap();
347                let mut auth_identity: Option<russh::keys::PublicKey> = None;
348                for identity in agent
349                    .request_identities()
350                    .await
351                    .map_err(crate::Error::KeyInvalid)?
352                {
353                    if identity == cpubk {
354                        auth_identity = Some(identity.clone());
355                        break;
356                    }
357                }
358
359                if auth_identity.is_none() {
360                    return Err(crate::Error::KeyAuthFailed);
361                }
362
363                let is_authentificated = handle
364                    .authenticate_publickey_with(
365                        username,
366                        cpubk,
367                        handle.best_supported_rsa_hash().await?.flatten(),
368                        &mut agent,
369                    )
370                    .await?;
371                if !is_authentificated.success() {
372                    return Err(crate::Error::KeyAuthFailed);
373                }
374            }
375            #[cfg(not(target_os = "windows"))]
376            AuthMethod::Agent => {
377                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
378                    .await
379                    .map_err(|_| crate::Error::AgentConnectionFailed)?;
380
381                let identities = agent
382                    .request_identities()
383                    .await
384                    .map_err(|_| crate::Error::AgentRequestIdentitiesFailed)?;
385
386                if identities.is_empty() {
387                    return Err(crate::Error::AgentNoIdentities);
388                }
389
390                let mut auth_success = false;
391                for identity in identities {
392                    let result = handle
393                        .authenticate_publickey_with(
394                            username,
395                            identity.clone(),
396                            handle.best_supported_rsa_hash().await?.flatten(),
397                            &mut agent,
398                        )
399                        .await;
400
401                    if let Ok(auth_result) = result {
402                        if auth_result.success() {
403                            auth_success = true;
404                            break;
405                        }
406                    }
407                }
408
409                if !auth_success {
410                    return Err(crate::Error::AgentAuthenticationFailed);
411                }
412            }
413            AuthMethod::KeyboardInteractive(mut kbd) => {
414                let mut res = handle
415                    .authenticate_keyboard_interactive_start(username, kbd.submethods)
416                    .await?;
417                loop {
418                    let prompts = match res {
419                        KeyboardInteractiveAuthResponse::Success => break,
420                        KeyboardInteractiveAuthResponse::Failure { .. } => {
421                            return Err(crate::Error::KeyboardInteractiveAuthFailed);
422                        }
423                        KeyboardInteractiveAuthResponse::InfoRequest { prompts, .. } => prompts,
424                    };
425
426                    let mut responses = vec![];
427                    for prompt in prompts {
428                        let Some(pos) = kbd
429                            .responses
430                            .iter()
431                            .position(|pr| pr.matches(&prompt.prompt))
432                        else {
433                            return Err(crate::Error::KeyboardInteractiveNoResponseForPrompt(
434                                prompt.prompt,
435                            ));
436                        };
437                        let pr = kbd.responses.remove(pos);
438                        responses.push(pr.response);
439                    }
440
441                    res = handle
442                        .authenticate_keyboard_interactive_respond(responses)
443                        .await?;
444                }
445            }
446        };
447        Ok(())
448    }
449
450    pub async fn get_channel(&self) -> Result<Channel<Msg>, crate::Error> {
451        self.connection_handle
452            .channel_open_session()
453            .await
454            .map_err(crate::Error::SshError)
455    }
456
457    /// Open a TCP/IP forwarding channel.
458    ///
459    /// This opens a `direct-tcpip` channel to the given target.
460    pub async fn open_direct_tcpip_channel<
461        T: ToSocketAddrsWithHostname,
462        S: Into<Option<SocketAddr>>,
463    >(
464        &self,
465        target: T,
466        src: S,
467    ) -> Result<Channel<Msg>, crate::Error> {
468        let targets = target
469            .to_socket_addrs()
470            .map_err(crate::Error::AddressInvalid)?;
471        let src = src
472            .into()
473            .map(|src| (src.ip().to_string(), src.port().into()))
474            .unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
475
476        let mut connect_err = crate::Error::AddressInvalid(io::Error::new(
477            io::ErrorKind::InvalidInput,
478            "could not resolve to any addresses",
479        ));
480        for target in targets {
481            match self
482                .connection_handle
483                .channel_open_direct_tcpip(
484                    target.ip().to_string(),
485                    target.port().into(),
486                    src.0.clone(),
487                    src.1,
488                )
489                .await
490            {
491                Ok(channel) => return Ok(channel),
492                Err(err) => connect_err = crate::Error::SshError(err),
493            }
494        }
495
496        Err(connect_err)
497    }
498
499    /// Upload a file with sftp to the remote server.
500    ///
501    /// `src_file_path` is the path to the file on the local machine.
502    /// `dest_file_path` is the path to the file on the remote machine.
503    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
504    /// A config line like a `Subsystem sftp internal-sftp` or
505    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
506    pub async fn upload_file<T: AsRef<Path>, U: Into<String>>(
507        &self,
508        src_file_path: T,
509        //fa993: This cannot be AsRef<Path> because of underlying lib constraints as described here
510        //https://github.com/AspectUnk/russh-sftp/issues/7#issuecomment-1738355245
511        dest_file_path: U,
512    ) -> Result<(), crate::Error> {
513        // start sftp session
514        let channel = self.get_channel().await?;
515        channel.request_subsystem(true, "sftp").await?;
516        let sftp = SftpSession::new(channel.into_stream()).await?;
517
518        // read file contents locally
519        let file_contents = tokio::fs::read(src_file_path)
520            .await
521            .map_err(crate::Error::IoError)?;
522
523        // interaction with i/o
524        let mut file = sftp
525            .open_with_flags(
526                dest_file_path,
527                OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
528            )
529            .await?;
530        file.write_all(&file_contents)
531            .await
532            .map_err(crate::Error::IoError)?;
533        file.flush().await.map_err(crate::Error::IoError)?;
534        file.shutdown().await.map_err(crate::Error::IoError)?;
535
536        Ok(())
537    }
538
539    /// Download a file from the remote server using sftp.
540    ///
541    /// `remote_file_path` is the path to the file on the remote machine.
542    /// `local_file_path` is the path to the file on the local machine.
543    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
544    /// A config line like a `Subsystem sftp internal-sftp` or
545    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
546    pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
547        &self,
548        remote_file_path: U,
549        local_file_path: T,
550    ) -> Result<(), crate::Error> {
551        // start sftp session
552        let channel = self.get_channel().await?;
553        channel.request_subsystem(true, "sftp").await?;
554        let sftp = SftpSession::new(channel.into_stream()).await?;
555
556        // open remote file for reading
557        let mut remote_file = sftp
558            .open_with_flags(remote_file_path, OpenFlags::READ)
559            .await?;
560
561        // read remote file contents
562        let mut contents = Vec::new();
563        remote_file.read_to_end(contents.as_mut()).await?;
564
565        // write contents to local file
566        let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
567            .await
568            .map_err(crate::Error::IoError)?;
569
570        local_file
571            .write_all(&contents)
572            .await
573            .map_err(crate::Error::IoError)?;
574        local_file.flush().await.map_err(crate::Error::IoError)?;
575
576        Ok(())
577    }
578
579    /// Execute a remote command via the ssh connection.
580    ///
581    /// Returns stdout, stderr and the exit code of the command,
582    /// packaged in a [`CommandExecutedResult`] struct.
583    /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection,
584    /// e.g. `echo foo 2>&1`.
585    /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`.
586    ///
587    /// Make sure your commands don't read from stdin and exit after bounded time.
588    ///
589    /// Can be called multiple times, but every invocation is a new shell context.
590    /// Thus `cd`, setting variables and alike have no effect on future invocations.
591    pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, crate::Error> {
592        let mut stdout_buffer = vec![];
593        let mut stderr_buffer = vec![];
594        let mut channel = self.connection_handle.channel_open_session().await?;
595        channel.exec(true, command).await?;
596
597        let mut result: Option<u32> = None;
598
599        // While the channel has messages...
600        while let Some(msg) = channel.wait().await {
601            //dbg!(&msg);
602            match msg {
603                // If we get data, add it to the buffer
604                russh::ChannelMsg::Data { ref data } => {
605                    stdout_buffer.write_all(data).await.unwrap()
606                }
607                russh::ChannelMsg::ExtendedData { ref data, ext } => {
608                    if ext == 1 {
609                        stderr_buffer.write_all(data).await.unwrap()
610                    }
611                }
612
613                // If we get an exit code report, store it, but crucially don't
614                // assume this message means end of communications. The data might
615                // not be finished yet!
616                russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
617
618                // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits
619                // the channel to close without it being sent. And sometimes this
620                // message can even precede the Data message, so don't handle it
621                // russh::ChannelMsg::Eof => break,
622                _ => {}
623            }
624        }
625
626        // If we received an exit code, report it back
627        if let Some(result) = result {
628            Ok(CommandExecutedResult {
629                stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
630                stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
631                exit_status: result,
632            })
633
634        // Otherwise, report an error
635        } else {
636            Err(crate::Error::CommandDidntExit)
637        }
638    }
639
640    /// A debugging function to get the username this client is connected as.
641    pub fn get_connection_username(&self) -> &String {
642        &self.username
643    }
644
645    /// A debugging function to get the address this client is connected to.
646    pub fn get_connection_address(&self) -> &SocketAddr {
647        &self.address
648    }
649
650    pub async fn disconnect(&self) -> Result<(), crate::Error> {
651        self.connection_handle
652            .disconnect(russh::Disconnect::ByApplication, "", "")
653            .await
654            .map_err(crate::Error::SshError)
655    }
656
657    pub fn is_closed(&self) -> bool {
658        self.connection_handle.is_closed()
659    }
660}
661
662impl Debug for Client {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        f.debug_struct("Client")
665            .field("username", &self.username)
666            .field("address", &self.address)
667            .field("connection_handle", &"Handle<ClientHandler>")
668            .finish()
669    }
670}
671
672#[derive(Debug, Clone, PartialEq, Eq, Hash)]
673pub struct CommandExecutedResult {
674    /// The stdout output of the command.
675    pub stdout: String,
676    /// The stderr output of the command.
677    pub stderr: String,
678    /// The unix exit status (`$?` in bash).
679    pub exit_status: u32,
680}
681
682#[derive(Debug, Clone)]
683struct ClientHandler {
684    hostname: String,
685    host: SocketAddr,
686    server_check: ServerCheckMethod,
687}
688
689impl Handler for ClientHandler {
690    type Error = crate::Error;
691
692    async fn check_server_key(
693        &mut self,
694        server_public_key: &russh::keys::PublicKey,
695    ) -> Result<bool, Self::Error> {
696        match &self.server_check {
697            ServerCheckMethod::NoCheck => Ok(true),
698            ServerCheckMethod::PublicKey(key) => {
699                let pk = russh::keys::parse_public_key_base64(key)
700                    .map_err(|_| crate::Error::ServerCheckFailed)?;
701
702                Ok(pk == *server_public_key)
703            }
704            ServerCheckMethod::PublicKeyFile(key_file_name) => {
705                let pk = russh::keys::load_public_key(key_file_name)
706                    .map_err(|_| crate::Error::ServerCheckFailed)?;
707
708                Ok(pk == *server_public_key)
709            }
710            ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
711                let result = russh::keys::check_known_hosts_path(
712                    &self.hostname,
713                    self.host.port(),
714                    server_public_key,
715                    known_hosts_path,
716                )
717                .map_err(|_| crate::Error::ServerCheckFailed)?;
718
719                Ok(result)
720            }
721            ServerCheckMethod::DefaultKnownHostsFile => {
722                let result = russh::keys::check_known_hosts(
723                    &self.hostname,
724                    self.host.port(),
725                    server_public_key,
726                )
727                .map_err(|_| crate::Error::ServerCheckFailed)?;
728
729                Ok(result)
730            }
731        }
732    }
733}
734
735#[cfg(test)]
736mod tests {
737    use crate::client::*;
738    use core::time;
739    use dotenv::dotenv;
740    use std::path::Path;
741    use std::sync::Once;
742    use tokio::io::AsyncReadExt;
743    static INIT: Once = Once::new();
744
745    fn initialize() {
746        // Perform your initialization tasks here
747        println!("Running initialization code before tests...");
748        // Example: load .env file if we are using non-docker environment
749        if is_running_in_docker() {
750            println!("Running inside Docker.");
751        } else {
752            println!("Not running inside Docker. Load env from file");
753            dotenv().ok();
754        }
755    }
756    fn is_running_in_docker() -> bool {
757        Path::new("/.dockerenv").exists() || check_cgroup()
758    }
759
760    fn check_cgroup() -> bool {
761        match std::fs::read_to_string("/proc/1/cgroup") {
762            Ok(contents) => contents.contains("docker"),
763            Err(_) => false,
764        }
765    }
766
767    fn env(name: &str) -> String {
768        INIT.call_once(|| {
769            initialize();
770        });
771        std::env::var(name).unwrap_or_else(|_| {
772            panic!(
773                "Failed to get env var needed for test, make sure to set the following env var: {name}",
774            )
775        })
776    }
777
778    fn test_address() -> SocketAddr {
779        format!(
780            "{}:{}",
781            env("ASYNC_SSH2_TEST_HOST_IP"),
782            env("ASYNC_SSH2_TEST_HOST_PORT")
783        )
784        .parse()
785        .unwrap()
786    }
787
788    fn test_hostname() -> impl ToSocketAddrsWithHostname {
789        (
790            env("ASYNC_SSH2_TEST_HOST_NAME"),
791            env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
792        )
793    }
794
795    async fn establish_test_host_connection() -> Client {
796        Client::connect(
797            (
798                env("ASYNC_SSH2_TEST_HOST_IP"),
799                env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
800            ),
801            &env("ASYNC_SSH2_TEST_HOST_USER"),
802            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
803            ServerCheckMethod::NoCheck,
804        )
805        .await
806        .expect("Connection/Authentification failed")
807    }
808
809    #[tokio::test]
810    async fn connect_with_password() {
811        let client = establish_test_host_connection().await;
812        assert_eq!(
813            &env("ASYNC_SSH2_TEST_HOST_USER"),
814            client.get_connection_username(),
815        );
816        assert_eq!(test_address(), *client.get_connection_address(),);
817    }
818
819    #[tokio::test]
820    async fn execute_command_result() {
821        let client = establish_test_host_connection().await;
822        let output = client.execute("echo test!!!").await.unwrap();
823        assert_eq!("test!!!\n", output.stdout);
824        assert_eq!("", output.stderr);
825        assert_eq!(0, output.exit_status);
826    }
827
828    #[tokio::test]
829    async fn execute_command_result_stderr() {
830        let client = establish_test_host_connection().await;
831        let output = client.execute("echo test!!! 1>&2").await.unwrap();
832        assert_eq!("", output.stdout);
833        assert_eq!("test!!!\n", output.stderr);
834        assert_eq!(0, output.exit_status);
835    }
836
837    #[tokio::test]
838    async fn unicode_output() {
839        let client = establish_test_host_connection().await;
840        let output = client.execute("echo To thḙ moon! 🚀").await.unwrap();
841        assert_eq!("To thḙ moon! 🚀\n", output.stdout);
842        assert_eq!(0, output.exit_status);
843    }
844
845    #[tokio::test]
846    async fn execute_command_status() {
847        let client = establish_test_host_connection().await;
848        let output = client.execute("exit 42").await.unwrap();
849        assert_eq!(42, output.exit_status);
850    }
851
852    #[tokio::test]
853    async fn execute_multiple_commands() {
854        let client = establish_test_host_connection().await;
855        let output = client.execute("echo test!!!").await.unwrap().stdout;
856        assert_eq!("test!!!\n", output);
857
858        let output = client.execute("echo Hello World").await.unwrap().stdout;
859        assert_eq!("Hello World\n", output);
860    }
861
862    #[tokio::test]
863    async fn direct_tcpip_channel() {
864        let client = establish_test_host_connection().await;
865        let channel = client
866            .open_direct_tcpip_channel(
867                format!(
868                    "{}:{}",
869                    env("ASYNC_SSH2_TEST_HTTP_SERVER_IP"),
870                    env("ASYNC_SSH2_TEST_HTTP_SERVER_PORT"),
871                ),
872                None,
873            )
874            .await
875            .unwrap();
876
877        let mut stream = channel.into_stream();
878        stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await.unwrap();
879
880        let mut response = String::new();
881        stream.read_to_string(&mut response).await.unwrap();
882
883        let body = response.split_once("\r\n\r\n").unwrap().1;
884        assert_eq!("Hello", body);
885    }
886
887    #[tokio::test]
888    async fn stderr_redirection() {
889        let client = establish_test_host_connection().await;
890
891        let output = client.execute("echo foo >/dev/null").await.unwrap();
892        assert_eq!("", output.stdout);
893
894        let output = client.execute("echo foo >>/dev/stderr").await.unwrap();
895        assert_eq!("", output.stdout);
896
897        let output = client.execute("2>&1 echo foo >>/dev/stderr").await.unwrap();
898        assert_eq!("foo\n", output.stdout);
899    }
900
901    #[tokio::test]
902    async fn sequential_commands() {
903        let client = establish_test_host_connection().await;
904
905        for i in 0..100 {
906            std::thread::sleep(time::Duration::from_millis(100));
907            let res = client
908                .execute(&format!("echo {i}"))
909                .await
910                .unwrap_or_else(|_| panic!("Execution failed in iteration {i}"));
911            assert_eq!(format!("{i}\n"), res.stdout);
912        }
913    }
914
915    #[tokio::test]
916    async fn execute_multiple_context() {
917        // This is maybe not expected behaviour, thus documenting this via a test is important.
918        let client = establish_test_host_connection().await;
919        let output = client
920            .execute("export VARIABLE=42; echo $VARIABLE")
921            .await
922            .unwrap()
923            .stdout;
924        assert_eq!("42\n", output);
925
926        let output = client.execute("echo $VARIABLE").await.unwrap().stdout;
927        assert_eq!("\n", output);
928    }
929
930    #[tokio::test]
931    async fn connect_second_address() {
932        let client = Client::connect(
933            &[SocketAddr::from(([127, 0, 0, 1], 23)), test_address()][..],
934            &env("ASYNC_SSH2_TEST_HOST_USER"),
935            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
936            ServerCheckMethod::NoCheck,
937        )
938        .await
939        .expect("Resolution to second address failed");
940
941        assert_eq!(test_address(), *client.get_connection_address(),);
942    }
943
944    #[tokio::test]
945    async fn connect_with_wrong_password() {
946        let error = Client::connect(
947            test_address(),
948            &env("ASYNC_SSH2_TEST_HOST_USER"),
949            AuthMethod::with_password("hopefully the wrong password"),
950            ServerCheckMethod::NoCheck,
951        )
952        .await
953        .expect_err("Client connected with wrong password");
954
955        match error {
956            crate::Error::PasswordWrong => {}
957            _ => panic!("Wrong error type"),
958        }
959    }
960
961    #[tokio::test]
962    async fn invalid_address() {
963        let no_client = Client::connect(
964            "this is definitely not an address",
965            &env("ASYNC_SSH2_TEST_HOST_USER"),
966            AuthMethod::with_password("hopefully the wrong password"),
967            ServerCheckMethod::NoCheck,
968        )
969        .await;
970        assert!(no_client.is_err());
971    }
972
973    #[tokio::test]
974    async fn connect_to_wrong_port() {
975        let no_client = Client::connect(
976            (env("ASYNC_SSH2_TEST_HOST_IP"), 23),
977            &env("ASYNC_SSH2_TEST_HOST_USER"),
978            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
979            ServerCheckMethod::NoCheck,
980        )
981        .await;
982        assert!(no_client.is_err());
983    }
984
985    #[tokio::test]
986    #[ignore = "This times out only after 20 seconds"]
987    async fn connect_to_wrong_host() {
988        let no_client = Client::connect(
989            "172.16.0.6:22",
990            "xxx",
991            AuthMethod::with_password("xxx"),
992            ServerCheckMethod::NoCheck,
993        )
994        .await;
995        assert!(no_client.is_err());
996    }
997
998    #[tokio::test]
999    async fn auth_key_file() {
1000        let client = Client::connect(
1001            test_address(),
1002            &env("ASYNC_SSH2_TEST_HOST_USER"),
1003            AuthMethod::with_key_file(env("ASYNC_SSH2_TEST_CLIENT_PRIV"), None),
1004            ServerCheckMethod::NoCheck,
1005        )
1006        .await;
1007        assert!(client.is_ok());
1008    }
1009
1010    #[tokio::test]
1011    #[cfg(not(target_os = "windows"))]
1012    async fn auth_with_agent() {
1013        // This test requires SSH agent to be running with the test key loaded
1014        // In Docker environment, the agent is always properly configured
1015        let client = Client::connect(
1016            test_address(),
1017            &env("ASYNC_SSH2_TEST_HOST_USER"),
1018            AuthMethod::with_agent(),
1019            ServerCheckMethod::NoCheck,
1020        )
1021        .await
1022        .expect("Agent authentication should succeed with correct key loaded");
1023
1024        // Verify we can execute a command
1025        let output = client.execute("echo test").await.unwrap();
1026        assert_eq!("test\n", output.stdout);
1027    }
1028
1029    #[tokio::test]
1030    #[cfg(not(target_os = "windows"))]
1031    async fn auth_with_agent_wrong_user() {
1032        // This test verifies that agent auth fails with wrong username
1033        let result = Client::connect(
1034            test_address(),
1035            "wrong_user_that_does_not_exist",
1036            AuthMethod::with_agent(),
1037            ServerCheckMethod::NoCheck,
1038        )
1039        .await;
1040
1041        // Should fail with authentication error
1042        assert!(matches!(
1043            result,
1044            Err(crate::Error::AgentAuthenticationFailed)
1045        ));
1046    }
1047
1048    #[tokio::test]
1049    #[cfg(not(target_os = "windows"))]
1050    async fn auth_with_agent_no_sock() {
1051        // Test behavior when SSH_AUTH_SOCK is not set
1052        // Temporarily unset SSH_AUTH_SOCK for this test
1053        let original_sock = std::env::var("SSH_AUTH_SOCK").ok();
1054        unsafe {
1055            std::env::remove_var("SSH_AUTH_SOCK");
1056        }
1057
1058        let result = Client::connect(
1059            test_address(),
1060            &env("ASYNC_SSH2_TEST_HOST_USER"),
1061            AuthMethod::with_agent(),
1062            ServerCheckMethod::NoCheck,
1063        )
1064        .await;
1065
1066        // Restore original SSH_AUTH_SOCK if it was set
1067        if let Some(sock) = original_sock {
1068            unsafe {
1069                std::env::set_var("SSH_AUTH_SOCK", sock);
1070            }
1071        }
1072
1073        // Should fail with connection error
1074        assert!(matches!(result, Err(crate::Error::AgentConnectionFailed)));
1075    }
1076
1077    #[tokio::test]
1078    async fn auth_key_file_with_passphrase() {
1079        let client = Client::connect(
1080            test_address(),
1081            &env("ASYNC_SSH2_TEST_HOST_USER"),
1082            AuthMethod::with_key_file(
1083                env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV"),
1084                Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS")),
1085            ),
1086            ServerCheckMethod::NoCheck,
1087        )
1088        .await;
1089        if client.is_err() {
1090            println!("{:?}", client.err());
1091            panic!();
1092        }
1093        assert!(client.is_ok());
1094    }
1095
1096    #[tokio::test]
1097    async fn auth_key_str() {
1098        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PRIV")).unwrap();
1099
1100        let client = Client::connect(
1101            test_address(),
1102            &env("ASYNC_SSH2_TEST_HOST_USER"),
1103            AuthMethod::with_key(key.as_str(), None),
1104            ServerCheckMethod::NoCheck,
1105        )
1106        .await;
1107        assert!(client.is_ok());
1108    }
1109
1110    #[tokio::test]
1111    async fn auth_key_str_with_passphrase() {
1112        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV")).unwrap();
1113
1114        let client = Client::connect(
1115            test_address(),
1116            &env("ASYNC_SSH2_TEST_HOST_USER"),
1117            AuthMethod::with_key(key.as_str(), Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS"))),
1118            ServerCheckMethod::NoCheck,
1119        )
1120        .await;
1121        assert!(client.is_ok());
1122    }
1123
1124    #[tokio::test]
1125    async fn auth_keyboard_interactive() {
1126        let client = Client::connect(
1127            test_address(),
1128            &env("ASYNC_SSH2_TEST_HOST_USER"),
1129            AuthKeyboardInteractive::new()
1130                .with_response("Password", env("ASYNC_SSH2_TEST_HOST_PW"))
1131                .into(),
1132            ServerCheckMethod::NoCheck,
1133        )
1134        .await;
1135        assert!(client.is_ok());
1136    }
1137
1138    #[tokio::test]
1139    async fn auth_keyboard_interactive_exact() {
1140        let client = Client::connect(
1141            test_address(),
1142            &env("ASYNC_SSH2_TEST_HOST_USER"),
1143            AuthKeyboardInteractive::new()
1144                .with_response_exact("Password: ", env("ASYNC_SSH2_TEST_HOST_PW"))
1145                .into(),
1146            ServerCheckMethod::NoCheck,
1147        )
1148        .await;
1149        assert!(client.is_ok());
1150    }
1151
1152    #[tokio::test]
1153    async fn auth_keyboard_interactive_wrong_response() {
1154        let client = Client::connect(
1155            test_address(),
1156            &env("ASYNC_SSH2_TEST_HOST_USER"),
1157            AuthKeyboardInteractive::new()
1158                .with_response_exact("Password: ", "wrong password")
1159                .into(),
1160            ServerCheckMethod::NoCheck,
1161        )
1162        .await;
1163        match client {
1164            Err(crate::error::Error::KeyboardInteractiveAuthFailed) => {}
1165            Err(e) => {
1166                panic!("Expected KeyboardInteractiveAuthFailed error. Got error: {e:?}")
1167            }
1168            Ok(_) => panic!("Expected KeyboardInteractiveAuthFailed error."),
1169        }
1170    }
1171
1172    #[tokio::test]
1173    async fn auth_keyboard_interactive_no_response() {
1174        let client = Client::connect(
1175            test_address(),
1176            &env("ASYNC_SSH2_TEST_HOST_USER"),
1177            AuthKeyboardInteractive::new()
1178                .with_response_exact("Password:", "123")
1179                .into(),
1180            ServerCheckMethod::NoCheck,
1181        )
1182        .await;
1183        match client {
1184            Err(crate::error::Error::KeyboardInteractiveNoResponseForPrompt(prompt)) => {
1185                assert_eq!(prompt, "Password: ");
1186            }
1187            Err(e) => {
1188                panic!("Expected KeyboardInteractiveNoResponseForPrompt error. Got error: {e:?}")
1189            }
1190            Ok(_) => panic!("Expected KeyboardInteractiveNoResponseForPrompt error."),
1191        }
1192    }
1193
1194    #[tokio::test]
1195    async fn server_check_file() {
1196        let client = Client::connect(
1197            test_address(),
1198            &env("ASYNC_SSH2_TEST_HOST_USER"),
1199            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1200            ServerCheckMethod::with_public_key_file(&env("ASYNC_SSH2_TEST_SERVER_PUB")),
1201        )
1202        .await;
1203        assert!(client.is_ok());
1204    }
1205
1206    #[tokio::test]
1207    async fn server_check_str() {
1208        let line = std::fs::read_to_string(env("ASYNC_SSH2_TEST_SERVER_PUB")).unwrap();
1209        let mut split = line.split_whitespace();
1210        let key = match (split.next(), split.next()) {
1211            (Some(_), Some(k)) => k,
1212            (Some(k), None) => k,
1213            _ => panic!("Failed to parse pub key file"),
1214        };
1215
1216        let client = Client::connect(
1217            test_address(),
1218            &env("ASYNC_SSH2_TEST_HOST_USER"),
1219            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1220            ServerCheckMethod::with_public_key(key),
1221        )
1222        .await;
1223        assert!(client.is_ok());
1224    }
1225
1226    #[tokio::test]
1227    async fn server_check_by_known_hosts_for_ip() {
1228        let client = Client::connect(
1229            test_address(),
1230            &env("ASYNC_SSH2_TEST_HOST_USER"),
1231            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1232            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1233        )
1234        .await;
1235        assert!(client.is_ok());
1236    }
1237
1238    #[tokio::test]
1239    async fn server_check_by_known_hosts_for_hostname() {
1240        let client = Client::connect(
1241            test_hostname(),
1242            &env("ASYNC_SSH2_TEST_HOST_USER"),
1243            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1244            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1245        )
1246        .await;
1247        if is_running_in_docker() {
1248            assert!(client.is_ok());
1249        } else {
1250            assert!(client.is_err()); // DNS can't find the docker hostname if the rust running without docker container
1251        }
1252    }
1253
1254    #[tokio::test]
1255    async fn client_can_be_cloned() {
1256        let client = establish_test_host_connection().await;
1257        let client2 = client.clone();
1258
1259        let result1 = client.execute("echo test clone").await.unwrap();
1260        let result2 = client2.execute("echo test clone2").await.unwrap();
1261
1262        assert_eq!(result1.stdout, "test clone\n");
1263        assert_eq!(result2.stdout, "test clone2\n");
1264    }
1265
1266    #[tokio::test]
1267    async fn client_can_upload_file() {
1268        let client = establish_test_host_connection().await;
1269        client
1270            .upload_file(&env("ASYNC_SSH2_TEST_UPLOAD_FILE"), "/tmp/uploaded")
1271            .await
1272            .unwrap();
1273        let result = client.execute("cat /tmp/uploaded").await.unwrap();
1274        assert_eq!(result.stdout, "this is a test file\n");
1275    }
1276
1277    #[tokio::test]
1278    async fn client_can_download_file() {
1279        let client = establish_test_host_connection().await;
1280
1281        client
1282            .execute("echo 'this is a downloaded test file' > /tmp/test_download")
1283            .await
1284            .unwrap();
1285
1286        let local_path = std::env::temp_dir().join("downloaded_test_file");
1287        client
1288            .download_file("/tmp/test_download", &local_path)
1289            .await
1290            .unwrap();
1291
1292        let contents = tokio::fs::read_to_string(&local_path).await.unwrap();
1293        assert_eq!(contents, "this is a downloaded test file\n");
1294    }
1295}