1use russh::client::KeyboardInteractiveAuthResponse;
2use russh::{
3 Channel,
4 client::{Config, Handle, Handler, Msg},
5};
6use russh_sftp::{client::SftpSession, protocol::OpenFlags};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Instant;
10use std::{fmt::Debug, path::Path};
11use std::{io, path::PathBuf};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::mpsc;
14
15use crate::ToSocketAddrsWithHostname;
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22#[non_exhaustive]
23pub enum AuthMethod {
24 Password(String),
25 PrivateKey {
26 key_data: String,
28 key_pass: Option<String>,
29 },
30 PrivateKeyFile {
31 key_file_path: PathBuf,
32 key_pass: Option<String>,
33 },
34 #[cfg(not(target_os = "windows"))]
35 PublicKeyFile {
36 key_file_path: PathBuf,
37 },
38 #[cfg(not(target_os = "windows"))]
39 Agent,
40 KeyboardInteractive(AuthKeyboardInteractive),
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum SteamingOutput {
45 Stdout(Vec<u8>),
46 Stderr(Vec<u8>),
47 ExitStatus(u32),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51struct PromptResponse {
52 exact: bool,
53 prompt: String,
54 response: String,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
58#[non_exhaustive]
59pub struct AuthKeyboardInteractive {
60 submethods: Option<String>,
62 responses: Vec<PromptResponse>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66#[non_exhaustive]
67pub enum ServerCheckMethod {
68 NoCheck,
69 PublicKey(String),
71 PublicKeyFile(String),
72 DefaultKnownHostsFile,
73 KnownHostsFile(String),
74}
75
76impl AuthMethod {
77 pub fn with_password(password: &str) -> Self {
79 Self::Password(password.to_string())
80 }
81
82 pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
83 Self::PrivateKey {
84 key_data: key.to_string(),
85 key_pass: passphrase.map(str::to_string),
86 }
87 }
88
89 pub fn with_key_file<T: AsRef<Path>>(key_file_path: T, passphrase: Option<&str>) -> Self {
90 Self::PrivateKeyFile {
91 key_file_path: key_file_path.as_ref().to_path_buf(),
92 key_pass: passphrase.map(str::to_string),
93 }
94 }
95
96 #[cfg(not(target_os = "windows"))]
97 pub fn with_public_key_file<T: AsRef<Path>>(key_file_path: T) -> Self {
98 Self::PublicKeyFile {
99 key_file_path: key_file_path.as_ref().to_path_buf(),
100 }
101 }
102
103 #[cfg(not(target_os = "windows"))]
119 pub fn with_agent() -> Self {
120 Self::Agent
121 }
122
123 pub const fn with_keyboard_interactive(auth: AuthKeyboardInteractive) -> Self {
124 Self::KeyboardInteractive(auth)
125 }
126}
127
128impl AuthKeyboardInteractive {
129 pub fn new() -> Self {
130 Default::default()
131 }
132
133 pub fn with_submethods(mut self, submethods: impl Into<String>) -> Self {
135 self.submethods = Some(submethods.into());
136 self
137 }
138
139 pub fn with_response(mut self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
143 self.responses.push(PromptResponse {
144 exact: false,
145 prompt: prompt.into(),
146 response: response.into(),
147 });
148
149 self
150 }
151
152 pub fn with_response_exact(
154 mut self,
155 prompt: impl Into<String>,
156 response: impl Into<String>,
157 ) -> Self {
158 self.responses.push(PromptResponse {
159 exact: true,
160 prompt: prompt.into(),
161 response: response.into(),
162 });
163
164 self
165 }
166}
167
168impl PromptResponse {
169 fn matches(&self, received_prompt: &str) -> bool {
170 if self.exact {
171 self.prompt.eq(received_prompt)
172 } else {
173 received_prompt.contains(&self.prompt)
174 }
175 }
176}
177
178impl From<AuthKeyboardInteractive> for AuthMethod {
179 fn from(value: AuthKeyboardInteractive) -> Self {
180 Self::with_keyboard_interactive(value)
181 }
182}
183
184impl ServerCheckMethod {
185 pub fn with_public_key(key: &str) -> Self {
187 Self::PublicKey(key.to_string())
188 }
189
190 pub fn with_public_key_file(key_file_name: &str) -> Self {
192 Self::PublicKeyFile(key_file_name.to_string())
193 }
194
195 pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
197 Self::KnownHostsFile(known_hosts_file.to_string())
198 }
199}
200
201#[derive(Clone)]
229pub struct Client {
230 connection_handle: Arc<Handle<ClientHandler>>,
231 username: String,
232 address: SocketAddr,
233}
234
235impl Client {
236 pub async fn connect(
248 addr: impl ToSocketAddrsWithHostname,
249 username: &str,
250 auth: AuthMethod,
251 server_check: ServerCheckMethod,
252 ) -> Result<Self, crate::Error> {
253 Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
254 }
255
256 pub async fn connect_with_config(
259 addr: impl ToSocketAddrsWithHostname,
260 username: &str,
261 auth: AuthMethod,
262 server_check: ServerCheckMethod,
263 config: Config,
264 ) -> Result<Self, crate::Error> {
265 let config = Arc::new(config);
266
267 let socket_addrs = addr
269 .to_socket_addrs()
270 .map_err(crate::Error::AddressInvalid)?;
271 let mut connect_res = Err(crate::Error::AddressInvalid(io::Error::new(
272 io::ErrorKind::InvalidInput,
273 "could not resolve to any addresses",
274 )));
275 for socket_addr in socket_addrs {
276 let handler = ClientHandler {
277 hostname: addr.hostname(),
278 host: socket_addr,
279 server_check: server_check.clone(),
280 };
281 match russh::client::connect(config.clone(), socket_addr, handler).await {
282 Ok(h) => {
283 connect_res = Ok((socket_addr, h));
284 break;
285 }
286 Err(e) => connect_res = Err(e),
287 }
288 }
289 let (address, mut handle) = connect_res?;
290 let username = username.to_string();
291
292 Self::authenticate(&mut handle, &username, auth).await?;
293
294 Ok(Self {
295 connection_handle: Arc::new(handle),
296 username,
297 address,
298 })
299 }
300
301 async fn authenticate(
303 handle: &mut Handle<ClientHandler>,
304 username: &String,
305 auth: AuthMethod,
306 ) -> Result<(), crate::Error> {
307 match auth {
308 AuthMethod::Password(password) => {
309 let is_authentificated = handle.authenticate_password(username, password).await?;
310 if !is_authentificated.success() {
311 return Err(crate::Error::PasswordWrong);
312 }
313 }
314 AuthMethod::PrivateKey { key_data, key_pass } => {
315 let cprivk = russh::keys::decode_secret_key(key_data.as_str(), key_pass.as_deref())
316 .map_err(crate::Error::KeyInvalid)?;
317 let is_authentificated = handle
318 .authenticate_publickey(
319 username,
320 russh::keys::PrivateKeyWithHashAlg::new(
321 Arc::new(cprivk),
322 handle.best_supported_rsa_hash().await?.flatten(),
323 ),
324 )
325 .await?;
326 if !is_authentificated.success() {
327 return Err(crate::Error::KeyAuthFailed);
328 }
329 }
330 AuthMethod::PrivateKeyFile {
331 key_file_path,
332 key_pass,
333 } => {
334 let cprivk = russh::keys::load_secret_key(key_file_path, key_pass.as_deref())
335 .map_err(crate::Error::KeyInvalid)?;
336 let is_authentificated = handle
337 .authenticate_publickey(
338 username,
339 russh::keys::PrivateKeyWithHashAlg::new(
340 Arc::new(cprivk),
341 handle.best_supported_rsa_hash().await?.flatten(),
342 ),
343 )
344 .await?;
345 if !is_authentificated.success() {
346 return Err(crate::Error::KeyAuthFailed);
347 }
348 }
349 #[cfg(not(target_os = "windows"))]
350 AuthMethod::PublicKeyFile { key_file_path } => {
351 let cpubk = russh::keys::load_public_key(key_file_path)
352 .map_err(crate::Error::KeyInvalid)?;
353 let mut agent = russh::keys::agent::client::AgentClient::connect_env()
354 .await
355 .unwrap();
356 let mut auth_identity: Option<russh::keys::PublicKey> = None;
357 for identity in agent
358 .request_identities()
359 .await
360 .map_err(crate::Error::KeyInvalid)?
361 {
362 if identity == cpubk {
363 auth_identity = Some(identity.clone());
364 break;
365 }
366 }
367
368 if auth_identity.is_none() {
369 return Err(crate::Error::KeyAuthFailed);
370 }
371
372 let is_authentificated = handle
373 .authenticate_publickey_with(
374 username,
375 cpubk,
376 handle.best_supported_rsa_hash().await?.flatten(),
377 &mut agent,
378 )
379 .await?;
380 if !is_authentificated.success() {
381 return Err(crate::Error::KeyAuthFailed);
382 }
383 }
384 #[cfg(not(target_os = "windows"))]
385 AuthMethod::Agent => {
386 let mut agent = russh::keys::agent::client::AgentClient::connect_env()
387 .await
388 .map_err(|_| crate::Error::AgentConnectionFailed)?;
389
390 let identities = agent
391 .request_identities()
392 .await
393 .map_err(|_| crate::Error::AgentRequestIdentitiesFailed)?;
394
395 if identities.is_empty() {
396 return Err(crate::Error::AgentNoIdentities);
397 }
398
399 let mut auth_success = false;
400 for identity in identities {
401 let result = handle
402 .authenticate_publickey_with(
403 username,
404 identity.clone(),
405 handle.best_supported_rsa_hash().await?.flatten(),
406 &mut agent,
407 )
408 .await;
409
410 if let Ok(auth_result) = result
411 && auth_result.success()
412 {
413 auth_success = true;
414 break;
415 }
416 }
417
418 if !auth_success {
419 return Err(crate::Error::AgentAuthenticationFailed);
420 }
421 }
422 AuthMethod::KeyboardInteractive(mut kbd) => {
423 let mut res = handle
424 .authenticate_keyboard_interactive_start(username, kbd.submethods)
425 .await?;
426 loop {
427 let prompts = match res {
428 KeyboardInteractiveAuthResponse::Success => break,
429 KeyboardInteractiveAuthResponse::Failure { .. } => {
430 return Err(crate::Error::KeyboardInteractiveAuthFailed);
431 }
432 KeyboardInteractiveAuthResponse::InfoRequest { prompts, .. } => prompts,
433 };
434
435 let mut responses = vec![];
436 for prompt in prompts {
437 let Some(pos) = kbd
438 .responses
439 .iter()
440 .position(|pr| pr.matches(&prompt.prompt))
441 else {
442 return Err(crate::Error::KeyboardInteractiveNoResponseForPrompt(
443 prompt.prompt,
444 ));
445 };
446 let pr = kbd.responses.remove(pos);
447 responses.push(pr.response);
448 }
449
450 res = handle
451 .authenticate_keyboard_interactive_respond(responses)
452 .await?;
453 }
454 }
455 };
456 Ok(())
457 }
458
459 pub async fn get_channel(&self) -> Result<Channel<Msg>, crate::Error> {
460 self.connection_handle
461 .channel_open_session()
462 .await
463 .map_err(crate::Error::SshError)
464 }
465
466 pub async fn open_direct_tcpip_channel<
470 T: ToSocketAddrsWithHostname,
471 S: Into<Option<SocketAddr>>,
472 >(
473 &self,
474 target: T,
475 src: S,
476 ) -> Result<Channel<Msg>, crate::Error> {
477 let targets = target
478 .to_socket_addrs()
479 .map_err(crate::Error::AddressInvalid)?;
480 let src = src
481 .into()
482 .map(|src| (src.ip().to_string(), src.port().into()))
483 .unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
484
485 let mut connect_err = crate::Error::AddressInvalid(io::Error::new(
486 io::ErrorKind::InvalidInput,
487 "could not resolve to any addresses",
488 ));
489 for target in targets {
490 match self
491 .connection_handle
492 .channel_open_direct_tcpip(
493 target.ip().to_string(),
494 target.port().into(),
495 src.0.clone(),
496 src.1,
497 )
498 .await
499 {
500 Ok(channel) => return Ok(channel),
501 Err(err) => connect_err = crate::Error::SshError(err),
502 }
503 }
504
505 Err(connect_err)
506 }
507
508 pub async fn upload_file<T, U>(
520 &self,
521 src_file_path: T,
522 dest_file_path: U,
525 timeout_seconds: Option<u64>,
526 buffer_size_in_bytes: Option<usize>,
527 show_progress: bool,
528 ) -> Result<(), crate::Error>
529 where
530 T: AsRef<Path> + std::fmt::Display,
531 U: Into<String>,
532 {
533 let channel = self.get_channel().await?;
535 channel.request_subsystem(true, "sftp").await?;
536 let sftp = SftpSession::new_opts(channel.into_stream(), timeout_seconds).await?;
537
538 let file_size = tokio::fs::metadata(&src_file_path).await?.len();
539 let local_file = tokio::fs::File::open(&src_file_path)
541 .await
542 .map_err(crate::Error::IoError)?;
543 let mut local_file_buffered = tokio::io::BufReader::new(local_file);
544
545 let dest_file_path = dest_file_path.into();
546 let mut remote_file = sftp
547 .open_with_flags(
548 dest_file_path.clone(),
549 OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
550 )
551 .await?;
552
553 let buffer_size_in_bytes = buffer_size_in_bytes.unwrap_or(4096);
554 let mut buffer = vec![0; buffer_size_in_bytes];
555
556 let mut total_bytes_copied = 0;
557 let mut next_progress_marker = 5.0;
558
559 let start_time = Instant::now();
560 if show_progress {
561 log::info!(
562 "Starting file upload from {src_file_path} to {dest_file_path}, total bytes to be transferred: {}",
563 file_size
564 );
565 }
566 loop {
567 let n = local_file_buffered.read(&mut buffer).await?;
568 if n == 0 {
569 break;
570 }
571 remote_file
572 .write_all(&buffer[..n])
573 .await
574 .map_err(crate::Error::IoError)?;
575 if show_progress {
576 total_bytes_copied += n as u64;
577 let progress = (total_bytes_copied as f64 / file_size as f64) * 100.0;
578 if progress >= next_progress_marker {
579 log::info!(
580 "Progress of upload from {src_file_path} to {dest_file_path}: {:.0}% in elapsed time: {}s",
581 next_progress_marker,
582 start_time.elapsed().as_secs_f64()
583 );
584 next_progress_marker += 5.0;
585 }
586 }
587 }
588
589 if show_progress {
590 log::info!(
591 "file upload comprising {file_size} bytes from {src_file_path} to {dest_file_path} completed successfully in {}s",
592 start_time.elapsed().as_secs_f64()
593 );
594 }
595 remote_file
596 .shutdown()
597 .await
598 .map_err(crate::Error::IoError)?;
599
600 Ok(())
601 }
602
603 pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
611 &self,
612 remote_file_path: U,
613 local_file_path: T,
614 ) -> Result<(), crate::Error> {
615 let channel = self.get_channel().await?;
617 channel.request_subsystem(true, "sftp").await?;
618 let sftp = SftpSession::new(channel.into_stream()).await?;
619
620 let mut remote_file = sftp
622 .open_with_flags(remote_file_path, OpenFlags::READ)
623 .await?;
624
625 let mut contents = Vec::new();
627 remote_file.read_to_end(contents.as_mut()).await?;
628
629 let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
631 .await
632 .map_err(crate::Error::IoError)?;
633
634 local_file
635 .write_all(&contents)
636 .await
637 .map_err(crate::Error::IoError)?;
638 local_file.flush().await.map_err(crate::Error::IoError)?;
639
640 Ok(())
641 }
642
643 pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, crate::Error> {
656 let mut stdout_buffer = vec![];
657 let mut stderr_buffer = vec![];
658 let mut channel = self.connection_handle.channel_open_session().await?;
659 channel.exec(true, command).await?;
660
661 let mut result: Option<u32> = None;
662
663 while let Some(msg) = channel.wait().await {
665 match msg {
667 russh::ChannelMsg::Data { ref data } => {
669 stdout_buffer.write_all(data).await.unwrap()
670 }
671 russh::ChannelMsg::ExtendedData { ref data, ext } => {
672 if ext == 1 {
673 stderr_buffer.write_all(data).await.unwrap()
674 }
675 }
676
677 russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
681
682 _ => {}
687 }
688 }
689
690 if let Some(result) = result {
692 Ok(CommandExecutedResult {
693 stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
694 stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
695 exit_status: result,
696 })
697
698 } else {
700 Err(crate::Error::CommandDidntExit)
701 }
702 }
703
704 #[deprecated(
712 since = "0.11.0",
713 note = "Use execute_io with channels directly for more flexibility.\n\
714 This method will be removed or introduced breaking changes in future versions.\n\
715 At minimum, SteamingOutput will be renamed to StreamingOutput"
716 )]
717 pub async fn execute_streaming(
718 &self,
719 command: &str,
720 ch: tokio::sync::mpsc::Sender<SteamingOutput>,
721 ) -> Result<u32, crate::Error> {
722 let (stdout_tx, mut stdout_rx) = tokio::sync::mpsc::channel(1);
723 let (stderr_tx, mut stderr_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1);
724
725 let exec_future = self.execute_io(command, stdout_tx, Some(stderr_tx), None, false, None);
726 tokio::pin!(exec_future);
727 let result = loop {
728 tokio::select! {
729 result = &mut exec_future => break result,
730 Some(stdout) = stdout_rx.recv() => {
731 ch.send(SteamingOutput::Stdout(stdout)).await.unwrap();
732 },
733 Some(stderr) = stderr_rx.recv() => {
734 ch.send(SteamingOutput::Stderr(stderr)).await.unwrap();
735 },
736 };
737 }?;
738 if let Some(stdout) = stdout_rx.recv().await {
740 ch.send(SteamingOutput::Stdout(stdout)).await.unwrap();
741 }
742 if let Some(stderr) = stderr_rx.recv().await {
743 ch.send(SteamingOutput::Stderr(stderr)).await.unwrap();
744 }
745 ch.send(SteamingOutput::ExitStatus(result)).await.unwrap();
746 Ok(result)
747 }
748
749 pub async fn execute_io(
812 &self,
813 command: &str,
814 stdout_channel: mpsc::Sender<Vec<u8>>,
815 stderr_channel: Option<mpsc::Sender<Vec<u8>>>,
816 mut stdin_channel: Option<mpsc::Receiver<Vec<u8>>>,
817 request_pty: bool,
818 default_exit_code: Option<u32>,
819 ) -> Result<u32, crate::Error> {
820 let mut channel = self.connection_handle.channel_open_session().await?;
821
822 let mut result: Option<u32> = None;
823 if request_pty {
824 channel
825 .request_pty(false, "xterm", 80_u32, 24_u32, 0, 0, &[])
826 .await?;
827 }
828
829 channel.exec(true, command).await?;
830
831 loop {
833 let recv_stdin = async {
834 if let Some(ch) = stdin_channel.as_mut() {
835 Some(ch.recv().await)
836 } else {
837 None
838 }
839 };
840 tokio::select! {
841 Some(input) = recv_stdin => {
842 if let Some(input) = input {
843 if input.is_empty() {
844 channel.eof().await? ;
845 } else {
846 channel.data(&input as &[u8]).await?;
847 }
848 }
849 },
850 msg = channel.wait() => {
851 match msg {
853 Some(russh::ChannelMsg::Data { ref data }) => {
855 stdout_channel
857 .send(data.to_vec())
858 .await
859 .map_err(crate::Error::ChannelSendError)?;
860 }
861 Some (russh::ChannelMsg::ExtendedData { ref data, ext }) => {
862 if ext == 1 {
863 if let Some(stderr_channel) = &stderr_channel {
864 stderr_channel
866 .send(data.to_vec())
867 .await
868 .map_err(crate::Error::ChannelSendError)?;
869 } else {
870 stdout_channel
872 .send(data.to_vec())
873 .await
874 .map_err(crate::Error::ChannelSendError)?;
875 }
876 }
877 }
878
879 Some (russh::ChannelMsg::ExitStatus { exit_status }) => result = Some(exit_status),
883
884 Some (_) => {},
889 None => break,
890 }
891 }
892 }
893 }
894
895 if let Some(result) = result {
897 Ok(result)
898 } else if let Some(default_exit_code) = default_exit_code {
900 Ok(default_exit_code)
901 } else {
903 Err(crate::Error::CommandDidntExit)
904 }
905 }
906
907 pub fn get_connection_username(&self) -> &String {
909 &self.username
910 }
911
912 pub fn get_connection_address(&self) -> &SocketAddr {
914 &self.address
915 }
916
917 pub async fn disconnect(&self) -> Result<(), crate::Error> {
918 self.connection_handle
919 .disconnect(russh::Disconnect::ByApplication, "", "")
920 .await
921 .map_err(crate::Error::SshError)
922 }
923
924 pub fn is_closed(&self) -> bool {
925 self.connection_handle.is_closed()
926 }
927}
928
929impl Debug for Client {
930 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
931 f.debug_struct("Client")
932 .field("username", &self.username)
933 .field("address", &self.address)
934 .field("connection_handle", &"Handle<ClientHandler>")
935 .finish()
936 }
937}
938
939#[derive(Debug, Clone, PartialEq, Eq, Hash)]
940pub struct CommandExecutedResult {
941 pub stdout: String,
943 pub stderr: String,
945 pub exit_status: u32,
947}
948
949#[derive(Debug, Clone)]
950struct ClientHandler {
951 hostname: String,
952 host: SocketAddr,
953 server_check: ServerCheckMethod,
954}
955
956impl Handler for ClientHandler {
957 type Error = crate::Error;
958
959 async fn check_server_key(
960 &mut self,
961 server_public_key: &russh::keys::PublicKey,
962 ) -> Result<bool, Self::Error> {
963 match &self.server_check {
964 ServerCheckMethod::NoCheck => Ok(true),
965 ServerCheckMethod::PublicKey(key) => {
966 let pk = russh::keys::parse_public_key_base64(key)
967 .map_err(|_| crate::Error::ServerCheckFailed)?;
968
969 Ok(pk == *server_public_key)
970 }
971 ServerCheckMethod::PublicKeyFile(key_file_name) => {
972 let pk = russh::keys::load_public_key(key_file_name)
973 .map_err(|_| crate::Error::ServerCheckFailed)?;
974
975 Ok(pk == *server_public_key)
976 }
977 ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
978 let result = russh::keys::check_known_hosts_path(
979 &self.hostname,
980 self.host.port(),
981 server_public_key,
982 known_hosts_path,
983 )
984 .map_err(|_| crate::Error::ServerCheckFailed)?;
985
986 Ok(result)
987 }
988 ServerCheckMethod::DefaultKnownHostsFile => {
989 let result = russh::keys::check_known_hosts(
990 &self.hostname,
991 self.host.port(),
992 server_public_key,
993 )
994 .map_err(|_| crate::Error::ServerCheckFailed)?;
995
996 Ok(result)
997 }
998 }
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 #![allow(deprecated, clippy::useless_vec)]
1005
1006 use crate::client::*;
1007 use core::time;
1008 use dotenv::dotenv;
1009 use std::path::Path;
1010 use std::sync::Once;
1011 use tokio::io::AsyncReadExt;
1012 static INIT: Once = Once::new();
1013
1014 fn initialize() {
1015 println!("Running initialization code before tests...");
1017 if is_running_in_docker() {
1019 println!("Running inside Docker.");
1020 } else {
1021 println!("Not running inside Docker. Load env from file");
1022 dotenv().ok();
1023 }
1024 }
1025 fn is_running_in_docker() -> bool {
1026 Path::new("/.dockerenv").exists() || check_cgroup()
1027 }
1028
1029 fn check_cgroup() -> bool {
1030 match std::fs::read_to_string("/proc/1/cgroup") {
1031 Ok(contents) => contents.contains("docker"),
1032 Err(_) => false,
1033 }
1034 }
1035
1036 fn env(name: &str) -> String {
1037 INIT.call_once(|| {
1038 initialize();
1039 });
1040 std::env::var(name).unwrap_or_else(|_| {
1041 panic!(
1042 "Failed to get env var needed for test, make sure to set the following env var: {name}",
1043 )
1044 })
1045 }
1046
1047 fn test_address() -> SocketAddr {
1048 format!(
1049 "{}:{}",
1050 env("ASYNC_SSH2_TEST_HOST_IP"),
1051 env("ASYNC_SSH2_TEST_HOST_PORT")
1052 )
1053 .parse()
1054 .unwrap()
1055 }
1056
1057 fn test_hostname() -> impl ToSocketAddrsWithHostname {
1058 (
1059 env("ASYNC_SSH2_TEST_HOST_NAME"),
1060 env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
1061 )
1062 }
1063
1064 async fn establish_test_host_connection() -> Client {
1065 Client::connect(
1066 (
1067 env("ASYNC_SSH2_TEST_HOST_IP"),
1068 env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
1069 ),
1070 &env("ASYNC_SSH2_TEST_HOST_USER"),
1071 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1072 ServerCheckMethod::NoCheck,
1073 )
1074 .await
1075 .expect("Connection/Authentification failed")
1076 }
1077
1078 #[tokio::test]
1079 async fn connect_with_password() {
1080 let client = establish_test_host_connection().await;
1081 assert_eq!(
1082 &env("ASYNC_SSH2_TEST_HOST_USER"),
1083 client.get_connection_username(),
1084 );
1085 assert_eq!(test_address(), *client.get_connection_address(),);
1086 }
1087
1088 #[tokio::test]
1089 async fn execute_command_result() {
1090 let client = establish_test_host_connection().await;
1091 let output = client.execute("echo test!!!").await.unwrap();
1092 assert_eq!("test!!!\n", output.stdout);
1093 assert_eq!("", output.stderr);
1094 assert_eq!(0, output.exit_status);
1095 }
1096
1097 #[tokio::test]
1098 async fn execute_streaming_command_result() {
1099 let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1100 let client = establish_test_host_connection().await;
1101 let result = client.execute_streaming("echo test!!!", tx).await.unwrap();
1102 let mut output = Vec::new();
1103 while let Some(msg) = rx.recv().await {
1104 output.push(msg);
1105 }
1106 assert_eq!(0, result);
1107 assert_eq!(
1108 &[
1109 SteamingOutput::Stdout(b"test!!!\n".to_vec()),
1110 SteamingOutput::ExitStatus(0),
1111 ],
1112 output.as_slice(),
1113 );
1114 }
1115
1116 #[tokio::test]
1117 async fn execute_command_result_stderr() {
1118 let client = establish_test_host_connection().await;
1119 let output = client.execute("echo test!!! 1>&2").await.unwrap();
1120 assert_eq!("", output.stdout);
1121 assert_eq!("test!!!\n", output.stderr);
1122 assert_eq!(0, output.exit_status);
1123 }
1124
1125 #[tokio::test]
1126 async fn execute_streaming_command_result_stderr() {
1127 let client = establish_test_host_connection().await;
1128 let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1129 let result = client
1130 .execute_streaming("echo test!!! 1>&2", tx)
1131 .await
1132 .unwrap();
1133 let mut output = Vec::new();
1134 while let Some(msg) = rx.recv().await {
1135 output.push(msg);
1136 }
1137 assert_eq!(0, result);
1138 assert_eq!(
1139 &[
1140 SteamingOutput::Stderr(b"test!!!\n".to_vec()),
1141 SteamingOutput::ExitStatus(0),
1142 ],
1143 output.as_slice()
1144 );
1145 }
1146
1147 #[tokio::test]
1148 async fn unicode_output() {
1149 let client = establish_test_host_connection().await;
1150 let output = client.execute("echo To thḙ moon! 🚀").await.unwrap();
1151 assert_eq!("To thḙ moon! 🚀\n", output.stdout);
1152 assert_eq!(0, output.exit_status);
1153 }
1154
1155 #[tokio::test]
1156 async fn execute_command_status() {
1157 let client = establish_test_host_connection().await;
1158 let output = client.execute("exit 42").await.unwrap();
1159 assert_eq!(42, output.exit_status);
1160 }
1161
1162 #[tokio::test]
1163 async fn execute_streaming_command_status() {
1164 let client = establish_test_host_connection().await;
1165 let (tx, mut rx) = tokio::sync::mpsc::channel(10);
1166 let result = client.execute_streaming("exit 42", tx).await.unwrap();
1167 let mut output = Vec::new();
1168 while let Some(msg) = rx.recv().await {
1169 output.push(msg);
1170 }
1171 assert_eq!(42, result);
1172 assert_eq!(&[SteamingOutput::ExitStatus(42),], output.as_slice());
1173 }
1174
1175 #[tokio::test]
1176 async fn execute_io_command() {
1177 let client = establish_test_host_connection().await;
1178 let (stdout_tx, mut stdout_rx) = tokio::sync::mpsc::channel(10);
1179 let (stderr_tx, mut stderr_rx) = tokio::sync::mpsc::channel(10);
1180 let cmd = "echo out1; echo err1 1>&2; echo out2; echo err2 1>&2; exit 7";
1181 let exec_future = client.execute_io(cmd, stdout_tx, Some(stderr_tx), None, false, None);
1182 tokio::pin!(exec_future);
1183 let mut result: Option<u32> = None;
1184 let mut stdout_output = vec![];
1185 let mut stderr_output = vec![];
1186 loop {
1187 tokio::select! {
1188 result_inner = &mut exec_future => {
1189 result = Some(result_inner.unwrap());
1190 },
1191 Some(stdout) = stdout_rx.recv() => {
1192 stdout_output.push(stdout);
1193 },
1194 Some(stderr) = stderr_rx.recv() => {
1195 stderr_output.push(stderr);
1196 },
1197 };
1198 if result.is_some() {
1199 break;
1200 }
1201 }
1202 assert_eq!(Some(7), result);
1203 assert_eq!(
1204 vec![b"out1\n".to_vec(), b"out2\n".to_vec()].concat(),
1205 stdout_output.concat()
1206 );
1207 assert_eq!(
1208 vec![b"err1\n".to_vec(), b"err2\n".to_vec()].concat(),
1209 stderr_output.concat()
1210 );
1211 }
1212
1213 #[tokio::test]
1214 async fn execute_multiple_commands() {
1215 let client = establish_test_host_connection().await;
1216 let output = client.execute("echo test!!!").await.unwrap().stdout;
1217 assert_eq!("test!!!\n", output);
1218
1219 let output = client.execute("echo Hello World").await.unwrap().stdout;
1220 assert_eq!("Hello World\n", output);
1221 }
1222
1223 #[tokio::test]
1224 async fn direct_tcpip_channel() {
1225 let client = establish_test_host_connection().await;
1226 let channel = client
1227 .open_direct_tcpip_channel(
1228 format!(
1229 "{}:{}",
1230 env("ASYNC_SSH2_TEST_HTTP_SERVER_IP"),
1231 env("ASYNC_SSH2_TEST_HTTP_SERVER_PORT"),
1232 ),
1233 None,
1234 )
1235 .await
1236 .unwrap();
1237
1238 let mut stream = channel.into_stream();
1239 stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await.unwrap();
1240
1241 let mut response = String::new();
1242 stream.read_to_string(&mut response).await.unwrap();
1243
1244 let body = response.split_once("\r\n\r\n").unwrap().1;
1245 assert_eq!("Hello", body);
1246 }
1247
1248 #[tokio::test]
1249 async fn stderr_redirection() {
1250 let client = establish_test_host_connection().await;
1251
1252 let output = client.execute("echo foo >/dev/null").await.unwrap();
1253 assert_eq!("", output.stdout);
1254
1255 let output = client.execute("echo foo >>/dev/stderr").await.unwrap();
1256 assert_eq!("", output.stdout);
1257
1258 let output = client.execute("2>&1 echo foo >>/dev/stderr").await.unwrap();
1259 assert_eq!("foo\n", output.stdout);
1260 }
1261
1262 #[tokio::test]
1263 async fn sequential_commands() {
1264 let client = establish_test_host_connection().await;
1265
1266 for i in 0..100 {
1267 std::thread::sleep(time::Duration::from_millis(100));
1268 let res = client
1269 .execute(&format!("echo {i}"))
1270 .await
1271 .unwrap_or_else(|_| panic!("Execution failed in iteration {i}"));
1272 assert_eq!(format!("{i}\n"), res.stdout);
1273 }
1274 }
1275
1276 #[tokio::test]
1277 async fn execute_multiple_context() {
1278 let client = establish_test_host_connection().await;
1280 let output = client
1281 .execute("export VARIABLE=42; echo $VARIABLE")
1282 .await
1283 .unwrap()
1284 .stdout;
1285 assert_eq!("42\n", output);
1286
1287 let output = client.execute("echo $VARIABLE").await.unwrap().stdout;
1288 assert_eq!("\n", output);
1289 }
1290
1291 #[tokio::test]
1292 async fn connect_second_address() {
1293 let client = Client::connect(
1294 &[SocketAddr::from(([127, 0, 0, 1], 23)), test_address()][..],
1295 &env("ASYNC_SSH2_TEST_HOST_USER"),
1296 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1297 ServerCheckMethod::NoCheck,
1298 )
1299 .await
1300 .expect("Resolution to second address failed");
1301
1302 assert_eq!(test_address(), *client.get_connection_address(),);
1303 }
1304
1305 #[tokio::test]
1306 async fn connect_with_wrong_password() {
1307 let error = Client::connect(
1308 test_address(),
1309 &env("ASYNC_SSH2_TEST_HOST_USER"),
1310 AuthMethod::with_password("hopefully the wrong password"),
1311 ServerCheckMethod::NoCheck,
1312 )
1313 .await
1314 .expect_err("Client connected with wrong password");
1315
1316 match error {
1317 crate::Error::PasswordWrong => {}
1318 _ => panic!("Wrong error type"),
1319 }
1320 }
1321
1322 #[tokio::test]
1323 async fn invalid_address() {
1324 let no_client = Client::connect(
1325 "this is definitely not an address",
1326 &env("ASYNC_SSH2_TEST_HOST_USER"),
1327 AuthMethod::with_password("hopefully the wrong password"),
1328 ServerCheckMethod::NoCheck,
1329 )
1330 .await;
1331 assert!(no_client.is_err());
1332 }
1333
1334 #[tokio::test]
1335 async fn connect_to_wrong_port() {
1336 let no_client = Client::connect(
1337 (env("ASYNC_SSH2_TEST_HOST_IP"), 23),
1338 &env("ASYNC_SSH2_TEST_HOST_USER"),
1339 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1340 ServerCheckMethod::NoCheck,
1341 )
1342 .await;
1343 assert!(no_client.is_err());
1344 }
1345
1346 #[tokio::test]
1347 #[ignore = "This times out only after 20 seconds"]
1348 async fn connect_to_wrong_host() {
1349 let no_client = Client::connect(
1350 "172.16.0.6:22",
1351 "xxx",
1352 AuthMethod::with_password("xxx"),
1353 ServerCheckMethod::NoCheck,
1354 )
1355 .await;
1356 assert!(no_client.is_err());
1357 }
1358
1359 #[tokio::test]
1360 async fn auth_key_file() {
1361 let client = Client::connect(
1362 test_address(),
1363 &env("ASYNC_SSH2_TEST_HOST_USER"),
1364 AuthMethod::with_key_file(env("ASYNC_SSH2_TEST_CLIENT_PRIV"), None),
1365 ServerCheckMethod::NoCheck,
1366 )
1367 .await;
1368 assert!(client.is_ok());
1369 }
1370
1371 #[tokio::test]
1372 #[cfg(not(target_os = "windows"))]
1373 async fn auth_with_agent() {
1374 let client = Client::connect(
1377 test_address(),
1378 &env("ASYNC_SSH2_TEST_HOST_USER"),
1379 AuthMethod::with_agent(),
1380 ServerCheckMethod::NoCheck,
1381 )
1382 .await
1383 .expect("Agent authentication should succeed with correct key loaded");
1384
1385 let output = client.execute("echo test").await.unwrap();
1387 assert_eq!("test\n", output.stdout);
1388 }
1389
1390 #[tokio::test]
1391 #[cfg(not(target_os = "windows"))]
1392 async fn auth_with_agent_wrong_user() {
1393 let result = Client::connect(
1395 test_address(),
1396 "wrong_user_that_does_not_exist",
1397 AuthMethod::with_agent(),
1398 ServerCheckMethod::NoCheck,
1399 )
1400 .await;
1401
1402 assert!(matches!(
1404 result,
1405 Err(crate::Error::AgentAuthenticationFailed)
1406 ));
1407 }
1408
1409 #[tokio::test]
1410 #[cfg(not(target_os = "windows"))]
1411 async fn auth_with_agent_no_sock() {
1412 let original_sock = std::env::var("SSH_AUTH_SOCK").ok();
1415 unsafe {
1416 std::env::remove_var("SSH_AUTH_SOCK");
1417 }
1418
1419 let result = Client::connect(
1420 test_address(),
1421 &env("ASYNC_SSH2_TEST_HOST_USER"),
1422 AuthMethod::with_agent(),
1423 ServerCheckMethod::NoCheck,
1424 )
1425 .await;
1426
1427 if let Some(sock) = original_sock {
1429 unsafe {
1430 std::env::set_var("SSH_AUTH_SOCK", sock);
1431 }
1432 }
1433
1434 assert!(matches!(result, Err(crate::Error::AgentConnectionFailed)));
1436 }
1437
1438 #[tokio::test]
1439 async fn auth_key_file_with_passphrase() {
1440 let client = Client::connect(
1441 test_address(),
1442 &env("ASYNC_SSH2_TEST_HOST_USER"),
1443 AuthMethod::with_key_file(
1444 env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV"),
1445 Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS")),
1446 ),
1447 ServerCheckMethod::NoCheck,
1448 )
1449 .await;
1450 if client.is_err() {
1451 println!("{:?}", client.err());
1452 panic!();
1453 }
1454 assert!(client.is_ok());
1455 }
1456
1457 #[tokio::test]
1458 async fn auth_key_str() {
1459 let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PRIV")).unwrap();
1460
1461 let client = Client::connect(
1462 test_address(),
1463 &env("ASYNC_SSH2_TEST_HOST_USER"),
1464 AuthMethod::with_key(key.as_str(), None),
1465 ServerCheckMethod::NoCheck,
1466 )
1467 .await;
1468 assert!(client.is_ok());
1469 }
1470
1471 #[tokio::test]
1472 async fn auth_key_str_with_passphrase() {
1473 let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV")).unwrap();
1474
1475 let client = Client::connect(
1476 test_address(),
1477 &env("ASYNC_SSH2_TEST_HOST_USER"),
1478 AuthMethod::with_key(key.as_str(), Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS"))),
1479 ServerCheckMethod::NoCheck,
1480 )
1481 .await;
1482 assert!(client.is_ok());
1483 }
1484
1485 #[tokio::test]
1486 async fn auth_keyboard_interactive() {
1487 let client = Client::connect(
1488 test_address(),
1489 &env("ASYNC_SSH2_TEST_HOST_USER"),
1490 AuthKeyboardInteractive::new()
1491 .with_response("Password", env("ASYNC_SSH2_TEST_HOST_PW"))
1492 .into(),
1493 ServerCheckMethod::NoCheck,
1494 )
1495 .await;
1496 assert!(client.is_ok());
1497 }
1498
1499 #[tokio::test]
1500 async fn auth_keyboard_interactive_exact() {
1501 let client = Client::connect(
1502 test_address(),
1503 &env("ASYNC_SSH2_TEST_HOST_USER"),
1504 AuthKeyboardInteractive::new()
1505 .with_response_exact("Password: ", env("ASYNC_SSH2_TEST_HOST_PW"))
1506 .into(),
1507 ServerCheckMethod::NoCheck,
1508 )
1509 .await;
1510 assert!(client.is_ok());
1511 }
1512
1513 #[tokio::test]
1514 async fn auth_keyboard_interactive_wrong_response() {
1515 let client = Client::connect(
1516 test_address(),
1517 &env("ASYNC_SSH2_TEST_HOST_USER"),
1518 AuthKeyboardInteractive::new()
1519 .with_response_exact("Password: ", "wrong password")
1520 .into(),
1521 ServerCheckMethod::NoCheck,
1522 )
1523 .await;
1524 match client {
1525 Err(crate::error::Error::KeyboardInteractiveAuthFailed) => {}
1526 Err(e) => {
1527 panic!("Expected KeyboardInteractiveAuthFailed error. Got error: {e:?}")
1528 }
1529 Ok(_) => panic!("Expected KeyboardInteractiveAuthFailed error."),
1530 }
1531 }
1532
1533 #[tokio::test]
1534 async fn auth_keyboard_interactive_no_response() {
1535 let client = Client::connect(
1536 test_address(),
1537 &env("ASYNC_SSH2_TEST_HOST_USER"),
1538 AuthKeyboardInteractive::new()
1539 .with_response_exact("Password:", "123")
1540 .into(),
1541 ServerCheckMethod::NoCheck,
1542 )
1543 .await;
1544 match client {
1545 Err(crate::error::Error::KeyboardInteractiveNoResponseForPrompt(prompt)) => {
1546 assert_eq!(prompt, "Password: ");
1547 }
1548 Err(e) => {
1549 panic!("Expected KeyboardInteractiveNoResponseForPrompt error. Got error: {e:?}")
1550 }
1551 Ok(_) => panic!("Expected KeyboardInteractiveNoResponseForPrompt error."),
1552 }
1553 }
1554
1555 #[tokio::test]
1556 async fn server_check_file() {
1557 let client = Client::connect(
1558 test_address(),
1559 &env("ASYNC_SSH2_TEST_HOST_USER"),
1560 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1561 ServerCheckMethod::with_public_key_file(&env("ASYNC_SSH2_TEST_SERVER_PUB")),
1562 )
1563 .await;
1564 assert!(client.is_ok());
1565 }
1566
1567 #[tokio::test]
1568 async fn server_check_str() {
1569 let line = std::fs::read_to_string(env("ASYNC_SSH2_TEST_SERVER_PUB")).unwrap();
1570 let mut split = line.split_whitespace();
1571 let key = match (split.next(), split.next()) {
1572 (Some(_), Some(k)) => k,
1573 (Some(k), None) => k,
1574 _ => panic!("Failed to parse pub key file"),
1575 };
1576
1577 let client = Client::connect(
1578 test_address(),
1579 &env("ASYNC_SSH2_TEST_HOST_USER"),
1580 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1581 ServerCheckMethod::with_public_key(key),
1582 )
1583 .await;
1584 assert!(client.is_ok());
1585 }
1586
1587 #[tokio::test]
1588 async fn server_check_by_known_hosts_for_ip() {
1589 let client = Client::connect(
1590 test_address(),
1591 &env("ASYNC_SSH2_TEST_HOST_USER"),
1592 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1593 ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1594 )
1595 .await;
1596 assert!(client.is_ok());
1597 }
1598
1599 #[tokio::test]
1600 async fn server_check_by_known_hosts_for_hostname() {
1601 let client = Client::connect(
1602 test_hostname(),
1603 &env("ASYNC_SSH2_TEST_HOST_USER"),
1604 AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
1605 ServerCheckMethod::with_known_hosts_file(&env("ASYNC_SSH2_TEST_KNOWN_HOSTS")),
1606 )
1607 .await;
1608 if is_running_in_docker() {
1609 assert!(client.is_ok());
1610 } else {
1611 assert!(client.is_err()); }
1613 }
1614
1615 #[tokio::test]
1616 async fn client_can_be_cloned() {
1617 let client = establish_test_host_connection().await;
1618 let client2 = client.clone();
1619
1620 let result1 = client.execute("echo test clone").await.unwrap();
1621 let result2 = client2.execute("echo test clone2").await.unwrap();
1622
1623 assert_eq!(result1.stdout, "test clone\n");
1624 assert_eq!(result2.stdout, "test clone2\n");
1625 }
1626
1627 #[tokio::test]
1628 async fn client_can_upload_file() {
1629 let client = establish_test_host_connection().await;
1630 client
1631 .upload_file(
1632 &env("ASYNC_SSH2_TEST_UPLOAD_FILE"),
1633 "/tmp/uploaded",
1634 None,
1635 None,
1636 false,
1637 )
1638 .await
1639 .unwrap();
1640 let result = client.execute("cat /tmp/uploaded").await.unwrap();
1641 assert_eq!(result.stdout, "this is a test file\n");
1642 }
1643
1644 #[tokio::test]
1645 async fn client_can_download_file() {
1646 let client = establish_test_host_connection().await;
1647
1648 client
1649 .execute("echo 'this is a downloaded test file' > /tmp/test_download")
1650 .await
1651 .unwrap();
1652
1653 let local_path = std::env::temp_dir().join("downloaded_test_file");
1654 client
1655 .download_file("/tmp/test_download", &local_path)
1656 .await
1657 .unwrap();
1658
1659 let contents = tokio::fs::read_to_string(&local_path).await.unwrap();
1660 assert_eq!(contents, "this is a downloaded test file\n");
1661 }
1662}