use bytes::{Bytes, Buf as _};
use futures_core::ready;
use futures_core::future::BoxFuture;
use std::cmp::min;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncBufRead, AsyncWrite};
use crate::codec::PacketEncode;
use crate::error::Result;
use super::channel::{Channel, ChannelReceiver, ChannelEvent, ChannelConfig, DATA_STANDARD};
use super::client::Client;
#[derive(Clone)]
pub struct Tunnel {
channel: Channel,
}
impl Tunnel {
pub(super) async fn connect(
client: &Client,
config: ChannelConfig,
connect_addr: (String, u16),
originator_addr: (String, u16),
) -> Result<(Tunnel, TunnelReceiver)> {
let mut open_payload = PacketEncode::new();
open_payload.put_str(&connect_addr.0);
open_payload.put_u32(connect_addr.1 as u32);
open_payload.put_str(&originator_addr.0);
open_payload.put_u32(originator_addr.1 as u32);
let (channel, channel_rx, _) = client.open_channel(
"direct-tcpip".into(), config, open_payload.finish()).await?;
Ok((Tunnel { channel }, TunnelReceiver { channel_rx }))
}
pub(super) fn accept(channel: Channel, channel_rx: ChannelReceiver) -> Result<(Tunnel, TunnelReceiver)> {
Ok((Tunnel { channel }, TunnelReceiver { channel_rx }))
}
}
impl Tunnel {
pub async fn send_data(&self, data: Bytes) -> Result<()> {
self.channel.send_data(data, DATA_STANDARD).await
}
pub async fn send_eof(&self) -> Result<()> {
self.channel.send_eof().await
}
}
#[derive(Debug)]
pub struct TunnelReceiver {
channel_rx: ChannelReceiver,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum TunnelEvent {
Data(Bytes),
Eof,
}
impl TunnelReceiver {
pub async fn recv(&mut self) -> Result<Option<TunnelEvent>> {
struct Recv<'a> { rx: &'a mut TunnelReceiver }
impl<'a> Future for Recv<'a> {
type Output = Result<Option<TunnelEvent>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.rx.poll_recv(cx)
}
}
Recv { rx: self }.await
}
pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<Result<Option<TunnelEvent>>> {
loop {
match ready!(self.channel_rx.poll_recv(cx)) {
Some(ChannelEvent::Data(data, DATA_STANDARD)) =>
return Poll::Ready(Ok(Some(TunnelEvent::Data(data)))),
Some(ChannelEvent::Eof) =>
return Poll::Ready(Ok(Some(TunnelEvent::Eof))),
Some(ChannelEvent::Data(_, _) | ChannelEvent::Request(_)) =>
continue,
None => return Poll::Ready(Ok(None)),
}
}
}
}
pub struct TunnelReader {
tunnel_rx: TunnelReceiver,
read_buf: Bytes,
read_eof: bool,
}
impl TunnelReader {
pub fn new(tunnel_rx: TunnelReceiver) -> Self {
Self {
tunnel_rx,
read_buf: Bytes::new(),
read_eof: false,
}
}
pub fn into_inner(self) -> TunnelReceiver {
self.tunnel_rx
}
pub fn buffer(&self) -> &Bytes {
&self.read_buf
}
}
impl AsyncRead for TunnelReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let filled_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
let fill_len = min(filled_buf.len(), buf.remaining());
buf.put_slice(&filled_buf[..fill_len]);
self.as_mut().consume(fill_len);
Poll::Ready(Ok(()))
}
}
impl AsyncBufRead for TunnelReader {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<&[u8]>> {
let this = self.get_mut();
loop {
if !this.read_buf.is_empty() || this.read_eof {
return Poll::Ready(Ok(&this.read_buf))
}
match ready!(this.tunnel_rx.poll_recv(cx))? {
Some(TunnelEvent::Data(data)) => this.read_buf = data,
Some(TunnelEvent::Eof) | None => this.read_eof = true,
}
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.get_mut().read_buf.advance(amt);
}
}
pub struct TunnelWriter {
tunnel: Tunnel,
pending_write_fut: Option<BoxFuture<'static, Result<()>>>,
pending_shutdown_fut: Option<BoxFuture<'static, Result<()>>>,
}
impl TunnelWriter {
pub fn new(tunnel: Tunnel) -> Self {
Self {
tunnel,
pending_write_fut: None,
pending_shutdown_fut: None,
}
}
}
impl AsyncWrite for TunnelWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let this = self.get_mut();
debug_assert!(this.pending_write_fut.is_none());
let data = Bytes::copy_from_slice(buf);
let tunnel = this.tunnel.clone();
let mut write_fut = Box::pin(async move {
tunnel.send_data(data).await
});
match write_fut.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
Poll::Pending => this.pending_write_fut = Some(write_fut),
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
if let Some(write_fut) = this.pending_write_fut.as_mut() {
let res = ready!(write_fut.as_mut().poll(cx));
this.pending_write_fut = None;
res?;
}
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let this = self.get_mut();
let shutdown_fut = this.pending_shutdown_fut.get_or_insert_with(|| {
let tunnel = this.tunnel.clone();
Box::pin(async move {
tunnel.send_eof().await
})
});
let res = ready!(shutdown_fut.as_mut().poll(cx));
this.pending_shutdown_fut = None;
res?;
Poll::Ready(Ok(()))
}
}
pub struct TunnelStream {
pub reader: TunnelReader,
pub writer: TunnelWriter,
}
impl TunnelStream {
pub fn new(tunnel: Tunnel, tunnel_rx: TunnelReceiver) -> Self {
let reader = TunnelReader::new(tunnel_rx);
let writer = TunnelWriter::new(tunnel);
Self { reader, writer }
}
}
impl AsyncRead for TunnelStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
}
}
impl AsyncBufRead for TunnelStream {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().reader).poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.get_mut().reader).consume(amt)
}
}
impl AsyncWrite for TunnelStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}