use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream};
use async_io::Async;
use async_ssh2_lite::{AsyncChannel, AsyncSession, SessionConfiguration};
use async_std_resolver::resolver_from_system_conf;
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
use plain_path::PlainPathExt;
use ssh_jumper_model::{
AuthMethod, Error, HostAddress, HostSocketParams, JumpHostAuthParams, SshForwarderEnd,
SshTunnelParams,
};
use tokio::sync::{oneshot, oneshot::Receiver};
use crate::SshSession;
#[derive(Debug)]
pub struct SshJumper;
impl SshJumper {
pub async fn open_tunnel(
ssh_tunnel_params: &SshTunnelParams<'_>,
) -> Result<(SocketAddr, Receiver<SshForwarderEnd>), Error> {
let SshTunnelParams {
jump_host,
jump_host_auth_params,
local_socket,
target_socket,
} = ssh_tunnel_params;
let ssh_session =
Self::open_ssh_session_with_port(jump_host, jump_host_auth_params).await?;
Self::open_direct_channel(&ssh_session, *local_socket, target_socket).await
}
pub async fn open_ssh_session(
jump_host_addr: &HostAddress<'_>,
jump_host_auth_params: &JumpHostAuthParams<'_>,
) -> Result<SshSession, Error> {
SshJumper::open_ssh_session_with_port(
&HostSocketParams {
address: jump_host_addr.clone(),
port: 22,
},
jump_host_auth_params,
)
.await
}
pub async fn open_ssh_session_with_port(
jump_host_addr: &HostSocketParams<'_>,
jump_host_auth_params: &JumpHostAuthParams<'_>,
) -> Result<SshSession, Error> {
let jump_host_ip = match jump_host_addr.clone().address {
HostAddress::IpAddr(ip_addr) => ip_addr,
HostAddress::HostName(jump_host_addr) => Self::resolve_ip(&jump_host_addr).await?,
};
let stream =
Async::<TcpStream>::connect(SocketAddr::from((jump_host_ip, jump_host_addr.port)))
.await
.map_err(|io_error| Error::JumpHostConnectFail {
jump_host_addr: jump_host_addr.into_static(),
io_error,
})?;
let mut session_configuration = SessionConfiguration::new();
session_configuration.set_compress(true);
let mut session = AsyncSession::new(stream, Some(session_configuration))
.map_err(Error::AsyncSessionInitialize)?;
session.handshake().await.map_err(Error::SshHandshakeFail)?;
Self::ssh_session_authenticate(jump_host_auth_params, &mut session).await?;
if !session.authenticated() {
return Err(session
.last_error()
.map(Error::SshUserAuthError)
.unwrap_or(Error::SshUserAuthUnknownError));
}
Ok(SshSession(session))
}
async fn ssh_session_authenticate(
jump_host_auth_params: &JumpHostAuthParams<'_>,
session: &mut AsyncSession<TcpStream>,
) -> Result<(), Error> {
let jump_host_user_name = &jump_host_auth_params.user_name;
match &jump_host_auth_params.auth_method {
AuthMethod::KeyPair {
private_key,
passphrase,
} => {
let jump_host_public_key = None;
let jump_host_private_key =
private_key.plain().map_err(Error::PrivateKeyPlainPath)?;
let jump_host_private_key_passphrase = passphrase.as_deref();
session
.userauth_pubkey_file(
jump_host_user_name,
jump_host_public_key,
&jump_host_private_key,
jump_host_private_key_passphrase,
)
.await
.map_err(Error::SshUserAuthFail)?;
}
AuthMethod::Password { password } => {
session
.userauth_password(jump_host_user_name, password)
.await
.map_err(Error::SshUserAuthFail)?;
}
}
Ok(())
}
pub async fn open_direct_channel(
ssh_session: &SshSession,
local_socket: SocketAddr,
target_socket: &HostSocketParams<'_>,
) -> Result<(SocketAddr, Receiver<SshForwarderEnd>), Error> {
let target_host_address = target_socket.address.to_string();
let target_host_address = target_host_address.as_str();
let target_port = target_socket.port;
let source = None;
let async_channel = ssh_session
.channel_direct_tcpip(target_host_address, target_port, source)
.await
.map_err(Error::SshTunnelOpenFail)?;
Self::spawn_channel_streamers(local_socket, async_channel).await
}
async fn spawn_channel_streamers<'tunnel>(
local_socket: SocketAddr,
mut jump_host_channel: AsyncChannel<TcpStream>,
) -> Result<(SocketAddr, Receiver<SshForwarderEnd>), Error> {
let local_socket_addr = TcpListener::bind(local_socket)
.map_err(|io_error| Error::LocalSocketBind {
local_socket,
io_error,
})?
.local_addr()
.map_err(|io_error| Error::LocalSocketAddr {
local_socket,
io_error,
})?;
let local_socket_listener = Async::<TcpListener>::bind(local_socket_addr)
.map_err(Error::SshTunnelListenerCreate)?;
let (ssh_forwarder_tx, ssh_forwarder_rx) = oneshot::channel::<SshForwarderEnd>();
let spawn_join_handle = tokio::task::spawn(async move {
let _detached_task = tokio::task::spawn(async move {
let mut buf_jump_host_channel = vec![0; 2048];
let mut buf_forward_stream_r = vec![0; 2048];
match local_socket_listener.accept().await {
Ok((mut forward_stream_r, _)) => loop {
futures::select! {
ret_forward_stream_r = forward_stream_r.read(&mut buf_forward_stream_r).fuse() => match ret_forward_stream_r {
Ok(0) => {
let _send_result = ssh_forwarder_tx.send(SshForwarderEnd::LocalReadEof);
break;
},
Ok(n) => {
if let Err(e) = jump_host_channel.write(&buf_forward_stream_r[..n]).await.map(|_| ()).map_err(|err| {
err
}) {
let _send_result = ssh_forwarder_tx.send(ssh_jumper_model::SshForwarderEnd::LocalToChannelWriteErr(e));
break;
}
},
Err(e) => {
let _send_result = ssh_forwarder_tx.send(ssh_jumper_model::SshForwarderEnd::LocalReadErr(e));
break;
}
},
ret_jump_host_channel = jump_host_channel.read(&mut buf_jump_host_channel).fuse() => match ret_jump_host_channel {
Ok(0) => {
let _send_result = ssh_forwarder_tx.send(SshForwarderEnd::ChannelReadEof);
break;
},
Ok(n) => {
if let Err(e) = forward_stream_r.write(&buf_jump_host_channel[..n]).await.map(|_| ()).map_err(|err| {
err
}) {
let _send_result = ssh_forwarder_tx.send(ssh_jumper_model::SshForwarderEnd::ChannelToLocalWriteErr(e));
break;
}
},
Err(e) => {
let _send_result = ssh_forwarder_tx.send(ssh_jumper_model::SshForwarderEnd::ChannelReadErr(e));
break;
}
},
}
},
Err(e) => {
let _send_result =
ssh_forwarder_tx.send(SshForwarderEnd::LocalConnectFail(e));
}
}
});
});
spawn_join_handle
.await
.map_err(Error::SshStreamerSpawnFail)?;
Ok((local_socket_addr, ssh_forwarder_rx))
}
async fn resolve_ip<'tunnel>(jump_host_addr: &str) -> Result<IpAddr, Error> {
let resolver = resolver_from_system_conf()
.await
.map_err(Error::DnsResolverCreate)?;
let mut lookup_addr = String::with_capacity(jump_host_addr.len() + 1);
lookup_addr.push_str(jump_host_addr);
lookup_addr.push('.');
let response = resolver
.lookup_ip(lookup_addr)
.await
.map_err(Error::DnsResolverLookup)?;
if let Some(host_ip) = response.iter().next() {
Ok(host_ip)
} else {
Err(Error::JumpHostIpResolutionFail {
jump_host_addr: jump_host_addr.to_string(),
})
}
}
}