async_ssh2_tokio/
client.rs

1use russh::client::KeyboardInteractiveAuthResponse;
2use russh::{
3    Channel,
4    client::{Config, Handle, Handler, Msg},
5};
6use russh_sftp::{client::SftpSession, protocol::OpenFlags};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::{fmt::Debug, path::Path};
10use std::{io, path::PathBuf};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12
13use crate::ToSocketAddrsWithHostname;
14
15/// An authentification token.
16///
17/// Used when creating a [`Client`] for authentification.
18/// Supports password, private key, public key, SSH agent, and keyboard interactive authentication.
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20#[non_exhaustive]
21pub enum AuthMethod {
22    Password(String),
23    PrivateKey {
24        /// entire contents of private key file
25        key_data: String,
26        key_pass: Option<String>,
27    },
28    PrivateKeyFile {
29        key_file_path: PathBuf,
30        key_pass: Option<String>,
31    },
32    #[cfg(not(target_os = "windows"))]
33    PublicKeyFile {
34        key_file_path: PathBuf,
35    },
36    #[cfg(not(target_os = "windows"))]
37    Agent,
38    KeyboardInteractive(AuthKeyboardInteractive),
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum SteamingOutput {
43    Stdout(Vec<u8>),
44    Stderr(Vec<u8>),
45    ExitStatus(u32),
46}
47
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49struct PromptResponse {
50    exact: bool,
51    prompt: String,
52    response: String,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
56#[non_exhaustive]
57pub struct AuthKeyboardInteractive {
58    /// Hnts to the server the preferred methods to be used for authentication.
59    submethods: Option<String>,
60    responses: Vec<PromptResponse>,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Hash)]
64#[non_exhaustive]
65pub enum ServerCheckMethod {
66    NoCheck,
67    /// base64 encoded key without the type prefix or hostname suffix (type is already encoded)
68    PublicKey(String),
69    PublicKeyFile(String),
70    DefaultKnownHostsFile,
71    KnownHostsFile(String),
72}
73
74impl AuthMethod {
75    /// Convenience method to create a [`AuthMethod`] from a string literal.
76    pub fn with_password(password: &str) -> Self {
77        Self::Password(password.to_string())
78    }
79
80    pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
81        Self::PrivateKey {
82            key_data: key.to_string(),
83            key_pass: passphrase.map(str::to_string),
84        }
85    }
86
87    pub fn with_key_file<T: AsRef<Path>>(key_file_path: T, passphrase: Option<&str>) -> Self {
88        Self::PrivateKeyFile {
89            key_file_path: key_file_path.as_ref().to_path_buf(),
90            key_pass: passphrase.map(str::to_string),
91        }
92    }
93
94    #[cfg(not(target_os = "windows"))]
95    pub fn with_public_key_file<T: AsRef<Path>>(key_file_path: T) -> Self {
96        Self::PublicKeyFile {
97            key_file_path: key_file_path.as_ref().to_path_buf(),
98        }
99    }
100
101    /// Creates a new SSH agent authentication method.
102    ///
103    /// This will attempt to authenticate using all identities available in the SSH agent.
104    /// The SSH agent must be running and the SSH_AUTH_SOCK environment variable must be set.
105    ///
106    /// # Example
107    /// ```no_run
108    /// use async_ssh2_tokio::client::AuthMethod;
109    ///
110    /// let auth = AuthMethod::with_agent();
111    /// ```
112    ///
113    /// # Platform Support
114    /// This method is only available on Unix-like systems (Linux, macOS, etc.).
115    /// It is not available on Windows.
116    #[cfg(not(target_os = "windows"))]
117    pub fn with_agent() -> Self {
118        Self::Agent
119    }
120
121    pub const fn with_keyboard_interactive(auth: AuthKeyboardInteractive) -> Self {
122        Self::KeyboardInteractive(auth)
123    }
124}
125
126impl AuthKeyboardInteractive {
127    pub fn new() -> Self {
128        Default::default()
129    }
130
131    /// Hnts to the server the preferred methods to be used for authentication.
132    pub fn with_submethods(mut self, submethods: impl Into<String>) -> Self {
133        self.submethods = Some(submethods.into());
134        self
135    }
136
137    /// Adds a response to the list of responses for a given prompt.
138    ///
139    /// The comparison for the prompt is done using a "contains".
140    pub fn with_response(mut self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
141        self.responses.push(PromptResponse {
142            exact: false,
143            prompt: prompt.into(),
144            response: response.into(),
145        });
146
147        self
148    }
149
150    /// Adds a response to the list of responses for a given exact prompt.
151    pub fn with_response_exact(
152        mut self,
153        prompt: impl Into<String>,
154        response: impl Into<String>,
155    ) -> Self {
156        self.responses.push(PromptResponse {
157            exact: true,
158            prompt: prompt.into(),
159            response: response.into(),
160        });
161
162        self
163    }
164}
165
166impl PromptResponse {
167    fn matches(&self, received_prompt: &str) -> bool {
168        if self.exact {
169            self.prompt.eq(received_prompt)
170        } else {
171            received_prompt.contains(&self.prompt)
172        }
173    }
174}
175
176impl From<AuthKeyboardInteractive> for AuthMethod {
177    fn from(value: AuthKeyboardInteractive) -> Self {
178        Self::with_keyboard_interactive(value)
179    }
180}
181
182impl ServerCheckMethod {
183    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
184    pub fn with_public_key(key: &str) -> Self {
185        Self::PublicKey(key.to_string())
186    }
187
188    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
189    pub fn with_public_key_file(key_file_name: &str) -> Self {
190        Self::PublicKeyFile(key_file_name.to_string())
191    }
192
193    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
194    pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
195        Self::KnownHostsFile(known_hosts_file.to_string())
196    }
197}
198
199/// A ssh connection to a remote server.
200///
201/// After creating a `Client` by [`connect`]ing to a remote host,
202/// use [`execute`] to send commands and receive results through the connections.
203///
204/// [`connect`]: Client::connect
205/// [`execute`]: Client::execute
206///
207/// # Examples
208///
209/// ```no_run
210/// use async_ssh2_tokio::{Client, AuthMethod, ServerCheckMethod};
211/// #[tokio::main]
212/// async fn main() -> Result<(), async_ssh2_tokio::Error> {
213///     let mut client = Client::connect(
214///         ("10.10.10.2", 22),
215///         "root",
216///         AuthMethod::with_password("root"),
217///         ServerCheckMethod::NoCheck,
218///     ).await?;
219///
220///     let result = client.execute("echo Hello SSH").await?;
221///     assert_eq!(result.stdout, "Hello SSH\n");
222///     assert_eq!(result.exit_status, 0);
223///
224///     Ok(())
225/// }
226#[derive(Clone)]
227pub struct Client {
228    connection_handle: Arc<Handle<ClientHandler>>,
229    username: String,
230    address: SocketAddr,
231}
232
233impl Client {
234    /// Open a ssh connection to a remote host.
235    ///
236    /// `addr` is an address of the remote host. Anything which implements
237    /// [`ToSocketAddrsWithHostname`] trait can be supplied for the address;
238    /// ToSocketAddrsWithHostname reimplements all of [`ToSocketAddrs`];
239    /// see this trait's documentation for concrete examples.
240    ///
241    /// If `addr` yields multiple addresses, `connect` will be attempted with
242    /// each of the addresses until a connection is successful.
243    /// Authentification is tried on the first successful connection and the whole
244    /// process aborted if this fails.
245    pub async fn connect(
246        addr: impl ToSocketAddrsWithHostname,
247        username: &str,
248        auth: AuthMethod,
249        server_check: ServerCheckMethod,
250    ) -> Result<Self, crate::Error> {
251        Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
252    }
253
254    /// Same as `connect`, but with the option to specify a non default
255    /// [`russh::client::Config`].
256    pub async fn connect_with_config(
257        addr: impl ToSocketAddrsWithHostname,
258        username: &str,
259        auth: AuthMethod,
260        server_check: ServerCheckMethod,
261        config: Config,
262    ) -> Result<Self, crate::Error> {
263        let config = Arc::new(config);
264
265        // Connection code inspired from std::net::TcpStream::connect and std::net::each_addr
266        let socket_addrs = addr
267            .to_socket_addrs()
268            .map_err(crate::Error::AddressInvalid)?;
269        let mut connect_res = Err(crate::Error::AddressInvalid(io::Error::new(
270            io::ErrorKind::InvalidInput,
271            "could not resolve to any addresses",
272        )));
273        for socket_addr in socket_addrs {
274            let handler = ClientHandler {
275                hostname: addr.hostname(),
276                host: socket_addr,
277                server_check: server_check.clone(),
278            };
279            match russh::client::connect(config.clone(), socket_addr, handler).await {
280                Ok(h) => {
281                    connect_res = Ok((socket_addr, h));
282                    break;
283                }
284                Err(e) => connect_res = Err(e),
285            }
286        }
287        let (address, mut handle) = connect_res?;
288        let username = username.to_string();
289
290        Self::authenticate(&mut handle, &username, auth).await?;
291
292        Ok(Self {
293            connection_handle: Arc::new(handle),
294            username,
295            address,
296        })
297    }
298
299    /// This takes a handle and performs authentification with the given method.
300    async fn authenticate(
301        handle: &mut Handle<ClientHandler>,
302        username: &String,
303        auth: AuthMethod,
304    ) -> Result<(), crate::Error> {
305        match auth {
306            AuthMethod::Password(password) => {
307                let is_authentificated = handle.authenticate_password(username, password).await?;
308                if !is_authentificated.success() {
309                    return Err(crate::Error::PasswordWrong);
310                }
311            }
312            AuthMethod::PrivateKey { key_data, key_pass } => {
313                let cprivk = russh::keys::decode_secret_key(key_data.as_str(), key_pass.as_deref())
314                    .map_err(crate::Error::KeyInvalid)?;
315                let is_authentificated = handle
316                    .authenticate_publickey(
317                        username,
318                        russh::keys::PrivateKeyWithHashAlg::new(
319                            Arc::new(cprivk),
320                            handle.best_supported_rsa_hash().await?.flatten(),
321                        ),
322                    )
323                    .await?;
324                if !is_authentificated.success() {
325                    return Err(crate::Error::KeyAuthFailed);
326                }
327            }
328            AuthMethod::PrivateKeyFile {
329                key_file_path,
330                key_pass,
331            } => {
332                let cprivk = russh::keys::load_secret_key(key_file_path, key_pass.as_deref())
333                    .map_err(crate::Error::KeyInvalid)?;
334                let is_authentificated = handle
335                    .authenticate_publickey(
336                        username,
337                        russh::keys::PrivateKeyWithHashAlg::new(
338                            Arc::new(cprivk),
339                            handle.best_supported_rsa_hash().await?.flatten(),
340                        ),
341                    )
342                    .await?;
343                if !is_authentificated.success() {
344                    return Err(crate::Error::KeyAuthFailed);
345                }
346            }
347            #[cfg(not(target_os = "windows"))]
348            AuthMethod::PublicKeyFile { key_file_path } => {
349                let cpubk = russh::keys::load_public_key(key_file_path)
350                    .map_err(crate::Error::KeyInvalid)?;
351                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
352                    .await
353                    .unwrap();
354                let mut auth_identity: Option<russh::keys::PublicKey> = None;
355                for identity in agent
356                    .request_identities()
357                    .await
358                    .map_err(crate::Error::KeyInvalid)?
359                {
360                    if identity == cpubk {
361                        auth_identity = Some(identity.clone());
362                        break;
363                    }
364                }
365
366                if auth_identity.is_none() {
367                    return Err(crate::Error::KeyAuthFailed);
368                }
369
370                let is_authentificated = handle
371                    .authenticate_publickey_with(
372                        username,
373                        cpubk,
374                        handle.best_supported_rsa_hash().await?.flatten(),
375                        &mut agent,
376                    )
377                    .await?;
378                if !is_authentificated.success() {
379                    return Err(crate::Error::KeyAuthFailed);
380                }
381            }
382            #[cfg(not(target_os = "windows"))]
383            AuthMethod::Agent => {
384                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
385                    .await
386                    .map_err(|_| crate::Error::AgentConnectionFailed)?;
387
388                let identities = agent
389                    .request_identities()
390                    .await
391                    .map_err(|_| crate::Error::AgentRequestIdentitiesFailed)?;
392
393                if identities.is_empty() {
394                    return Err(crate::Error::AgentNoIdentities);
395                }
396
397                let mut auth_success = false;
398                for identity in identities {
399                    let result = handle
400                        .authenticate_publickey_with(
401                            username,
402                            identity.clone(),
403                            handle.best_supported_rsa_hash().await?.flatten(),
404                            &mut agent,
405                        )
406                        .await;
407
408                    if let Ok(auth_result) = result
409                        && auth_result.success()
410                    {
411                        auth_success = true;
412                        break;
413                    }
414                }
415
416                if !auth_success {
417                    return Err(crate::Error::AgentAuthenticationFailed);
418                }
419            }
420            AuthMethod::KeyboardInteractive(mut kbd) => {
421                let mut res = handle
422                    .authenticate_keyboard_interactive_start(username, kbd.submethods)
423                    .await?;
424                loop {
425                    let prompts = match res {
426                        KeyboardInteractiveAuthResponse::Success => break,
427                        KeyboardInteractiveAuthResponse::Failure { .. } => {
428                            return Err(crate::Error::KeyboardInteractiveAuthFailed);
429                        }
430                        KeyboardInteractiveAuthResponse::InfoRequest { prompts, .. } => prompts,
431                    };
432
433                    let mut responses = vec![];
434                    for prompt in prompts {
435                        let Some(pos) = kbd
436                            .responses
437                            .iter()
438                            .position(|pr| pr.matches(&prompt.prompt))
439                        else {
440                            return Err(crate::Error::KeyboardInteractiveNoResponseForPrompt(
441                                prompt.prompt,
442                            ));
443                        };
444                        let pr = kbd.responses.remove(pos);
445                        responses.push(pr.response);
446                    }
447
448                    res = handle
449                        .authenticate_keyboard_interactive_respond(responses)
450                        .await?;
451                }
452            }
453        };
454        Ok(())
455    }
456
457    pub async fn get_channel(&self) -> Result<Channel<Msg>, crate::Error> {
458        self.connection_handle
459            .channel_open_session()
460            .await
461            .map_err(crate::Error::SshError)
462    }
463
464    /// Open a TCP/IP forwarding channel.
465    ///
466    /// This opens a `direct-tcpip` channel to the given target.
467    pub async fn open_direct_tcpip_channel<
468        T: ToSocketAddrsWithHostname,
469        S: Into<Option<SocketAddr>>,
470    >(
471        &self,
472        target: T,
473        src: S,
474    ) -> Result<Channel<Msg>, crate::Error> {
475        let targets = target
476            .to_socket_addrs()
477            .map_err(crate::Error::AddressInvalid)?;
478        let src = src
479            .into()
480            .map(|src| (src.ip().to_string(), src.port().into()))
481            .unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
482
483        let mut connect_err = crate::Error::AddressInvalid(io::Error::new(
484            io::ErrorKind::InvalidInput,
485            "could not resolve to any addresses",
486        ));
487        for target in targets {
488            match self
489                .connection_handle
490                .channel_open_direct_tcpip(
491                    target.ip().to_string(),
492                    target.port().into(),
493                    src.0.clone(),
494                    src.1,
495                )
496                .await
497            {
498                Ok(channel) => return Ok(channel),
499                Err(err) => connect_err = crate::Error::SshError(err),
500            }
501        }
502
503        Err(connect_err)
504    }
505
506    /// Upload a file with sftp to the remote server.
507    ///
508    /// `src_file_path` is the path to the file on the local machine.
509    /// `dest_file_path` is the path to the file on the remote machine.
510    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
511    /// A config line like a `Subsystem sftp internal-sftp` or
512    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
513    pub async fn upload_file<T: AsRef<Path>, U: Into<String>>(
514        &self,
515        src_file_path: T,
516        //fa993: This cannot be AsRef<Path> because of underlying lib constraints as described here
517        //https://github.com/AspectUnk/russh-sftp/issues/7#issuecomment-1738355245
518        dest_file_path: U,
519    ) -> Result<(), crate::Error> {
520        // start sftp session
521        let channel = self.get_channel().await?;
522        channel.request_subsystem(true, "sftp").await?;
523        let sftp = SftpSession::new(channel.into_stream()).await?;
524
525        // read file contents locally
526        let file_contents = tokio::fs::read(src_file_path)
527            .await
528            .map_err(crate::Error::IoError)?;
529
530        // interaction with i/o
531        let mut file = sftp
532            .open_with_flags(
533                dest_file_path,
534                OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
535            )
536            .await?;
537        file.write_all(&file_contents)
538            .await
539            .map_err(crate::Error::IoError)?;
540        file.flush().await.map_err(crate::Error::IoError)?;
541        file.shutdown().await.map_err(crate::Error::IoError)?;
542
543        Ok(())
544    }
545
546    /// Download a file from the remote server using sftp.
547    ///
548    /// `remote_file_path` is the path to the file on the remote machine.
549    /// `local_file_path` is the path to the file on the local machine.
550    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
551    /// A config line like a `Subsystem sftp internal-sftp` or
552    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
553    pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
554        &self,
555        remote_file_path: U,
556        local_file_path: T,
557    ) -> Result<(), crate::Error> {
558        // start sftp session
559        let channel = self.get_channel().await?;
560        channel.request_subsystem(true, "sftp").await?;
561        let sftp = SftpSession::new(channel.into_stream()).await?;
562
563        // open remote file for reading
564        let mut remote_file = sftp
565            .open_with_flags(remote_file_path, OpenFlags::READ)
566            .await?;
567
568        // read remote file contents
569        let mut contents = Vec::new();
570        remote_file.read_to_end(contents.as_mut()).await?;
571
572        // write contents to local file
573        let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
574            .await
575            .map_err(crate::Error::IoError)?;
576
577        local_file
578            .write_all(&contents)
579            .await
580            .map_err(crate::Error::IoError)?;
581        local_file.flush().await.map_err(crate::Error::IoError)?;
582
583        Ok(())
584    }
585
586    /// Execute a remote command via the ssh connection.
587    ///
588    /// Returns stdout, stderr and the exit code of the command,
589    /// packaged in a [`CommandExecutedResult`] struct.
590    /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection,
591    /// e.g. `echo foo 2>&1`.
592    /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`.
593    ///
594    /// Make sure your commands don't read from stdin and exit after bounded time.
595    ///
596    /// Can be called multiple times, but every invocation is a new shell context.
597    /// Thus `cd`, setting variables and alike have no effect on future invocations.
598    pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, crate::Error> {
599        let mut stdout_buffer = vec![];
600        let mut stderr_buffer = vec![];
601        let mut channel = self.connection_handle.channel_open_session().await?;
602        channel.exec(true, command).await?;
603
604        let mut result: Option<u32> = None;
605
606        // While the channel has messages...
607        while let Some(msg) = channel.wait().await {
608            //dbg!(&msg);
609            match msg {
610                // If we get data, add it to the buffer
611                russh::ChannelMsg::Data { ref data } => {
612                    stdout_buffer.write_all(data).await.unwrap()
613                }
614                russh::ChannelMsg::ExtendedData { ref data, ext } => {
615                    if ext == 1 {
616                        stderr_buffer.write_all(data).await.unwrap()
617                    }
618                }
619
620                // If we get an exit code report, store it, but crucially don't
621                // assume this message means end of communications. The data might
622                // not be finished yet!
623                russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
624
625                // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits
626                // the channel to close without it being sent. And sometimes this
627                // message can even precede the Data message, so don't handle it
628                // russh::ChannelMsg::Eof => break,
629                _ => {}
630            }
631        }
632
633        // If we received an exit code, report it back
634        if let Some(result) = result {
635            Ok(CommandExecutedResult {
636                stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
637                stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
638                exit_status: result,
639            })
640
641        // Otherwise, report an error
642        } else {
643            Err(crate::Error::CommandDidntExit)
644        }
645    }
646
647    /// Execute a remote command via the ssh connection.
648    ///
649    /// Command output is stream to the provided channel. Returns the exit code.
650    /// The channel sends `SteamingOutput` enum variants to distinguish stdout,
651    /// stderr and exit code so message arrive interleaved and in the order
652    /// they are received. See `execute` for more details.
653    ///
654    pub async fn execute_streaming(
655        &self,
656        command: &str,
657        ch: tokio::sync::mpsc::Sender<SteamingOutput>,
658    ) -> Result<u32, crate::Error> {
659        let mut channel = self.connection_handle.channel_open_session().await?;
660        channel.exec(true, command).await?;
661
662        let mut result: Option<u32> = None;
663
664        // While the channel has messages...
665        while let Some(msg) = channel.wait().await {
666            //dbg!(&msg);
667            match msg {
668                // If we get data, add it to the buffer
669                russh::ChannelMsg::Data { ref data } => {
670                    ch.send(SteamingOutput::Stdout(Vec::from(data.as_ref())))
671                        .await
672                        .unwrap();
673                }
674                russh::ChannelMsg::ExtendedData { ref data, ext } => {
675                    if ext == 1 {
676                        ch.send(SteamingOutput::Stderr(Vec::from(data.as_ref())))
677                            .await
678                            .unwrap();
679                    }
680                }
681
682                // If we get an exit code report, store it, but crucially don't
683                // assume this message means end of communications. The data might
684                // not be finished yet!
685                russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
686
687                // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits
688                // the channel to close without it being sent. And sometimes this
689                // message can even precede the Data message, so don't handle it
690                // russh::ChannelMsg::Eof => break,
691                _ => {}
692            }
693        }
694
695        if let Some(exit_code) = result {
696            ch.send(SteamingOutput::ExitStatus(exit_code))
697                .await
698                .unwrap();
699            Ok(exit_code)
700        } else {
701            Err(crate::Error::CommandDidntExit)
702        }
703    }
704
705    /// A debugging function to get the username this client is connected as.
706    pub fn get_connection_username(&self) -> &String {
707        &self.username
708    }
709
710    /// A debugging function to get the address this client is connected to.
711    pub fn get_connection_address(&self) -> &SocketAddr {
712        &self.address
713    }
714
715    pub async fn disconnect(&self) -> Result<(), crate::Error> {
716        self.connection_handle
717            .disconnect(russh::Disconnect::ByApplication, "", "")
718            .await
719            .map_err(crate::Error::SshError)
720    }
721
722    pub fn is_closed(&self) -> bool {
723        self.connection_handle.is_closed()
724    }
725}
726
727impl Debug for Client {
728    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
729        f.debug_struct("Client")
730            .field("username", &self.username)
731            .field("address", &self.address)
732            .field("connection_handle", &"Handle<ClientHandler>")
733            .finish()
734    }
735}
736
737#[derive(Debug, Clone, PartialEq, Eq, Hash)]
738pub struct CommandExecutedResult {
739    /// The stdout output of the command.
740    pub stdout: String,
741    /// The stderr output of the command.
742    pub stderr: String,
743    /// The unix exit status (`$?` in bash).
744    pub exit_status: u32,
745}
746
747#[derive(Debug, Clone)]
748struct ClientHandler {
749    hostname: String,
750    host: SocketAddr,
751    server_check: ServerCheckMethod,
752}
753
754impl Handler for ClientHandler {
755    type Error = crate::Error;
756
757    async fn check_server_key(
758        &mut self,
759        server_public_key: &russh::keys::PublicKey,
760    ) -> Result<bool, Self::Error> {
761        match &self.server_check {
762            ServerCheckMethod::NoCheck => Ok(true),
763            ServerCheckMethod::PublicKey(key) => {
764                let pk = russh::keys::parse_public_key_base64(key)
765                    .map_err(|_| crate::Error::ServerCheckFailed)?;
766
767                Ok(pk == *server_public_key)
768            }
769            ServerCheckMethod::PublicKeyFile(key_file_name) => {
770                let pk = russh::keys::load_public_key(key_file_name)
771                    .map_err(|_| crate::Error::ServerCheckFailed)?;
772
773                Ok(pk == *server_public_key)
774            }
775            ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
776                let result = russh::keys::check_known_hosts_path(
777                    &self.hostname,
778                    self.host.port(),
779                    server_public_key,
780                    known_hosts_path,
781                )
782                .map_err(|_| crate::Error::ServerCheckFailed)?;
783
784                Ok(result)
785            }
786            ServerCheckMethod::DefaultKnownHostsFile => {
787                let result = russh::keys::check_known_hosts(
788                    &self.hostname,
789                    self.host.port(),
790                    server_public_key,
791                )
792                .map_err(|_| crate::Error::ServerCheckFailed)?;
793
794                Ok(result)
795            }
796        }
797    }
798}
799
800#[cfg(test)]
801mod tests {
802    use crate::client::*;
803    use core::time;
804    use dotenv::dotenv;
805    use std::path::Path;
806    use std::sync::Once;
807    use tokio::io::AsyncReadExt;
808    static INIT: Once = Once::new();
809
810    fn initialize() {
811        // Perform your initialization tasks here
812        println!("Running initialization code before tests...");
813        // Example: load .env file if we are using non-docker environment
814        if is_running_in_docker() {
815            println!("Running inside Docker.");
816        } else {
817            println!("Not running inside Docker. Load env from file");
818            dotenv().ok();
819        }
820    }
821    fn is_running_in_docker() -> bool {
822        Path::new("/.dockerenv").exists() || check_cgroup()
823    }
824
825    fn check_cgroup() -> bool {
826        match std::fs::read_to_string("/proc/1/cgroup") {
827            Ok(contents) => contents.contains("docker"),
828            Err(_) => false,
829        }
830    }
831
832    fn env(name: &str) -> String {
833        INIT.call_once(|| {
834            initialize();
835        });
836        std::env::var(name).unwrap_or_else(|_| {
837            panic!(
838                "Failed to get env var needed for test, make sure to set the following env var: {name}",
839            )
840        })
841    }
842
843    fn test_address() -> SocketAddr {
844        format!(
845            "{}:{}",
846            env("ASYNC_SSH2_TEST_HOST_IP"),
847            env("ASYNC_SSH2_TEST_HOST_PORT")
848        )
849        .parse()
850        .unwrap()
851    }
852
853    fn test_hostname() -> impl ToSocketAddrsWithHostname {
854        (
855            env("ASYNC_SSH2_TEST_HOST_NAME"),
856            env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
857        )
858    }
859
860    async fn establish_test_host_connection() -> Client {
861        Client::connect(
862            (
863                env("ASYNC_SSH2_TEST_HOST_IP"),
864                env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
865            ),
866            &env("ASYNC_SSH2_TEST_HOST_USER"),
867            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
868            ServerCheckMethod::NoCheck,
869        )
870        .await
871        .expect("Connection/Authentification failed")
872    }
873
874    #[tokio::test]
875    async fn connect_with_password() {
876        let client = establish_test_host_connection().await;
877        assert_eq!(
878            &env("ASYNC_SSH2_TEST_HOST_USER"),
879            client.get_connection_username(),
880        );
881        assert_eq!(test_address(), *client.get_connection_address(),);
882    }
883
884    #[tokio::test]
885    async fn execute_command_result() {
886        let client = establish_test_host_connection().await;
887        let output = client.execute("echo test!!!").await.unwrap();
888        assert_eq!("test!!!\n", output.stdout);
889        assert_eq!("", output.stderr);
890        assert_eq!(0, output.exit_status);
891    }
892
893    #[tokio::test]
894    async fn execute_streaming_command_result() {
895        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
896        let client = establish_test_host_connection().await;
897        let result = client.execute_streaming("echo test!!!", tx).await.unwrap();
898        let mut output = Vec::new();
899        while let Some(msg) = rx.recv().await {
900            output.push(msg);
901        }
902        assert_eq!(0, result);
903        assert_eq!(
904            &[
905                SteamingOutput::Stdout(b"test!!!\n".to_vec()),
906                SteamingOutput::ExitStatus(0),
907            ],
908            output.as_slice(),
909        );
910    }
911
912    #[tokio::test]
913    async fn execute_command_result_stderr() {
914        let client = establish_test_host_connection().await;
915        let output = client.execute("echo test!!! 1>&2").await.unwrap();
916        assert_eq!("", output.stdout);
917        assert_eq!("test!!!\n", output.stderr);
918        assert_eq!(0, output.exit_status);
919    }
920
921    #[tokio::test]
922    async fn execute_streaming_command_result_stderr() {
923        let client = establish_test_host_connection().await;
924        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
925        let result = client
926            .execute_streaming("echo test!!! 1>&2", tx)
927            .await
928            .unwrap();
929        let mut output = Vec::new();
930        while let Some(msg) = rx.recv().await {
931            output.push(msg);
932        }
933        assert_eq!(0, result);
934        assert_eq!(
935            &[
936                SteamingOutput::Stderr(b"test!!!\n".to_vec()),
937                SteamingOutput::ExitStatus(0),
938            ],
939            output.as_slice()
940        );
941    }
942
943    #[tokio::test]
944    async fn unicode_output() {
945        let client = establish_test_host_connection().await;
946        let output = client.execute("echo To thḙ moon! 🚀").await.unwrap();
947        assert_eq!("To thḙ moon! 🚀\n", output.stdout);
948        assert_eq!(0, output.exit_status);
949    }
950
951    #[tokio::test]
952    async fn execute_command_status() {
953        let client = establish_test_host_connection().await;
954        let output = client.execute("exit 42").await.unwrap();
955        assert_eq!(42, output.exit_status);
956    }
957
958    #[tokio::test]
959    async fn execute_streaming_command_status() {
960        let client = establish_test_host_connection().await;
961        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
962        let result = client.execute_streaming("exit 42", tx).await.unwrap();
963        let mut output = Vec::new();
964        while let Some(msg) = rx.recv().await {
965            output.push(msg);
966        }
967        assert_eq!(42, result);
968        assert_eq!(&[SteamingOutput::ExitStatus(42),], output.as_slice());
969    }
970
971    #[tokio::test]
972    async fn execute_multiple_commands() {
973        let client = establish_test_host_connection().await;
974        let output = client.execute("echo test!!!").await.unwrap().stdout;
975        assert_eq!("test!!!\n", output);
976
977        let output = client.execute("echo Hello World").await.unwrap().stdout;
978        assert_eq!("Hello World\n", output);
979    }
980
981    #[tokio::test]
982    async fn direct_tcpip_channel() {
983        let client = establish_test_host_connection().await;
984        let channel = client
985            .open_direct_tcpip_channel(
986                format!(
987                    "{}:{}",
988                    env("ASYNC_SSH2_TEST_HTTP_SERVER_IP"),
989                    env("ASYNC_SSH2_TEST_HTTP_SERVER_PORT"),
990                ),
991                None,
992            )
993            .await
994            .unwrap();
995
996        let mut stream = channel.into_stream();
997        stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await.unwrap();
998
999        let mut response = String::new();
1000        stream.read_to_string(&mut response).await.unwrap();
1001
1002        let body = response.split_once("\r\n\r\n").unwrap().1;
1003        assert_eq!("Hello", body);
1004    }
1005
1006    #[tokio::test]
1007    async fn stderr_redirection() {
1008        let client = establish_test_host_connection().await;
1009
1010        let output = client.execute("echo foo >/dev/null").await.unwrap();
1011        assert_eq!("", output.stdout);
1012
1013        let output = client.execute("echo foo >>/dev/stderr").await.unwrap();
1014        assert_eq!("", output.stdout);
1015
1016        let output = client.execute("2>&1 echo foo >>/dev/stderr").await.unwrap();
1017        assert_eq!("foo\n", output.stdout);
1018    }
1019
1020    #[tokio::test]
1021    async fn sequential_commands() {
1022        let client = establish_test_host_connection().await;
1023
1024        for i in 0..100 {
1025            std::thread::sleep(time::Duration::from_millis(100));
1026            let res = client
1027                .execute(&format!("echo {i}"))
1028                .await
1029                .unwrap_or_else(|_| panic!("Execution failed in iteration {i}"));
1030            assert_eq!(format!("{i}\n"), res.stdout);
1031        }
1032    }
1033
1034    #[tokio::test]
1035    async fn execute_multiple_context() {
1036        // This is maybe not expected behaviour, thus documenting this via a test is important.
1037        let client = establish_test_host_connection().await;
1038        let output = client
1039            .execute("export VARIABLE=42; echo $VARIABLE")
1040            .await
1041            .unwrap()
1042            .stdout;
1043        assert_eq!("42\n", output);
1044
1045        let output = client.execute("echo $VARIABLE").await.unwrap().stdout;
1046        assert_eq!("\n", output);
1047    }
1048
1049    #[tokio::test]
1050    async fn connect_second_address() {
1051        let client = Client::connect(
1052            &[SocketAddr::from(([127, 0, 0, 1], 23)), test_address()][..],
1053            &env("ASYNC_SSH2_TEST_HOST_USER"),
1054            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1055            ServerCheckMethod::NoCheck,
1056        )
1057        .await
1058        .expect("Resolution to second address failed");
1059
1060        assert_eq!(test_address(), *client.get_connection_address(),);
1061    }
1062
1063    #[tokio::test]
1064    async fn connect_with_wrong_password() {
1065        let error = Client::connect(
1066            test_address(),
1067            &env("ASYNC_SSH2_TEST_HOST_USER"),
1068            AuthMethod::with_password("hopefully the wrong password"),
1069            ServerCheckMethod::NoCheck,
1070        )
1071        .await
1072        .expect_err("Client connected with wrong password");
1073
1074        match error {
1075            crate::Error::PasswordWrong => {}
1076            _ => panic!("Wrong error type"),
1077        }
1078    }
1079
1080    #[tokio::test]
1081    async fn invalid_address() {
1082        let no_client = Client::connect(
1083            "this is definitely not an address",
1084            &env("ASYNC_SSH2_TEST_HOST_USER"),
1085            AuthMethod::with_password("hopefully the wrong password"),
1086            ServerCheckMethod::NoCheck,
1087        )
1088        .await;
1089        assert!(no_client.is_err());
1090    }
1091
1092    #[tokio::test]
1093    async fn connect_to_wrong_port() {
1094        let no_client = Client::connect(
1095            (env("ASYNC_SSH2_TEST_HOST_IP"), 23),
1096            &env("ASYNC_SSH2_TEST_HOST_USER"),
1097            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1098            ServerCheckMethod::NoCheck,
1099        )
1100        .await;
1101        assert!(no_client.is_err());
1102    }
1103
1104    #[tokio::test]
1105    #[ignore = "This times out only after 20 seconds"]
1106    async fn connect_to_wrong_host() {
1107        let no_client = Client::connect(
1108            "172.16.0.6:22",
1109            "xxx",
1110            AuthMethod::with_password("xxx"),
1111            ServerCheckMethod::NoCheck,
1112        )
1113        .await;
1114        assert!(no_client.is_err());
1115    }
1116
1117    #[tokio::test]
1118    async fn auth_key_file() {
1119        let client = Client::connect(
1120            test_address(),
1121            &env("ASYNC_SSH2_TEST_HOST_USER"),
1122            AuthMethod::with_key_file(env("ASYNC_SSH2_TEST_CLIENT_PRIV"), None),
1123            ServerCheckMethod::NoCheck,
1124        )
1125        .await;
1126        assert!(client.is_ok());
1127    }
1128
1129    #[tokio::test]
1130    #[cfg(not(target_os = "windows"))]
1131    async fn auth_with_agent() {
1132        // This test requires SSH agent to be running with the test key loaded
1133        // In Docker environment, the agent is always properly configured
1134        let client = Client::connect(
1135            test_address(),
1136            &env("ASYNC_SSH2_TEST_HOST_USER"),
1137            AuthMethod::with_agent(),
1138            ServerCheckMethod::NoCheck,
1139        )
1140        .await
1141        .expect("Agent authentication should succeed with correct key loaded");
1142
1143        // Verify we can execute a command
1144        let output = client.execute("echo test").await.unwrap();
1145        assert_eq!("test\n", output.stdout);
1146    }
1147
1148    #[tokio::test]
1149    #[cfg(not(target_os = "windows"))]
1150    async fn auth_with_agent_wrong_user() {
1151        // This test verifies that agent auth fails with wrong username
1152        let result = Client::connect(
1153            test_address(),
1154            "wrong_user_that_does_not_exist",
1155            AuthMethod::with_agent(),
1156            ServerCheckMethod::NoCheck,
1157        )
1158        .await;
1159
1160        // Should fail with authentication error
1161        assert!(matches!(
1162            result,
1163            Err(crate::Error::AgentAuthenticationFailed)
1164        ));
1165    }
1166
1167    #[tokio::test]
1168    #[cfg(not(target_os = "windows"))]
1169    async fn auth_with_agent_no_sock() {
1170        // Test behavior when SSH_AUTH_SOCK is not set
1171        // Temporarily unset SSH_AUTH_SOCK for this test
1172        let original_sock = std::env::var("SSH_AUTH_SOCK").ok();
1173        unsafe {
1174            std::env::remove_var("SSH_AUTH_SOCK");
1175        }
1176
1177        let result = Client::connect(
1178            test_address(),
1179            &env("ASYNC_SSH2_TEST_HOST_USER"),
1180            AuthMethod::with_agent(),
1181            ServerCheckMethod::NoCheck,
1182        )
1183        .await;
1184
1185        // Restore original SSH_AUTH_SOCK if it was set
1186        if let Some(sock) = original_sock {
1187            unsafe {
1188                std::env::set_var("SSH_AUTH_SOCK", sock);
1189            }
1190        }
1191
1192        // Should fail with connection error
1193        assert!(matches!(result, Err(crate::Error::AgentConnectionFailed)));
1194    }
1195
1196    #[tokio::test]
1197    async fn auth_key_file_with_passphrase() {
1198        let client = Client::connect(
1199            test_address(),
1200            &env("ASYNC_SSH2_TEST_HOST_USER"),
1201            AuthMethod::with_key_file(
1202                env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV"),
1203                Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS")),
1204            ),
1205            ServerCheckMethod::NoCheck,
1206        )
1207        .await;
1208        if client.is_err() {
1209            println!("{:?}", client.err());
1210            panic!();
1211        }
1212        assert!(client.is_ok());
1213    }
1214
1215    #[tokio::test]
1216    async fn auth_key_str() {
1217        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PRIV")).unwrap();
1218
1219        let client = Client::connect(
1220            test_address(),
1221            &env("ASYNC_SSH2_TEST_HOST_USER"),
1222            AuthMethod::with_key(key.as_str(), None),
1223            ServerCheckMethod::NoCheck,
1224        )
1225        .await;
1226        assert!(client.is_ok());
1227    }
1228
1229    #[tokio::test]
1230    async fn auth_key_str_with_passphrase() {
1231        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV")).unwrap();
1232
1233        let client = Client::connect(
1234            test_address(),
1235            &env("ASYNC_SSH2_TEST_HOST_USER"),
1236            AuthMethod::with_key(key.as_str(), Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS"))),
1237            ServerCheckMethod::NoCheck,
1238        )
1239        .await;
1240        assert!(client.is_ok());
1241    }
1242
1243    #[tokio::test]
1244    async fn auth_keyboard_interactive() {
1245        let client = Client::connect(
1246            test_address(),
1247            &env("ASYNC_SSH2_TEST_HOST_USER"),
1248            AuthKeyboardInteractive::new()
1249                .with_response("Password", env("ASYNC_SSH2_TEST_HOST_PW"))
1250                .into(),
1251            ServerCheckMethod::NoCheck,
1252        )
1253        .await;
1254        assert!(client.is_ok());
1255    }
1256
1257    #[tokio::test]
1258    async fn auth_keyboard_interactive_exact() {
1259        let client = Client::connect(
1260            test_address(),
1261            &env("ASYNC_SSH2_TEST_HOST_USER"),
1262            AuthKeyboardInteractive::new()
1263                .with_response_exact("Password: ", env("ASYNC_SSH2_TEST_HOST_PW"))
1264                .into(),
1265            ServerCheckMethod::NoCheck,
1266        )
1267        .await;
1268        assert!(client.is_ok());
1269    }
1270
1271    #[tokio::test]
1272    async fn auth_keyboard_interactive_wrong_response() {
1273        let client = Client::connect(
1274            test_address(),
1275            &env("ASYNC_SSH2_TEST_HOST_USER"),
1276            AuthKeyboardInteractive::new()
1277                .with_response_exact("Password: ", "wrong password")
1278                .into(),
1279            ServerCheckMethod::NoCheck,
1280        )
1281        .await;
1282        match client {
1283            Err(crate::error::Error::KeyboardInteractiveAuthFailed) => {}
1284            Err(e) => {
1285                panic!("Expected KeyboardInteractiveAuthFailed error. Got error: {e:?}")
1286            }
1287            Ok(_) => panic!("Expected KeyboardInteractiveAuthFailed error."),
1288        }
1289    }
1290
1291    #[tokio::test]
1292    async fn auth_keyboard_interactive_no_response() {
1293        let client = Client::connect(
1294            test_address(),
1295            &env("ASYNC_SSH2_TEST_HOST_USER"),
1296            AuthKeyboardInteractive::new()
1297                .with_response_exact("Password:", "123")
1298                .into(),
1299            ServerCheckMethod::NoCheck,
1300        )
1301        .await;
1302        match client {
1303            Err(crate::error::Error::KeyboardInteractiveNoResponseForPrompt(prompt)) => {
1304                assert_eq!(prompt, "Password: ");
1305            }
1306            Err(e) => {
1307                panic!("Expected KeyboardInteractiveNoResponseForPrompt error. Got error: {e:?}")
1308            }
1309            Ok(_) => panic!("Expected KeyboardInteractiveNoResponseForPrompt error."),
1310        }
1311    }
1312
1313    #[tokio::test]
1314    async fn server_check_file() {
1315        let client = Client::connect(
1316            test_address(),
1317            &env("ASYNC_SSH2_TEST_HOST_USER"),
1318            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1319            ServerCheckMethod::with_public_key_file(&env("ASYNC_SSH2_TEST_SERVER_PUB")),
1320        )
1321        .await;
1322        assert!(client.is_ok());
1323    }
1324
1325    #[tokio::test]
1326    async fn server_check_str() {
1327        let line = std::fs::read_to_string(env("ASYNC_SSH2_TEST_SERVER_PUB")).unwrap();
1328        let mut split = line.split_whitespace();
1329        let key = match (split.next(), split.next()) {
1330            (Some(_), Some(k)) => k,
1331            (Some(k), None) => k,
1332            _ => panic!("Failed to parse pub key file"),
1333        };
1334
1335        let client = Client::connect(
1336            test_address(),
1337            &env("ASYNC_SSH2_TEST_HOST_USER"),
1338            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1339            ServerCheckMethod::with_public_key(key),
1340        )
1341        .await;
1342        assert!(client.is_ok());
1343    }
1344
1345    #[tokio::test]
1346    async fn server_check_by_known_hosts_for_ip() {
1347        let client = Client::connect(
1348            test_address(),
1349            &env("ASYNC_SSH2_TEST_HOST_USER"),
1350            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1351            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1352        )
1353        .await;
1354        assert!(client.is_ok());
1355    }
1356
1357    #[tokio::test]
1358    async fn server_check_by_known_hosts_for_hostname() {
1359        let client = Client::connect(
1360            test_hostname(),
1361            &env("ASYNC_SSH2_TEST_HOST_USER"),
1362            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1363            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1364        )
1365        .await;
1366        if is_running_in_docker() {
1367            assert!(client.is_ok());
1368        } else {
1369            assert!(client.is_err()); // DNS can't find the docker hostname if the rust running without docker container
1370        }
1371    }
1372
1373    #[tokio::test]
1374    async fn client_can_be_cloned() {
1375        let client = establish_test_host_connection().await;
1376        let client2 = client.clone();
1377
1378        let result1 = client.execute("echo test clone").await.unwrap();
1379        let result2 = client2.execute("echo test clone2").await.unwrap();
1380
1381        assert_eq!(result1.stdout, "test clone\n");
1382        assert_eq!(result2.stdout, "test clone2\n");
1383    }
1384
1385    #[tokio::test]
1386    async fn client_can_upload_file() {
1387        let client = establish_test_host_connection().await;
1388        client
1389            .upload_file(&env("ASYNC_SSH2_TEST_UPLOAD_FILE"), "/tmp/uploaded")
1390            .await
1391            .unwrap();
1392        let result = client.execute("cat /tmp/uploaded").await.unwrap();
1393        assert_eq!(result.stdout, "this is a test file\n");
1394    }
1395
1396    #[tokio::test]
1397    async fn client_can_download_file() {
1398        let client = establish_test_host_connection().await;
1399
1400        client
1401            .execute("echo 'this is a downloaded test file' > /tmp/test_download")
1402            .await
1403            .unwrap();
1404
1405        let local_path = std::env::temp_dir().join("downloaded_test_file");
1406        client
1407            .download_file("/tmp/test_download", &local_path)
1408            .await
1409            .unwrap();
1410
1411        let contents = tokio::fs::read_to_string(&local_path).await.unwrap();
1412        assert_eq!(contents, "this is a downloaded test file\n");
1413    }
1414}