use bytes::Bytes;
use derivative::Derivative;
use futures_core::ready;
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::Weak;
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot};
use crate::codec::PacketDecode;
use crate::error::{Result, Error, ChannelOpenError};
use crate::pubkey::Pubkey;
use super::channel::{Channel, ChannelConfig, ChannelReceiver};
use super::client_state::ClientState;
use super::conn::AcceptedChannel;
use super::tunnel::{Tunnel, TunnelReceiver};
pub struct ClientReceiver {
pub(super) client_st: Weak<Mutex<ClientState>>,
pub(super) event_rx: mpsc::Receiver<ClientEvent>,
pub(super) specialize_channels: bool,
}
impl ClientReceiver {
pub async fn recv(&mut self) -> Result<Option<ClientEvent>> {
struct Recv<'a> { rx: &'a mut ClientReceiver }
impl<'a> Future for Recv<'a> {
type Output = Result<Option<ClientEvent>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.rx.poll_recv(cx)
}
}
Recv { rx: self }.await
}
pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<Result<Option<ClientEvent>>> {
match ready!(self.event_rx.poll_recv(cx)) {
Some(ClientEvent::Channel(mut accept)) => {
accept.client_st = Some(self.client_st.clone());
if accept.channel_type == "forwarded-tcpip" && self.specialize_channels {
let accept = AcceptTunnel::decode(accept)?;
Poll::Ready(Ok(Some(ClientEvent::Tunnel(accept))))
} else {
Poll::Ready(Ok(Some(ClientEvent::Channel(accept))))
}
},
event => Poll::Ready(Ok(event)),
}
}
pub fn specialize_channels(&mut self, enable: bool) {
self.specialize_channels = enable;
}
}
#[non_exhaustive]
#[derive(Debug)]
pub enum ClientEvent {
ServerPubkey(Pubkey, AcceptPubkey),
DebugMsg(DebugMsg),
AuthBanner(AuthBanner),
Tunnel(AcceptTunnel),
Channel(AcceptChannel),
}
#[derive(Debug)]
pub struct DebugMsg {
pub always_display: bool,
pub message: String,
pub message_lang: String,
}
#[derive(Debug)]
pub struct AuthBanner {
pub message: String,
pub message_lang: String,
}
#[derive(Debug)]
pub struct AcceptPubkey {
pub(super) accepted_tx: oneshot::Sender<Result<PubkeyAccepted>>,
}
#[derive(Debug)]
pub(super) struct PubkeyAccepted(());
impl AcceptPubkey {
pub fn accept(self) {
let _: Result<_, _> = self.accepted_tx.send(Ok(PubkeyAccepted(())));
}
pub fn reject<E: std::error::Error + Send + Sync + 'static>(self, err: E) {
let _: Result<_, _> = self.accepted_tx.send(Err(Error::PubkeyAccept(Box::new(err))));
}
}
#[derive(Derivative)]
#[derivative(Debug)]
pub struct AcceptChannel {
#[derivative(Debug = "ignore")]
pub(super) client_st: Option<Weak<Mutex<ClientState>>>,
pub channel_type: String,
pub open_payload: Bytes,
#[derivative(Debug = "ignore")]
pub(super) accepted_tx: oneshot::Sender<Result<AcceptedChannel, ChannelOpenError>>,
}
impl AcceptChannel {
pub async fn accept(self, config: ChannelConfig, confirm_payload: Bytes)
-> Result<(Channel, ChannelReceiver)>
{
let (result_tx, result_rx) = oneshot::channel();
let accepted = AcceptedChannel {
recv_window_max: config.recv_window_max(),
recv_packet_len_max: config.recv_packet_len_max(),
confirm_payload,
result_tx,
};
let _: Result<_, _> = self.accepted_tx.send(Ok(accepted));
let result = result_rx.await.map_err(|_| Error::ClientClosed)?;
let channel = Channel {
client_st: self.client_st.unwrap(),
channel_st: result.channel_st,
};
let channel_rx = ChannelReceiver { event_rx: result.event_rx };
Ok((channel, channel_rx))
}
pub fn reject(self, error: ChannelOpenError) {
let _: Result<_, _> = self.accepted_tx.send(Err(error));
}
pub fn reject_prohibited(self) {}
}
#[derive(Debug)]
pub struct AcceptTunnel {
accept: AcceptChannel,
pub connected_addr: (String, u16),
pub originator_addr: (String, u16),
}
impl AcceptTunnel {
fn decode(accept: AcceptChannel) -> Result<AcceptTunnel> {
let mut payload = PacketDecode::new(accept.open_payload.clone());
let connected_host = payload.get_string()?;
let connected_port = payload.get_u32()? as u16;
let connected_addr = (connected_host, connected_port);
let originator_host = payload.get_string()?;
let originator_port = payload.get_u32()? as u16;
let originator_addr = (originator_host, originator_port);
Ok(AcceptTunnel { accept, connected_addr, originator_addr })
}
pub async fn accept(self, config: ChannelConfig) -> Result<(Tunnel, TunnelReceiver)> {
let (channel, channel_rx) = self.accept.accept(config, Bytes::new()).await?;
Tunnel::accept(channel, channel_rx)
}
pub fn reject(self, error: ChannelOpenError) {
self.accept.reject(error);
}
pub fn reject_prohibited(self) {
self.accept.reject_prohibited();
}
}