async_ssh2_tokio/
client.rs

1use russh::client::KeyboardInteractiveAuthResponse;
2use russh::{
3    Channel,
4    client::{Config, Handle, Handler, Msg},
5};
6use russh_sftp::{client::SftpSession, protocol::OpenFlags};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Instant;
10use std::{fmt::Debug, path::Path};
11use std::{io, path::PathBuf};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::mpsc;
14
15use crate::ToSocketAddrsWithHostname;
16
17/// An authentification token.
18///
19/// Used when creating a [`Client`] for authentification.
20/// Supports password, private key, public key, SSH agent, and keyboard interactive authentication.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22#[non_exhaustive]
23pub enum AuthMethod {
24    Password(String),
25    PrivateKey {
26        /// entire contents of private key file
27        key_data: String,
28        key_pass: Option<String>,
29    },
30    PrivateKeyFile {
31        key_file_path: PathBuf,
32        key_pass: Option<String>,
33    },
34    #[cfg(not(target_os = "windows"))]
35    PublicKeyFile {
36        key_file_path: PathBuf,
37    },
38    #[cfg(not(target_os = "windows"))]
39    Agent,
40    KeyboardInteractive(AuthKeyboardInteractive),
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum SteamingOutput {
45    Stdout(Vec<u8>),
46    Stderr(Vec<u8>),
47    ExitStatus(u32),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51struct PromptResponse {
52    exact: bool,
53    prompt: String,
54    response: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
58#[non_exhaustive]
59pub struct AuthKeyboardInteractive {
60    /// Hnts to the server the preferred methods to be used for authentication.
61    submethods: Option<String>,
62    responses: Vec<PromptResponse>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66#[non_exhaustive]
67pub enum ServerCheckMethod {
68    NoCheck,
69    /// base64 encoded key without the type prefix or hostname suffix (type is already encoded)
70    PublicKey(String),
71    PublicKeyFile(String),
72    DefaultKnownHostsFile,
73    KnownHostsFile(String),
74}
75
76impl AuthMethod {
77    /// Convenience method to create a [`AuthMethod`] from a string literal.
78    pub fn with_password(password: &str) -> Self {
79        Self::Password(password.to_string())
80    }
81
82    pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
83        Self::PrivateKey {
84            key_data: key.to_string(),
85            key_pass: passphrase.map(str::to_string),
86        }
87    }
88
89    pub fn with_key_file<T: AsRef<Path>>(key_file_path: T, passphrase: Option<&str>) -> Self {
90        Self::PrivateKeyFile {
91            key_file_path: key_file_path.as_ref().to_path_buf(),
92            key_pass: passphrase.map(str::to_string),
93        }
94    }
95
96    #[cfg(not(target_os = "windows"))]
97    pub fn with_public_key_file<T: AsRef<Path>>(key_file_path: T) -> Self {
98        Self::PublicKeyFile {
99            key_file_path: key_file_path.as_ref().to_path_buf(),
100        }
101    }
102
103    /// Creates a new SSH agent authentication method.
104    ///
105    /// This will attempt to authenticate using all identities available in the SSH agent.
106    /// The SSH agent must be running and the SSH_AUTH_SOCK environment variable must be set.
107    ///
108    /// # Example
109    /// ```no_run
110    /// use async_ssh2_tokio::client::AuthMethod;
111    ///
112    /// let auth = AuthMethod::with_agent();
113    /// ```
114    ///
115    /// # Platform Support
116    /// This method is only available on Unix-like systems (Linux, macOS, etc.).
117    /// It is not available on Windows.
118    #[cfg(not(target_os = "windows"))]
119    pub fn with_agent() -> Self {
120        Self::Agent
121    }
122
123    pub const fn with_keyboard_interactive(auth: AuthKeyboardInteractive) -> Self {
124        Self::KeyboardInteractive(auth)
125    }
126}
127
128impl AuthKeyboardInteractive {
129    pub fn new() -> Self {
130        Default::default()
131    }
132
133    /// Hnts to the server the preferred methods to be used for authentication.
134    pub fn with_submethods(mut self, submethods: impl Into<String>) -> Self {
135        self.submethods = Some(submethods.into());
136        self
137    }
138
139    /// Adds a response to the list of responses for a given prompt.
140    ///
141    /// The comparison for the prompt is done using a "contains".
142    pub fn with_response(mut self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
143        self.responses.push(PromptResponse {
144            exact: false,
145            prompt: prompt.into(),
146            response: response.into(),
147        });
148
149        self
150    }
151
152    /// Adds a response to the list of responses for a given exact prompt.
153    pub fn with_response_exact(
154        mut self,
155        prompt: impl Into<String>,
156        response: impl Into<String>,
157    ) -> Self {
158        self.responses.push(PromptResponse {
159            exact: true,
160            prompt: prompt.into(),
161            response: response.into(),
162        });
163
164        self
165    }
166}
167
168impl PromptResponse {
169    fn matches(&self, received_prompt: &str) -> bool {
170        if self.exact {
171            self.prompt.eq(received_prompt)
172        } else {
173            received_prompt.contains(&self.prompt)
174        }
175    }
176}
177
178impl From<AuthKeyboardInteractive> for AuthMethod {
179    fn from(value: AuthKeyboardInteractive) -> Self {
180        Self::with_keyboard_interactive(value)
181    }
182}
183
184impl ServerCheckMethod {
185    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
186    pub fn with_public_key(key: &str) -> Self {
187        Self::PublicKey(key.to_string())
188    }
189
190    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
191    pub fn with_public_key_file(key_file_name: &str) -> Self {
192        Self::PublicKeyFile(key_file_name.to_string())
193    }
194
195    /// Convenience method to create a [`ServerCheckMethod`] from a string literal.
196    pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
197        Self::KnownHostsFile(known_hosts_file.to_string())
198    }
199}
200
201/// A ssh connection to a remote server.
202///
203/// After creating a `Client` by [`connect`]ing to a remote host,
204/// use [`execute`] to send commands and receive results through the connections.
205///
206/// [`connect`]: Client::connect
207/// [`execute`]: Client::execute
208///
209/// # Examples
210///
211/// ```no_run
212/// use async_ssh2_tokio::{Client, AuthMethod, ServerCheckMethod};
213/// #[tokio::main]
214/// async fn main() -> Result<(), async_ssh2_tokio::Error> {
215///     let mut client = Client::connect(
216///         ("10.10.10.2", 22),
217///         "root",
218///         AuthMethod::with_password("root"),
219///         ServerCheckMethod::NoCheck,
220///     ).await?;
221///
222///     let result = client.execute("echo Hello SSH").await?;
223///     assert_eq!(result.stdout, "Hello SSH\n");
224///     assert_eq!(result.exit_status, 0);
225///
226///     Ok(())
227/// }
228#[derive(Clone)]
229pub struct Client {
230    connection_handle: Arc<Handle<ClientHandler>>,
231    username: String,
232    address: SocketAddr,
233}
234
235impl Client {
236    /// Open a ssh connection to a remote host.
237    ///
238    /// `addr` is an address of the remote host. Anything which implements
239    /// [`ToSocketAddrsWithHostname`] trait can be supplied for the address;
240    /// ToSocketAddrsWithHostname reimplements all of [`ToSocketAddrs`];
241    /// see this trait's documentation for concrete examples.
242    ///
243    /// If `addr` yields multiple addresses, `connect` will be attempted with
244    /// each of the addresses until a connection is successful.
245    /// Authentification is tried on the first successful connection and the whole
246    /// process aborted if this fails.
247    pub async fn connect(
248        addr: impl ToSocketAddrsWithHostname,
249        username: &str,
250        auth: AuthMethod,
251        server_check: ServerCheckMethod,
252    ) -> Result<Self, crate::Error> {
253        Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
254    }
255
256    /// Same as `connect`, but with the option to specify a non default
257    /// [`russh::client::Config`].
258    pub async fn connect_with_config(
259        addr: impl ToSocketAddrsWithHostname,
260        username: &str,
261        auth: AuthMethod,
262        server_check: ServerCheckMethod,
263        config: Config,
264    ) -> Result<Self, crate::Error> {
265        let config = Arc::new(config);
266
267        // Connection code inspired from std::net::TcpStream::connect and std::net::each_addr
268        let socket_addrs = addr
269            .to_socket_addrs()
270            .map_err(crate::Error::AddressInvalid)?;
271        let mut connect_res = Err(crate::Error::AddressInvalid(io::Error::new(
272            io::ErrorKind::InvalidInput,
273            "could not resolve to any addresses",
274        )));
275        for socket_addr in socket_addrs {
276            let handler = ClientHandler {
277                hostname: addr.hostname(),
278                host: socket_addr,
279                server_check: server_check.clone(),
280            };
281            match russh::client::connect(config.clone(), socket_addr, handler).await {
282                Ok(h) => {
283                    connect_res = Ok((socket_addr, h));
284                    break;
285                }
286                Err(e) => connect_res = Err(e),
287            }
288        }
289        let (address, mut handle) = connect_res?;
290        let username = username.to_string();
291
292        Self::authenticate(&mut handle, &username, auth).await?;
293
294        Ok(Self {
295            connection_handle: Arc::new(handle),
296            username,
297            address,
298        })
299    }
300
301    /// This takes a handle and performs authentification with the given method.
302    async fn authenticate(
303        handle: &mut Handle<ClientHandler>,
304        username: &String,
305        auth: AuthMethod,
306    ) -> Result<(), crate::Error> {
307        match auth {
308            AuthMethod::Password(password) => {
309                let is_authentificated = handle.authenticate_password(username, password).await?;
310                if !is_authentificated.success() {
311                    return Err(crate::Error::PasswordWrong);
312                }
313            }
314            AuthMethod::PrivateKey { key_data, key_pass } => {
315                let cprivk = russh::keys::decode_secret_key(key_data.as_str(), key_pass.as_deref())
316                    .map_err(crate::Error::KeyInvalid)?;
317                let is_authentificated = handle
318                    .authenticate_publickey(
319                        username,
320                        russh::keys::PrivateKeyWithHashAlg::new(
321                            Arc::new(cprivk),
322                            handle.best_supported_rsa_hash().await?.flatten(),
323                        ),
324                    )
325                    .await?;
326                if !is_authentificated.success() {
327                    return Err(crate::Error::KeyAuthFailed);
328                }
329            }
330            AuthMethod::PrivateKeyFile {
331                key_file_path,
332                key_pass,
333            } => {
334                let cprivk = russh::keys::load_secret_key(key_file_path, key_pass.as_deref())
335                    .map_err(crate::Error::KeyInvalid)?;
336                let is_authentificated = handle
337                    .authenticate_publickey(
338                        username,
339                        russh::keys::PrivateKeyWithHashAlg::new(
340                            Arc::new(cprivk),
341                            handle.best_supported_rsa_hash().await?.flatten(),
342                        ),
343                    )
344                    .await?;
345                if !is_authentificated.success() {
346                    return Err(crate::Error::KeyAuthFailed);
347                }
348            }
349            #[cfg(not(target_os = "windows"))]
350            AuthMethod::PublicKeyFile { key_file_path } => {
351                let cpubk = russh::keys::load_public_key(key_file_path)
352                    .map_err(crate::Error::KeyInvalid)?;
353                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
354                    .await
355                    .unwrap();
356                let mut auth_identity: Option<russh::keys::PublicKey> = None;
357                for identity in agent
358                    .request_identities()
359                    .await
360                    .map_err(crate::Error::KeyInvalid)?
361                {
362                    if identity == cpubk {
363                        auth_identity = Some(identity.clone());
364                        break;
365                    }
366                }
367
368                if auth_identity.is_none() {
369                    return Err(crate::Error::KeyAuthFailed);
370                }
371
372                let is_authentificated = handle
373                    .authenticate_publickey_with(
374                        username,
375                        cpubk,
376                        handle.best_supported_rsa_hash().await?.flatten(),
377                        &mut agent,
378                    )
379                    .await?;
380                if !is_authentificated.success() {
381                    return Err(crate::Error::KeyAuthFailed);
382                }
383            }
384            #[cfg(not(target_os = "windows"))]
385            AuthMethod::Agent => {
386                let mut agent = russh::keys::agent::client::AgentClient::connect_env()
387                    .await
388                    .map_err(|_| crate::Error::AgentConnectionFailed)?;
389
390                let identities = agent
391                    .request_identities()
392                    .await
393                    .map_err(|_| crate::Error::AgentRequestIdentitiesFailed)?;
394
395                if identities.is_empty() {
396                    return Err(crate::Error::AgentNoIdentities);
397                }
398
399                let mut auth_success = false;
400                for identity in identities {
401                    let result = handle
402                        .authenticate_publickey_with(
403                            username,
404                            identity.clone(),
405                            handle.best_supported_rsa_hash().await?.flatten(),
406                            &mut agent,
407                        )
408                        .await;
409
410                    if let Ok(auth_result) = result
411                        && auth_result.success()
412                    {
413                        auth_success = true;
414                        break;
415                    }
416                }
417
418                if !auth_success {
419                    return Err(crate::Error::AgentAuthenticationFailed);
420                }
421            }
422            AuthMethod::KeyboardInteractive(mut kbd) => {
423                let mut res = handle
424                    .authenticate_keyboard_interactive_start(username, kbd.submethods)
425                    .await?;
426                loop {
427                    let prompts = match res {
428                        KeyboardInteractiveAuthResponse::Success => break,
429                        KeyboardInteractiveAuthResponse::Failure { .. } => {
430                            return Err(crate::Error::KeyboardInteractiveAuthFailed);
431                        }
432                        KeyboardInteractiveAuthResponse::InfoRequest { prompts, .. } => prompts,
433                    };
434
435                    let mut responses = vec![];
436                    for prompt in prompts {
437                        let Some(pos) = kbd
438                            .responses
439                            .iter()
440                            .position(|pr| pr.matches(&prompt.prompt))
441                        else {
442                            return Err(crate::Error::KeyboardInteractiveNoResponseForPrompt(
443                                prompt.prompt,
444                            ));
445                        };
446                        let pr = kbd.responses.remove(pos);
447                        responses.push(pr.response);
448                    }
449
450                    res = handle
451                        .authenticate_keyboard_interactive_respond(responses)
452                        .await?;
453                }
454            }
455        };
456        Ok(())
457    }
458
459    pub async fn get_channel(&self) -> Result<Channel<Msg>, crate::Error> {
460        self.connection_handle
461            .channel_open_session()
462            .await
463            .map_err(crate::Error::SshError)
464    }
465
466    /// Open a TCP/IP forwarding channel.
467    ///
468    /// This opens a `direct-tcpip` channel to the given target.
469    pub async fn open_direct_tcpip_channel<
470        T: ToSocketAddrsWithHostname,
471        S: Into<Option<SocketAddr>>,
472    >(
473        &self,
474        target: T,
475        src: S,
476    ) -> Result<Channel<Msg>, crate::Error> {
477        let targets = target
478            .to_socket_addrs()
479            .map_err(crate::Error::AddressInvalid)?;
480        let src = src
481            .into()
482            .map(|src| (src.ip().to_string(), src.port().into()))
483            .unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
484
485        let mut connect_err = crate::Error::AddressInvalid(io::Error::new(
486            io::ErrorKind::InvalidInput,
487            "could not resolve to any addresses",
488        ));
489        for target in targets {
490            match self
491                .connection_handle
492                .channel_open_direct_tcpip(
493                    target.ip().to_string(),
494                    target.port().into(),
495                    src.0.clone(),
496                    src.1,
497                )
498                .await
499            {
500                Ok(channel) => return Ok(channel),
501                Err(err) => connect_err = crate::Error::SshError(err),
502            }
503        }
504
505        Err(connect_err)
506    }
507
508    /// Upload a file with sftp to the remote server.
509    ///
510    /// `src_file_path` is the path to the file on the local machine.
511    /// `dest_file_path` is the path to the file on the remote machine.
512    /// 'timeout_seconds' is the timeout, in seconds, for the operation, passed on to the sftp session.
513    /// If not specified it will default to the underlying value in the sftp code, which as of this writing is 10 seconds.
514    /// 'buffer_size_in_bytes' is the value this function will buffer the file through.  it defaults to 4KB.
515    /// 'show_progress' if true, logs will be emitted every 5% of the file upload, measured in bytes
516    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
517    /// A config line like a `Subsystem sftp internal-sftp` or
518    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
519    pub async fn upload_file<T, U>(
520        &self,
521        src_file_path: T,
522        //fa993: This cannot be AsRef<Path> because of underlying lib constraints as described here
523        //https://github.com/AspectUnk/russh-sftp/issues/7#issuecomment-1738355245
524        dest_file_path: U,
525        timeout_seconds: Option<u64>,
526        buffer_size_in_bytes: Option<usize>,
527        show_progress: bool,
528    ) -> Result<(), crate::Error>
529    where
530        T: AsRef<Path> + std::fmt::Display,
531        U: Into<String>,
532    {
533        // start sftp session
534        let channel = self.get_channel().await?;
535        channel.request_subsystem(true, "sftp").await?;
536        let sftp = SftpSession::new_opts(channel.into_stream(), timeout_seconds).await?;
537
538        let file_size = tokio::fs::metadata(&src_file_path).await?.len();
539        // read file contents locally
540        let local_file = tokio::fs::File::open(&src_file_path)
541            .await
542            .map_err(crate::Error::IoError)?;
543        let mut local_file_buffered = tokio::io::BufReader::new(local_file);
544
545        let dest_file_path = dest_file_path.into();
546        let mut remote_file = sftp
547            .open_with_flags(
548                dest_file_path.clone(),
549                OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
550            )
551            .await?;
552
553        let buffer_size_in_bytes = buffer_size_in_bytes.unwrap_or(4096);
554        let mut buffer = vec![0; buffer_size_in_bytes];
555
556        let mut total_bytes_copied = 0;
557        let mut next_progress_marker = 5.0;
558
559        let start_time = Instant::now();
560        if show_progress {
561            log::info!(
562                "Starting file upload from {src_file_path} to {dest_file_path}, total bytes to be transferred: {}",
563                file_size
564            );
565        }
566        loop {
567            let n = local_file_buffered.read(&mut buffer).await?;
568            if n == 0 {
569                break;
570            }
571            remote_file
572                .write_all(&buffer[..n])
573                .await
574                .map_err(crate::Error::IoError)?;
575            if show_progress {
576                total_bytes_copied += n as u64;
577                let progress = (total_bytes_copied as f64 / file_size as f64) * 100.0;
578                if progress >= next_progress_marker {
579                    log::info!(
580                        "Progress of upload from {src_file_path} to {dest_file_path}: {:.0}% in elapsed time: {}s",
581                        next_progress_marker,
582                        start_time.elapsed().as_secs_f64()
583                    );
584                    next_progress_marker += 5.0;
585                }
586            }
587        }
588
589        if show_progress {
590            log::info!(
591                "file upload comprising {file_size} bytes from {src_file_path} to {dest_file_path} completed successfully in {}s",
592                start_time.elapsed().as_secs_f64()
593            );
594        }
595        remote_file
596            .shutdown()
597            .await
598            .map_err(crate::Error::IoError)?;
599
600        Ok(())
601    }
602
603    /// Download a file from the remote server using sftp.
604    ///
605    /// `remote_file_path` is the path to the file on the remote machine.
606    /// `local_file_path` is the path to the file on the local machine.
607    /// Some sshd_config does not enable sftp by default, so make sure it is enabled.
608    /// A config line like a `Subsystem sftp internal-sftp` or
609    /// `Subsystem sftp /usr/lib/openssh/sftp-server` is needed in the sshd_config in remote machine.
610    pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
611        &self,
612        remote_file_path: U,
613        local_file_path: T,
614    ) -> Result<(), crate::Error> {
615        // start sftp session
616        let channel = self.get_channel().await?;
617        channel.request_subsystem(true, "sftp").await?;
618        let sftp = SftpSession::new(channel.into_stream()).await?;
619
620        // open remote file for reading
621        let mut remote_file = sftp
622            .open_with_flags(remote_file_path, OpenFlags::READ)
623            .await?;
624
625        // read remote file contents
626        let mut contents = Vec::new();
627        remote_file.read_to_end(contents.as_mut()).await?;
628
629        // write contents to local file
630        let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
631            .await
632            .map_err(crate::Error::IoError)?;
633
634        local_file
635            .write_all(&contents)
636            .await
637            .map_err(crate::Error::IoError)?;
638        local_file.flush().await.map_err(crate::Error::IoError)?;
639
640        Ok(())
641    }
642
643    /// Execute a remote command via the ssh connection.
644    ///
645    /// Returns stdout, stderr and the exit code of the command,
646    /// packaged in a [`CommandExecutedResult`] struct.
647    /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection,
648    /// e.g. `echo foo 2>&1`.
649    /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`.
650    ///
651    /// Make sure your commands don't read from stdin and exit after bounded time.
652    ///
653    /// Can be called multiple times, but every invocation is a new shell context.
654    /// Thus `cd`, setting variables and alike have no effect on future invocations.
655    pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, crate::Error> {
656        let mut stdout_buffer = vec![];
657        let mut stderr_buffer = vec![];
658        let mut channel = self.connection_handle.channel_open_session().await?;
659        channel.exec(true, command).await?;
660
661        let mut result: Option<u32> = None;
662
663        // While the channel has messages...
664        while let Some(msg) = channel.wait().await {
665            //dbg!(&msg);
666            match msg {
667                // If we get data, add it to the buffer
668                russh::ChannelMsg::Data { ref data } => {
669                    stdout_buffer.write_all(data).await.unwrap()
670                }
671                russh::ChannelMsg::ExtendedData { ref data, ext } => {
672                    if ext == 1 {
673                        stderr_buffer.write_all(data).await.unwrap()
674                    }
675                }
676
677                // If we get an exit code report, store it, but crucially don't
678                // assume this message means end of communications. The data might
679                // not be finished yet!
680                russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
681
682                // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits
683                // the channel to close without it being sent. And sometimes this
684                // message can even precede the Data message, so don't handle it
685                // russh::ChannelMsg::Eof => break,
686                _ => {}
687            }
688        }
689
690        // If we received an exit code, report it back
691        if let Some(result) = result {
692            Ok(CommandExecutedResult {
693                stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
694                stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
695                exit_status: result,
696            })
697
698        // Otherwise, report an error
699        } else {
700            Err(crate::Error::CommandDidntExit)
701        }
702    }
703
704    /// Execute a remote command via the ssh connection.
705    ///
706    /// Command output is stream to the provided channel. Returns the exit code.
707    /// The channel sends `SteamingOutput` enum variants to distinguish stdout,
708    /// stderr and exit code so message arrive interleaved and in the order
709    /// they are received. See `execute` for more details.
710    ///
711    #[deprecated(
712        since = "0.11.0",
713        note = "Use execute_io with channels directly for more flexibility.\n\
714              This method will be removed or introduced breaking changes in future versions.\n\
715              At minimum, SteamingOutput will be renamed to StreamingOutput"
716    )]
717    pub async fn execute_streaming(
718        &self,
719        command: &str,
720        ch: tokio::sync::mpsc::Sender<SteamingOutput>,
721    ) -> Result<u32, crate::Error> {
722        let (stdout_tx, mut stdout_rx) = tokio::sync::mpsc::channel(1);
723        let (stderr_tx, mut stderr_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1);
724
725        let exec_future = self.execute_io(command, stdout_tx, Some(stderr_tx), None, false, None);
726        tokio::pin!(exec_future);
727        let result = loop {
728            tokio::select! {
729                result = &mut exec_future => break result,
730                Some(stdout) = stdout_rx.recv() => {
731                    ch.send(SteamingOutput::Stdout(stdout)).await.unwrap();
732                },
733                Some(stderr) = stderr_rx.recv() => {
734                    ch.send(SteamingOutput::Stderr(stderr)).await.unwrap();
735                },
736            };
737        }?;
738        // see if any output is left in the channels
739        if let Some(stdout) = stdout_rx.recv().await {
740            ch.send(SteamingOutput::Stdout(stdout)).await.unwrap();
741        }
742        if let Some(stderr) = stderr_rx.recv().await {
743            ch.send(SteamingOutput::Stderr(stderr)).await.unwrap();
744        }
745        ch.send(SteamingOutput::ExitStatus(result)).await.unwrap();
746        Ok(result)
747    }
748
749    /// Execute a remote command via the ssh connection and perform i/o via channels.
750    ///
751    /// `execute_io` does the same as `execute`, but ties stdin and stdout/stderr to channels.
752    /// Giving a stdin channel is optional. If there is only a stdout channel, stderr will be
753    /// sent to the stdout channel. Sending an empty string to the stdin channel will send an
754    /// EOF to the remote side.
755    /// If `request_pty` is true, a pseudo terminal is requested for the session. This is
756    /// sometime necessary for example to enter a password, which is not request via stdin
757    /// but directly from the terminal. NOTE: A pty has no stderr, so stderr output is
758    /// sent to the stdout channel.
759    /// The exit code of the command is returned as a result. If the remote ssh server
760    /// does not report an exit code, a default exit code can be passed, otherwise an error
761    /// is returned.
762    ///
763    /// Example:
764    ///
765    /// ```no_run
766    /// use async_ssh2_tokio::{Client, AuthMethod, ServerCheckMethod};
767    /// use tokio::sync::mpsc;
768    ///
769    /// #[tokio::main]
770    /// async fn main() -> Result<(), async_ssh2_tokio::Error> {
771    ///     let mut client = Client::connect(
772    ///         ("10.10.10.2", 22),
773    ///         "root",
774    ///         AuthMethod::with_password("root"),
775    ///         ServerCheckMethod::NoCheck,
776    ///     ).await?;
777    ///     let mut result_stdout = vec![];
778    ///     let mut result_stderr = vec![];
779    ///
780    ///     let (stdout_tx, mut stdout_rx) = mpsc::channel(10);
781    ///     let (stderr_tx, mut stderr_rx) = mpsc::channel(10);
782    ///     let cmd = "date";
783    ///     let exec_future = client.execute_io(&cmd, stdout_tx, Some(stderr_tx), None, false, None);
784    ///     tokio::pin!(exec_future);
785    ///     let result = loop {
786    ///         tokio::select! {
787    ///             result = &mut exec_future => break result,
788    ///             Some(stdout) = stdout_rx.recv() => {
789    ///                 println!("ssh stdout: {}", String::from_utf8_lossy(&stdout));
790    ///                 result_stdout.push(stdout);
791    ///             },
792    ///             Some(stderr) = stderr_rx.recv() => {
793    ///                 println!("ssh stderr: {}", String::from_utf8_lossy(&stderr));
794    ///                 result_stderr.push(stderr);
795    ///             },
796    ///         };
797    ///     }?;
798    ///
799    ///     // see if any output is left in the channels
800    ///     if let Some(stdout) = stdout_rx.recv().await {
801    ///         println!("ssh stdout: {}", String::from_utf8_lossy(&stdout));
802    ///         result_stdout.push(stdout);
803    ///     }
804    ///     if let Some(stderr) = stderr_rx.recv().await {
805    ///         println!("ssh stderr: {}", String::from_utf8_lossy(&stderr));
806    ///         result_stderr.push(stderr);
807    ///     }
808    ///     Ok(())
809    /// }
810    /// ```
811    pub async fn execute_io(
812        &self,
813        command: &str,
814        stdout_channel: mpsc::Sender<Vec<u8>>,
815        stderr_channel: Option<mpsc::Sender<Vec<u8>>>,
816        mut stdin_channel: Option<mpsc::Receiver<Vec<u8>>>,
817        request_pty: bool,
818        default_exit_code: Option<u32>,
819    ) -> Result<u32, crate::Error> {
820        let mut channel = self.connection_handle.channel_open_session().await?;
821
822        let mut result: Option<u32> = None;
823        if request_pty {
824            channel
825                .request_pty(false, "xterm", 80_u32, 24_u32, 0, 0, &[])
826                .await?;
827        }
828
829        channel.exec(true, command).await?;
830
831        // While the channel has messages...
832        loop {
833            let recv_stdin = async {
834                if let Some(ch) = stdin_channel.as_mut() {
835                    Some(ch.recv().await)
836                } else {
837                    None
838                }
839            };
840            tokio::select! {
841                Some(input) = recv_stdin => {
842                    if let Some(input) = input {
843                        if input.is_empty() {
844                            channel.eof().await? ;
845                        } else {
846                            channel.data(&input as &[u8]).await?;
847                        }
848                    }
849                },
850                msg = channel.wait() => {
851                    //dbg!(&msg);
852                    match msg {
853                        // If we get data, add it to the buffer
854                        Some(russh::ChannelMsg::Data { ref data }) => {
855                            //dbg!("sending stdout");
856                            stdout_channel
857                                .send(data.to_vec())
858                                .await
859                                .map_err(crate::Error::ChannelSendError)?;
860                        }
861                        Some (russh::ChannelMsg::ExtendedData { ref data, ext }) => {
862                            if ext == 1 {
863                                if let Some(stderr_channel) = &stderr_channel {
864                                    //dbg!("sending stderr");
865                                    stderr_channel
866                                        .send(data.to_vec())
867                                        .await
868                                        .map_err(crate::Error::ChannelSendError)?;
869                                } else {
870                                    //dbg!("sending stderr to stdout");
871                                    stdout_channel
872                                        .send(data.to_vec())
873                                        .await
874                                        .map_err(crate::Error::ChannelSendError)?;
875                                }
876                            }
877                        }
878
879                        // If we get an exit code report, store it, but crucially don't
880                        // assume this message means end of communications. The data might
881                        // not be finished yet!
882                        Some (russh::ChannelMsg::ExitStatus { exit_status }) => result = Some(exit_status),
883
884                        // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits
885                        // the channel to close without it being sent. And sometimes this
886                        // message can even precede the Data message, so don't handle it
887                        // russh::ChannelMsg::Eof => break,
888                        Some (_) => {},
889                        None => break,
890                    }
891                }
892            }
893        }
894
895        // If we received an exit code, report it back
896        if let Some(result) = result {
897            Ok(result)
898        // If we have an default exit code, report it back
899        } else if let Some(default_exit_code) = default_exit_code {
900            Ok(default_exit_code)
901        // Otherwise, report an error
902        } else {
903            Err(crate::Error::CommandDidntExit)
904        }
905    }
906
907    /// A debugging function to get the username this client is connected as.
908    pub fn get_connection_username(&self) -> &String {
909        &self.username
910    }
911
912    /// A debugging function to get the address this client is connected to.
913    pub fn get_connection_address(&self) -> &SocketAddr {
914        &self.address
915    }
916
917    pub async fn disconnect(&self) -> Result<(), crate::Error> {
918        self.connection_handle
919            .disconnect(russh::Disconnect::ByApplication, "", "")
920            .await
921            .map_err(crate::Error::SshError)
922    }
923
924    pub fn is_closed(&self) -> bool {
925        self.connection_handle.is_closed()
926    }
927}
928
929impl Debug for Client {
930    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
931        f.debug_struct("Client")
932            .field("username", &self.username)
933            .field("address", &self.address)
934            .field("connection_handle", &"Handle<ClientHandler>")
935            .finish()
936    }
937}
938
939#[derive(Debug, Clone, PartialEq, Eq, Hash)]
940pub struct CommandExecutedResult {
941    /// The stdout output of the command.
942    pub stdout: String,
943    /// The stderr output of the command.
944    pub stderr: String,
945    /// The unix exit status (`$?` in bash).
946    pub exit_status: u32,
947}
948
949#[derive(Debug, Clone)]
950struct ClientHandler {
951    hostname: String,
952    host: SocketAddr,
953    server_check: ServerCheckMethod,
954}
955
956impl Handler for ClientHandler {
957    type Error = crate::Error;
958
959    async fn check_server_key(
960        &mut self,
961        server_public_key: &russh::keys::PublicKey,
962    ) -> Result<bool, Self::Error> {
963        match &self.server_check {
964            ServerCheckMethod::NoCheck => Ok(true),
965            ServerCheckMethod::PublicKey(key) => {
966                let pk = russh::keys::parse_public_key_base64(key)
967                    .map_err(|_| crate::Error::ServerCheckFailed)?;
968
969                Ok(pk == *server_public_key)
970            }
971            ServerCheckMethod::PublicKeyFile(key_file_name) => {
972                let pk = russh::keys::load_public_key(key_file_name)
973                    .map_err(|_| crate::Error::ServerCheckFailed)?;
974
975                Ok(pk == *server_public_key)
976            }
977            ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
978                let result = russh::keys::check_known_hosts_path(
979                    &self.hostname,
980                    self.host.port(),
981                    server_public_key,
982                    known_hosts_path,
983                )
984                .map_err(|_| crate::Error::ServerCheckFailed)?;
985
986                Ok(result)
987            }
988            ServerCheckMethod::DefaultKnownHostsFile => {
989                let result = russh::keys::check_known_hosts(
990                    &self.hostname,
991                    self.host.port(),
992                    server_public_key,
993                )
994                .map_err(|_| crate::Error::ServerCheckFailed)?;
995
996                Ok(result)
997            }
998        }
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    #![allow(deprecated, clippy::useless_vec)]
1005
1006    use crate::client::*;
1007    use core::time;
1008    use dotenv::dotenv;
1009    use std::path::Path;
1010    use std::sync::Once;
1011    use tokio::io::AsyncReadExt;
1012    static INIT: Once = Once::new();
1013
1014    fn initialize() {
1015        // Perform your initialization tasks here
1016        println!("Running initialization code before tests...");
1017        // Example: load .env file if we are using non-docker environment
1018        if is_running_in_docker() {
1019            println!("Running inside Docker.");
1020        } else {
1021            println!("Not running inside Docker. Load env from file");
1022            dotenv().ok();
1023        }
1024    }
1025    fn is_running_in_docker() -> bool {
1026        Path::new("/.dockerenv").exists() || check_cgroup()
1027    }
1028
1029    fn check_cgroup() -> bool {
1030        match std::fs::read_to_string("/proc/1/cgroup") {
1031            Ok(contents) => contents.contains("docker"),
1032            Err(_) => false,
1033        }
1034    }
1035
1036    fn env(name: &str) -> String {
1037        INIT.call_once(|| {
1038            initialize();
1039        });
1040        std::env::var(name).unwrap_or_else(|_| {
1041            panic!(
1042                "Failed to get env var needed for test, make sure to set the following env var: {name}",
1043            )
1044        })
1045    }
1046
1047    fn test_address() -> SocketAddr {
1048        format!(
1049            "{}:{}",
1050            env("ASYNC_SSH2_TEST_HOST_IP"),
1051            env("ASYNC_SSH2_TEST_HOST_PORT")
1052        )
1053        .parse()
1054        .unwrap()
1055    }
1056
1057    fn test_hostname() -> impl ToSocketAddrsWithHostname {
1058        (
1059            env("ASYNC_SSH2_TEST_HOST_NAME"),
1060            env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
1061        )
1062    }
1063
1064    async fn establish_test_host_connection() -> Client {
1065        Client::connect(
1066            (
1067                env("ASYNC_SSH2_TEST_HOST_IP"),
1068                env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
1069            ),
1070            &env("ASYNC_SSH2_TEST_HOST_USER"),
1071            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1072            ServerCheckMethod::NoCheck,
1073        )
1074        .await
1075        .expect("Connection/Authentification failed")
1076    }
1077
1078    #[tokio::test]
1079    async fn connect_with_password() {
1080        let client = establish_test_host_connection().await;
1081        assert_eq!(
1082            &env("ASYNC_SSH2_TEST_HOST_USER"),
1083            client.get_connection_username(),
1084        );
1085        assert_eq!(test_address(), *client.get_connection_address(),);
1086    }
1087
1088    #[tokio::test]
1089    async fn execute_command_result() {
1090        let client = establish_test_host_connection().await;
1091        let output = client.execute("echo test!!!").await.unwrap();
1092        assert_eq!("test!!!\n", output.stdout);
1093        assert_eq!("", output.stderr);
1094        assert_eq!(0, output.exit_status);
1095    }
1096
1097    #[tokio::test]
1098    async fn execute_streaming_command_result() {
1099        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1100        let client = establish_test_host_connection().await;
1101        let result = client.execute_streaming("echo test!!!", tx).await.unwrap();
1102        let mut output = Vec::new();
1103        while let Some(msg) = rx.recv().await {
1104            output.push(msg);
1105        }
1106        assert_eq!(0, result);
1107        assert_eq!(
1108            &[
1109                SteamingOutput::Stdout(b"test!!!\n".to_vec()),
1110                SteamingOutput::ExitStatus(0),
1111            ],
1112            output.as_slice(),
1113        );
1114    }
1115
1116    #[tokio::test]
1117    async fn execute_command_result_stderr() {
1118        let client = establish_test_host_connection().await;
1119        let output = client.execute("echo test!!! 1>&2").await.unwrap();
1120        assert_eq!("", output.stdout);
1121        assert_eq!("test!!!\n", output.stderr);
1122        assert_eq!(0, output.exit_status);
1123    }
1124
1125    #[tokio::test]
1126    async fn execute_streaming_command_result_stderr() {
1127        let client = establish_test_host_connection().await;
1128        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1129        let result = client
1130            .execute_streaming("echo test!!! 1>&2", tx)
1131            .await
1132            .unwrap();
1133        let mut output = Vec::new();
1134        while let Some(msg) = rx.recv().await {
1135            output.push(msg);
1136        }
1137        assert_eq!(0, result);
1138        assert_eq!(
1139            &[
1140                SteamingOutput::Stderr(b"test!!!\n".to_vec()),
1141                SteamingOutput::ExitStatus(0),
1142            ],
1143            output.as_slice()
1144        );
1145    }
1146
1147    #[tokio::test]
1148    async fn unicode_output() {
1149        let client = establish_test_host_connection().await;
1150        let output = client.execute("echo To thḙ moon! 🚀").await.unwrap();
1151        assert_eq!("To thḙ moon! 🚀\n", output.stdout);
1152        assert_eq!(0, output.exit_status);
1153    }
1154
1155    #[tokio::test]
1156    async fn execute_command_status() {
1157        let client = establish_test_host_connection().await;
1158        let output = client.execute("exit 42").await.unwrap();
1159        assert_eq!(42, output.exit_status);
1160    }
1161
1162    #[tokio::test]
1163    async fn execute_streaming_command_status() {
1164        let client = establish_test_host_connection().await;
1165        let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1166        let result = client.execute_streaming("exit 42", tx).await.unwrap();
1167        let mut output = Vec::new();
1168        while let Some(msg) = rx.recv().await {
1169            output.push(msg);
1170        }
1171        assert_eq!(42, result);
1172        assert_eq!(&[SteamingOutput::ExitStatus(42),], output.as_slice());
1173    }
1174
1175    #[tokio::test]
1176    async fn execute_io_command() {
1177        let client = establish_test_host_connection().await;
1178        let (stdout_tx, mut stdout_rx) = tokio::sync::mpsc::channel(10);
1179        let (stderr_tx, mut stderr_rx) = tokio::sync::mpsc::channel(10);
1180        let cmd = "echo out1; echo err1 1>&2; echo out2; echo err2 1>&2; exit 7";
1181        let exec_future = client.execute_io(cmd, stdout_tx, Some(stderr_tx), None, false, None);
1182        tokio::pin!(exec_future);
1183        let mut result: Option<u32> = None;
1184        let mut stdout_output = vec![];
1185        let mut stderr_output = vec![];
1186        loop {
1187            tokio::select! {
1188                result_inner = &mut exec_future => {
1189                    result = Some(result_inner.unwrap());
1190                },
1191                Some(stdout) = stdout_rx.recv() => {
1192                    stdout_output.push(stdout);
1193                },
1194                Some(stderr) = stderr_rx.recv() => {
1195                    stderr_output.push(stderr);
1196                },
1197            };
1198            if result.is_some() {
1199                break;
1200            }
1201        }
1202        assert_eq!(Some(7), result);
1203        assert_eq!(
1204            vec![b"out1\n".to_vec(), b"out2\n".to_vec()].concat(),
1205            stdout_output.concat()
1206        );
1207        assert_eq!(
1208            vec![b"err1\n".to_vec(), b"err2\n".to_vec()].concat(),
1209            stderr_output.concat()
1210        );
1211    }
1212
1213    #[tokio::test]
1214    async fn execute_multiple_commands() {
1215        let client = establish_test_host_connection().await;
1216        let output = client.execute("echo test!!!").await.unwrap().stdout;
1217        assert_eq!("test!!!\n", output);
1218
1219        let output = client.execute("echo Hello World").await.unwrap().stdout;
1220        assert_eq!("Hello World\n", output);
1221    }
1222
1223    #[tokio::test]
1224    async fn direct_tcpip_channel() {
1225        let client = establish_test_host_connection().await;
1226        let channel = client
1227            .open_direct_tcpip_channel(
1228                format!(
1229                    "{}:{}",
1230                    env("ASYNC_SSH2_TEST_HTTP_SERVER_IP"),
1231                    env("ASYNC_SSH2_TEST_HTTP_SERVER_PORT"),
1232                ),
1233                None,
1234            )
1235            .await
1236            .unwrap();
1237
1238        let mut stream = channel.into_stream();
1239        stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await.unwrap();
1240
1241        let mut response = String::new();
1242        stream.read_to_string(&mut response).await.unwrap();
1243
1244        let body = response.split_once("\r\n\r\n").unwrap().1;
1245        assert_eq!("Hello", body);
1246    }
1247
1248    #[tokio::test]
1249    async fn stderr_redirection() {
1250        let client = establish_test_host_connection().await;
1251
1252        let output = client.execute("echo foo >/dev/null").await.unwrap();
1253        assert_eq!("", output.stdout);
1254
1255        let output = client.execute("echo foo >>/dev/stderr").await.unwrap();
1256        assert_eq!("", output.stdout);
1257
1258        let output = client.execute("2>&1 echo foo >>/dev/stderr").await.unwrap();
1259        assert_eq!("foo\n", output.stdout);
1260    }
1261
1262    #[tokio::test]
1263    async fn sequential_commands() {
1264        let client = establish_test_host_connection().await;
1265
1266        for i in 0..100 {
1267            std::thread::sleep(time::Duration::from_millis(100));
1268            let res = client
1269                .execute(&format!("echo {i}"))
1270                .await
1271                .unwrap_or_else(|_| panic!("Execution failed in iteration {i}"));
1272            assert_eq!(format!("{i}\n"), res.stdout);
1273        }
1274    }
1275
1276    #[tokio::test]
1277    async fn execute_multiple_context() {
1278        // This is maybe not expected behaviour, thus documenting this via a test is important.
1279        let client = establish_test_host_connection().await;
1280        let output = client
1281            .execute("export VARIABLE=42; echo $VARIABLE")
1282            .await
1283            .unwrap()
1284            .stdout;
1285        assert_eq!("42\n", output);
1286
1287        let output = client.execute("echo $VARIABLE").await.unwrap().stdout;
1288        assert_eq!("\n", output);
1289    }
1290
1291    #[tokio::test]
1292    async fn connect_second_address() {
1293        let client = Client::connect(
1294            &[SocketAddr::from(([127, 0, 0, 1], 23)), test_address()][..],
1295            &env("ASYNC_SSH2_TEST_HOST_USER"),
1296            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1297            ServerCheckMethod::NoCheck,
1298        )
1299        .await
1300        .expect("Resolution to second address failed");
1301
1302        assert_eq!(test_address(), *client.get_connection_address(),);
1303    }
1304
1305    #[tokio::test]
1306    async fn connect_with_wrong_password() {
1307        let error = Client::connect(
1308            test_address(),
1309            &env("ASYNC_SSH2_TEST_HOST_USER"),
1310            AuthMethod::with_password("hopefully the wrong password"),
1311            ServerCheckMethod::NoCheck,
1312        )
1313        .await
1314        .expect_err("Client connected with wrong password");
1315
1316        match error {
1317            crate::Error::PasswordWrong => {}
1318            _ => panic!("Wrong error type"),
1319        }
1320    }
1321
1322    #[tokio::test]
1323    async fn invalid_address() {
1324        let no_client = Client::connect(
1325            "this is definitely not an address",
1326            &env("ASYNC_SSH2_TEST_HOST_USER"),
1327            AuthMethod::with_password("hopefully the wrong password"),
1328            ServerCheckMethod::NoCheck,
1329        )
1330        .await;
1331        assert!(no_client.is_err());
1332    }
1333
1334    #[tokio::test]
1335    async fn connect_to_wrong_port() {
1336        let no_client = Client::connect(
1337            (env("ASYNC_SSH2_TEST_HOST_IP"), 23),
1338            &env("ASYNC_SSH2_TEST_HOST_USER"),
1339            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1340            ServerCheckMethod::NoCheck,
1341        )
1342        .await;
1343        assert!(no_client.is_err());
1344    }
1345
1346    #[tokio::test]
1347    #[ignore = "This times out only after 20 seconds"]
1348    async fn connect_to_wrong_host() {
1349        let no_client = Client::connect(
1350            "172.16.0.6:22",
1351            "xxx",
1352            AuthMethod::with_password("xxx"),
1353            ServerCheckMethod::NoCheck,
1354        )
1355        .await;
1356        assert!(no_client.is_err());
1357    }
1358
1359    #[tokio::test]
1360    async fn auth_key_file() {
1361        let client = Client::connect(
1362            test_address(),
1363            &env("ASYNC_SSH2_TEST_HOST_USER"),
1364            AuthMethod::with_key_file(env("ASYNC_SSH2_TEST_CLIENT_PRIV"), None),
1365            ServerCheckMethod::NoCheck,
1366        )
1367        .await;
1368        assert!(client.is_ok());
1369    }
1370
1371    #[tokio::test]
1372    #[cfg(not(target_os = "windows"))]
1373    async fn auth_with_agent() {
1374        // This test requires SSH agent to be running with the test key loaded
1375        // In Docker environment, the agent is always properly configured
1376        let client = Client::connect(
1377            test_address(),
1378            &env("ASYNC_SSH2_TEST_HOST_USER"),
1379            AuthMethod::with_agent(),
1380            ServerCheckMethod::NoCheck,
1381        )
1382        .await
1383        .expect("Agent authentication should succeed with correct key loaded");
1384
1385        // Verify we can execute a command
1386        let output = client.execute("echo test").await.unwrap();
1387        assert_eq!("test\n", output.stdout);
1388    }
1389
1390    #[tokio::test]
1391    #[cfg(not(target_os = "windows"))]
1392    async fn auth_with_agent_wrong_user() {
1393        // This test verifies that agent auth fails with wrong username
1394        let result = Client::connect(
1395            test_address(),
1396            "wrong_user_that_does_not_exist",
1397            AuthMethod::with_agent(),
1398            ServerCheckMethod::NoCheck,
1399        )
1400        .await;
1401
1402        // Should fail with authentication error
1403        assert!(matches!(
1404            result,
1405            Err(crate::Error::AgentAuthenticationFailed)
1406        ));
1407    }
1408
1409    #[tokio::test]
1410    #[cfg(not(target_os = "windows"))]
1411    async fn auth_with_agent_no_sock() {
1412        // Test behavior when SSH_AUTH_SOCK is not set
1413        // Temporarily unset SSH_AUTH_SOCK for this test
1414        let original_sock = std::env::var("SSH_AUTH_SOCK").ok();
1415        unsafe {
1416            std::env::remove_var("SSH_AUTH_SOCK");
1417        }
1418
1419        let result = Client::connect(
1420            test_address(),
1421            &env("ASYNC_SSH2_TEST_HOST_USER"),
1422            AuthMethod::with_agent(),
1423            ServerCheckMethod::NoCheck,
1424        )
1425        .await;
1426
1427        // Restore original SSH_AUTH_SOCK if it was set
1428        if let Some(sock) = original_sock {
1429            unsafe {
1430                std::env::set_var("SSH_AUTH_SOCK", sock);
1431            }
1432        }
1433
1434        // Should fail with connection error
1435        assert!(matches!(result, Err(crate::Error::AgentConnectionFailed)));
1436    }
1437
1438    #[tokio::test]
1439    async fn auth_key_file_with_passphrase() {
1440        let client = Client::connect(
1441            test_address(),
1442            &env("ASYNC_SSH2_TEST_HOST_USER"),
1443            AuthMethod::with_key_file(
1444                env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV"),
1445                Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS")),
1446            ),
1447            ServerCheckMethod::NoCheck,
1448        )
1449        .await;
1450        if client.is_err() {
1451            println!("{:?}", client.err());
1452            panic!();
1453        }
1454        assert!(client.is_ok());
1455    }
1456
1457    #[tokio::test]
1458    async fn auth_key_str() {
1459        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PRIV")).unwrap();
1460
1461        let client = Client::connect(
1462            test_address(),
1463            &env("ASYNC_SSH2_TEST_HOST_USER"),
1464            AuthMethod::with_key(key.as_str(), None),
1465            ServerCheckMethod::NoCheck,
1466        )
1467        .await;
1468        assert!(client.is_ok());
1469    }
1470
1471    #[tokio::test]
1472    async fn auth_key_str_with_passphrase() {
1473        let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV")).unwrap();
1474
1475        let client = Client::connect(
1476            test_address(),
1477            &env("ASYNC_SSH2_TEST_HOST_USER"),
1478            AuthMethod::with_key(key.as_str(), Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS"))),
1479            ServerCheckMethod::NoCheck,
1480        )
1481        .await;
1482        assert!(client.is_ok());
1483    }
1484
1485    #[tokio::test]
1486    async fn auth_keyboard_interactive() {
1487        let client = Client::connect(
1488            test_address(),
1489            &env("ASYNC_SSH2_TEST_HOST_USER"),
1490            AuthKeyboardInteractive::new()
1491                .with_response("Password", env("ASYNC_SSH2_TEST_HOST_PW"))
1492                .into(),
1493            ServerCheckMethod::NoCheck,
1494        )
1495        .await;
1496        assert!(client.is_ok());
1497    }
1498
1499    #[tokio::test]
1500    async fn auth_keyboard_interactive_exact() {
1501        let client = Client::connect(
1502            test_address(),
1503            &env("ASYNC_SSH2_TEST_HOST_USER"),
1504            AuthKeyboardInteractive::new()
1505                .with_response_exact("Password: ", env("ASYNC_SSH2_TEST_HOST_PW"))
1506                .into(),
1507            ServerCheckMethod::NoCheck,
1508        )
1509        .await;
1510        assert!(client.is_ok());
1511    }
1512
1513    #[tokio::test]
1514    async fn auth_keyboard_interactive_wrong_response() {
1515        let client = Client::connect(
1516            test_address(),
1517            &env("ASYNC_SSH2_TEST_HOST_USER"),
1518            AuthKeyboardInteractive::new()
1519                .with_response_exact("Password: ", "wrong password")
1520                .into(),
1521            ServerCheckMethod::NoCheck,
1522        )
1523        .await;
1524        match client {
1525            Err(crate::error::Error::KeyboardInteractiveAuthFailed) => {}
1526            Err(e) => {
1527                panic!("Expected KeyboardInteractiveAuthFailed error. Got error: {e:?}")
1528            }
1529            Ok(_) => panic!("Expected KeyboardInteractiveAuthFailed error."),
1530        }
1531    }
1532
1533    #[tokio::test]
1534    async fn auth_keyboard_interactive_no_response() {
1535        let client = Client::connect(
1536            test_address(),
1537            &env("ASYNC_SSH2_TEST_HOST_USER"),
1538            AuthKeyboardInteractive::new()
1539                .with_response_exact("Password:", "123")
1540                .into(),
1541            ServerCheckMethod::NoCheck,
1542        )
1543        .await;
1544        match client {
1545            Err(crate::error::Error::KeyboardInteractiveNoResponseForPrompt(prompt)) => {
1546                assert_eq!(prompt, "Password: ");
1547            }
1548            Err(e) => {
1549                panic!("Expected KeyboardInteractiveNoResponseForPrompt error. Got error: {e:?}")
1550            }
1551            Ok(_) => panic!("Expected KeyboardInteractiveNoResponseForPrompt error."),
1552        }
1553    }
1554
1555    #[tokio::test]
1556    async fn server_check_file() {
1557        let client = Client::connect(
1558            test_address(),
1559            &env("ASYNC_SSH2_TEST_HOST_USER"),
1560            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1561            ServerCheckMethod::with_public_key_file(&env("ASYNC_SSH2_TEST_SERVER_PUB")),
1562        )
1563        .await;
1564        assert!(client.is_ok());
1565    }
1566
1567    #[tokio::test]
1568    async fn server_check_str() {
1569        let line = std::fs::read_to_string(env("ASYNC_SSH2_TEST_SERVER_PUB")).unwrap();
1570        let mut split = line.split_whitespace();
1571        let key = match (split.next(), split.next()) {
1572            (Some(_), Some(k)) => k,
1573            (Some(k), None) => k,
1574            _ => panic!("Failed to parse pub key file"),
1575        };
1576
1577        let client = Client::connect(
1578            test_address(),
1579            &env("ASYNC_SSH2_TEST_HOST_USER"),
1580            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1581            ServerCheckMethod::with_public_key(key),
1582        )
1583        .await;
1584        assert!(client.is_ok());
1585    }
1586
1587    #[tokio::test]
1588    async fn server_check_by_known_hosts_for_ip() {
1589        let client = Client::connect(
1590            test_address(),
1591            &env("ASYNC_SSH2_TEST_HOST_USER"),
1592            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1593            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1594        )
1595        .await;
1596        assert!(client.is_ok());
1597    }
1598
1599    #[tokio::test]
1600    async fn server_check_by_known_hosts_for_hostname() {
1601        let client = Client::connect(
1602            test_hostname(),
1603            &env("ASYNC_SSH2_TEST_HOST_USER"),
1604            AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1605            ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1606        )
1607        .await;
1608        if is_running_in_docker() {
1609            assert!(client.is_ok());
1610        } else {
1611            assert!(client.is_err()); // DNS can't find the docker hostname if the rust running without docker container
1612        }
1613    }
1614
1615    #[tokio::test]
1616    async fn client_can_be_cloned() {
1617        let client = establish_test_host_connection().await;
1618        let client2 = client.clone();
1619
1620        let result1 = client.execute("echo test clone").await.unwrap();
1621        let result2 = client2.execute("echo test clone2").await.unwrap();
1622
1623        assert_eq!(result1.stdout, "test clone\n");
1624        assert_eq!(result2.stdout, "test clone2\n");
1625    }
1626
1627    #[tokio::test]
1628    async fn client_can_upload_file() {
1629        let client = establish_test_host_connection().await;
1630        client
1631            .upload_file(
1632                &env("ASYNC_SSH2_TEST_UPLOAD_FILE"),
1633                "/tmp/uploaded",
1634                None,
1635                None,
1636                false,
1637            )
1638            .await
1639            .unwrap();
1640        let result = client.execute("cat /tmp/uploaded").await.unwrap();
1641        assert_eq!(result.stdout, "this is a test file\n");
1642    }
1643
1644    #[tokio::test]
1645    async fn client_can_download_file() {
1646        let client = establish_test_host_connection().await;
1647
1648        client
1649            .execute("echo 'this is a downloaded test file' > /tmp/test_download")
1650            .await
1651            .unwrap();
1652
1653        let local_path = std::env::temp_dir().join("downloaded_test_file");
1654        client
1655            .download_file("/tmp/test_download", &local_path)
1656            .await
1657            .unwrap();
1658
1659        let contents = tokio::fs::read_to_string(&local_path).await.unwrap();
1660        assert_eq!(contents, "this is a downloaded test file\n");
1661    }
1662}