use bytes::Bytes;
use derivative::Derivative;
use parking_lot::Mutex;
use pin_project::pin_project;
use rand::rngs::OsRng;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, oneshot};
use crate::cipher::{self, CipherAlgo};
use crate::codec::{PacketDecode, PacketEncode};
use crate::error::{Error, Result, DisconnectError};
use crate::kex::{self, KexAlgo};
use crate::mac::{self, MacAlgo};
use crate::pubkey::{self, PubkeyAlgo, Pubkey, Privkey};
use super::{auth, negotiate};
use super::auth_method::none::{AuthNone, AuthNoneResult};
use super::auth_method::password::{AuthPassword, AuthPasswordResult};
use super::auth_method::pubkey::{AuthPubkey, AuthPubkeyResult, CheckPubkey};
use super::channel::{Channel, ChannelReceiver, ChannelConfig};
use super::client_event::ClientReceiver;
use super::client_state::{self, ClientState};
use super::conn::{self, OpenChannel};
use super::session::{Session, SessionReceiver};
use super::tunnel::{Tunnel, TunnelReceiver};
#[derive(Clone)]
pub struct Client {
pub(super) client_st: Weak<Mutex<ClientState>>,
}
impl Client {
pub fn open<IO>(stream: IO, config: ClientConfig) -> Result<(Client, ClientReceiver, ClientFuture<IO>)>
where IO: AsyncRead + AsyncWrite
{
let rng = Box::new(OsRng);
let (event_tx, event_rx) = mpsc::channel(1);
let client_st = client_state::new_client(config, rng, event_tx)?;
let client_st = Arc::new(Mutex::new(client_st));
let client = Client { client_st: Arc::downgrade(&client_st) };
let client_rx = ClientReceiver {
client_st: Arc::downgrade(&client_st),
event_rx,
specialize_channels: true,
};
let client_fut = ClientFuture { client_st, stream };
Ok((client, client_rx, client_fut))
}
fn upgrade(&self) -> Result<Arc<Mutex<ClientState>>> {
self.client_st.upgrade().ok_or(Error::ClientClosed)
}
pub async fn auth_none(&self, username: String) -> Result<AuthNoneResult> {
let (result_tx, result_rx) = oneshot::channel();
let method = AuthNone::new(username, result_tx);
auth::start_method(&mut self.upgrade()?.lock(), Box::new(method))?;
result_rx.await.map_err(|_| Error::AuthAborted)
}
pub async fn auth_password(&self, username: String, password: String) -> Result<AuthPasswordResult> {
let (result_tx, result_rx) = oneshot::channel();
let method = AuthPassword::new(username, password, result_tx);
auth::start_method(&mut self.upgrade()?.lock(), Box::new(method))?;
result_rx.await.map_err(|_| Error::AuthAborted)
}
pub async fn auth_pubkey(
&self,
username: String,
privkey: Privkey,
pubkey_algo: &'static PubkeyAlgo,
) -> Result<AuthPubkeyResult> {
let (result_tx, result_rx) = oneshot::channel();
let method = AuthPubkey::new(username, privkey, pubkey_algo, result_tx);
auth::start_method(&mut self.upgrade()?.lock(), Box::new(method))?;
result_rx.await.map_err(|_| Error::AuthAborted)?
}
pub async fn check_pubkey(
&self,
username: String,
pubkey: &Pubkey,
pubkey_algo: &'static PubkeyAlgo,
) -> Result<bool> {
let (result_tx, result_rx) = oneshot::channel();
let method = CheckPubkey::new(username, pubkey, pubkey_algo, result_tx);
auth::start_method(&mut self.upgrade()?.lock(), Box::new(method))?;
result_rx.await.map_err(|_| Error::AuthAborted)
}
pub fn auth_pubkey_algo_names(&self) -> Result<Option<Vec<String>>> {
Ok(self.upgrade()?.lock().their_ext_info.auth_pubkey_algo_names.clone())
}
pub fn is_authenticated(&self) -> Result<bool> {
Ok(auth::is_authenticated(&self.upgrade()?.lock()))
}
pub async fn open_session(&self, config: ChannelConfig) -> Result<(Session, SessionReceiver)> {
Session::open(self, config).await
}
pub async fn connect_tunnel(
&self,
config: ChannelConfig,
connect_addr: (String, u16),
originator_addr: (String, u16),
) -> Result<(Tunnel, TunnelReceiver)> {
Tunnel::connect(self, config, connect_addr, originator_addr).await
}
pub fn bind_tunnel(&self, bind_addr: (String, u16)) -> Result<ClientResp<Option<u16>>> {
let (reply_tx, reply_rx) = oneshot::channel();
let mut payload = PacketEncode::new();
payload.put_str(&bind_addr.0);
payload.put_u32(bind_addr.1 as u32);
self.send_request(GlobalReq {
request_type: "tcpip-forward".into(),
payload: payload.finish(),
reply_tx: Some(reply_tx),
})?;
Ok(ClientResp::map(reply_rx, |payload| {
if payload.remaining_len() >= 4 {
payload.get_u32().map(|x| Some(x as u16))
} else {
Ok(None)
}
}))
}
pub fn unbind_tunnel(&self, bind_addr: (String, u16)) -> Result<ClientResp<()>> {
let (reply_tx, reply_rx) = oneshot::channel();
let mut payload = PacketEncode::new();
payload.put_str(&bind_addr.0);
payload.put_u32(bind_addr.1 as u32);
self.send_request(GlobalReq {
request_type: "cancel-tcpip-forward".into(),
payload: payload.finish(),
reply_tx: Some(reply_tx),
})?;
Ok(ClientResp::map(reply_rx, |_payload| Ok(())))
}
pub async fn open_channel(&self, channel_type: String, config: ChannelConfig, open_payload: Bytes)
-> Result<(Channel, ChannelReceiver, Bytes)>
{
let (result_tx, result_rx) = oneshot::channel();
let open = OpenChannel {
channel_type,
recv_window_max: config.recv_window_max(),
recv_packet_len_max: config.recv_packet_len_max(),
open_payload,
result_tx,
};
conn::open_channel(&mut self.upgrade()?.lock(), open);
let result = result_rx.await.map_err(|_| Error::ChannelClosed)??;
let channel = Channel {
client_st: self.client_st.clone(),
channel_st: result.channel_st,
};
let channel_rx = ChannelReceiver {
event_rx: result.event_rx,
};
Ok((channel, channel_rx, result.confirm_payload))
}
pub fn send_keepalive(&self) -> Result<()> {
let (reply_tx, _reply_rx) = oneshot::channel();
let req = GlobalReq {
request_type: "keepalive@openssh.com".to_owned(),
payload: Bytes::new(),
reply_tx: Some(reply_tx),
};
self.send_request(req)
}
pub fn send_request(&self, req: GlobalReq) -> Result<()> {
conn::send_request(&mut self.upgrade()?.lock(), req)
}
pub async fn rekey(&self) -> Result<()> {
let (done_tx, done_rx) = oneshot::channel();
negotiate::start_kex(&mut self.upgrade()?.lock(), Some(done_tx));
done_rx.await.map_err(|_| Error::RekeyAborted)?
}
pub fn disconnect(&self, error: DisconnectError) -> Result<()> {
client_state::disconnect(&mut self.upgrade()?.lock(), error)
}
}
#[derive(Debug)]
pub struct GlobalReq {
pub request_type: String,
pub payload: Bytes,
pub reply_tx: Option<oneshot::Sender<GlobalReply>>,
}
#[derive(Debug)]
pub enum GlobalReply {
Success(Bytes),
Failure,
}
#[derive(Derivative)]
#[derivative(Debug)]
#[must_use = "please use .wait().await to await the response, or .ignore() to ignore it"]
pub struct ClientResp<T> {
reply_rx: oneshot::Receiver<GlobalReply>,
#[derivative(Debug = "ignore")]
map_fn: Box<dyn FnOnce(&mut PacketDecode) -> Result<T> + Send + Sync>,
}
impl<T> ClientResp<T> {
fn map<F>(reply_rx: oneshot::Receiver<GlobalReply>, map_fn: F) -> Self
where F: FnOnce(&mut PacketDecode) -> Result<T> + Send + Sync + 'static
{
Self { reply_rx, map_fn: Box::new(map_fn) }
}
pub async fn wait(self) -> Result<T> {
match self.reply_rx.await {
Ok(GlobalReply::Success(payload)) => (self.map_fn)(&mut PacketDecode::new(payload)),
Ok(GlobalReply::Failure) => Err(Error::GlobalReq),
Err(_) => Err(Error::ClientClosed),
}
}
pub fn ignore(self) {}
}
#[pin_project]
pub struct ClientFuture<IO> {
client_st: Arc<Mutex<client_state::ClientState>>,
#[pin] stream: IO,
}
impl<IO> ClientFuture<IO> {
pub fn into_stream(self) -> IO {
self.stream
}
}
impl<IO> Future for ClientFuture<IO>
where IO: AsyncRead + AsyncWrite
{
type Output = Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
let this = self.project();
let mut client_st = this.client_st.lock();
let res = client_state::poll_client(&mut client_st, this.stream, cx);
if let Poll::Ready(Err(ref err)) = res {
log::debug!("client future returned error: {:#}", err);
}
res
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ClientConfig {
pub kex_algos: Vec<&'static KexAlgo>,
pub server_pubkey_algos: Vec<&'static PubkeyAlgo>,
pub cipher_algos: Vec<&'static CipherAlgo>,
pub mac_algos: Vec<&'static MacAlgo>,
pub rekey_after_bytes: u64,
pub rekey_after_duration: Duration,
}
impl Default for ClientConfig {
fn default() -> Self {
ClientConfig {
kex_algos: vec![
&kex::CURVE25519_SHA256, &kex::CURVE25519_SHA256_LIBSSH,
],
server_pubkey_algos: vec![
&pubkey::SSH_ED25519,
&pubkey::RSA_SHA2_256, &pubkey::RSA_SHA2_512,
],
cipher_algos: vec![
&cipher::CHACHA20_POLY1305,
&cipher::AES128_GCM, &cipher::AES256_GCM,
&cipher::AES128_CTR, &cipher::AES192_CTR, &cipher::AES256_CTR,
],
mac_algos: vec![
&mac::HMAC_SHA2_256_ETM, &mac::HMAC_SHA2_512_ETM,
&mac::HMAC_SHA2_256, &mac::HMAC_SHA2_512,
],
rekey_after_bytes: 1 << 30,
rekey_after_duration: Duration::from_secs(60 * 60),
}
}
}
impl ClientConfig {
pub fn default_compatible_less_secure() -> ClientConfig {
Self::default().with(|c| {
c.kex_algos.extend_from_slice(&[
&kex::DIFFIE_HELLMAN_GROUP14_SHA256,
&kex::DIFFIE_HELLMAN_GROUP16_SHA512,
&kex::DIFFIE_HELLMAN_GROUP18_SHA512,
&kex::DIFFIE_HELLMAN_GROUP14_SHA1,
]);
c.server_pubkey_algos.extend_from_slice(&[
&pubkey::ECDSA_SHA2_NISTP256,
&pubkey::ECDSA_SHA2_NISTP384,
&pubkey::SSH_RSA_SHA1,
]);
c.cipher_algos.extend_from_slice(&[
&cipher::AES128_CBC, &cipher::AES192_CBC, &cipher::AES256_CBC,
]);
c.mac_algos.extend_from_slice(&[
&mac::HMAC_SHA1_ETM, &mac::HMAC_SHA1,
]);
})
}
#[cfg(feature = "insecure-crypto")]
pub fn default_insecure() -> ClientConfig {
Self::default_compatible_less_secure().with(|c| {
c.kex_algos.extend_from_slice(&[
&kex::DIFFIE_HELLMAN_GROUP1_SHA1,
]);
c.cipher_algos.extend_from_slice(&[
&cipher::TDES_CBC,
]);
})
}
pub fn with<F: FnOnce(&mut Self)>(mut self, f: F) -> Self {
f(&mut self);
self
}
}