use std::{
collections::HashMap,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{
Context,
Poll,
},
};
use async_trait::async_trait;
use futures::Stream;
#[cfg(feature = "hyper")]
use hyper::server::accept::Accept;
use muxado::{
typed::TypedStream,
Error as MuxadoError,
};
use thiserror::Error;
use tokio::{
io::{
AsyncRead,
AsyncWrite,
},
sync::mpsc::Receiver,
};
use crate::{
config::{
HttpTunnelBuilder,
LabeledTunnelBuilder,
TcpTunnelBuilder,
TlsTunnelBuilder,
},
internals::raw_session::RpcError,
session::ConnectError,
Session,
};
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum AcceptError {
#[error("transport error")]
Transport(#[from] MuxadoError),
#[error("reconnect error")]
Reconnect(#[from] Arc<ConnectError>),
}
pub(crate) struct TunnelInner {
pub(crate) id: String,
pub(crate) proto: String,
pub(crate) url: String,
pub(crate) labels: HashMap<String, String>,
pub(crate) forwards_to: String,
pub(crate) metadata: String,
pub(crate) incoming: Receiver<Result<Conn, AcceptError>>,
pub(crate) session: Session,
}
impl Drop for TunnelInner {
fn drop(&mut self) {
let id = self.id().to_string();
let sess = self.session.clone();
let rt = sess.runtime();
rt.spawn(async move { sess.close_tunnel(&id).await });
}
}
macro_rules! tunnel_trait {
($($hyper_bound:tt)*) => {
#[async_trait]
pub trait Tunnel:
Stream<Item = Result<Conn, AcceptError>>
$($hyper_bound)*
+ Unpin
+ Send
+ 'static
{
fn id(&self) -> &str;
fn forwards_to(&self) -> &str;
fn metadata(&self) -> &str;
async fn close(&mut self) -> Result<(), RpcError>;
}
}
}
#[cfg(not(feature = "hyper"))]
tunnel_trait!();
#[cfg(feature = "hyper")]
tunnel_trait!(+ Accept<Conn = Conn, Error = AcceptError>);
pub trait UrlTunnel: Tunnel {
fn url(&self) -> &str;
}
pub trait ProtoTunnel: Tunnel {
fn proto(&self) -> &str;
}
pub trait LabelsTunnel: Tunnel {
fn labels(&self) -> &HashMap<String, String>;
}
pub struct Conn {
pub(crate) remote_addr: SocketAddr,
pub(crate) stream: TypedStream,
}
impl Stream for TunnelInner {
type Item = Result<Conn, AcceptError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.incoming.poll_recv(cx)
}
}
#[cfg(feature = "hyper")]
impl Accept for TunnelInner {
type Conn = Conn;
type Error = AcceptError;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
}
}
impl TunnelInner {
pub fn id(&self) -> &str {
&self.id
}
pub fn url(&self) -> &str {
&self.url
}
pub async fn close(&mut self) -> Result<(), RpcError> {
self.session.close_tunnel(&self.id).await?;
self.incoming.close();
Ok(())
}
pub fn proto(&self) -> &str {
&self.proto
}
pub fn labels(&self) -> &HashMap<String, String> {
&self.labels
}
pub fn forwards_to(&self) -> &str {
&self.forwards_to
}
pub fn metadata(&self) -> &str {
&self.metadata
}
}
impl Conn {
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
impl AsyncRead for Conn {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for Conn {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.stream).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.stream).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.stream).poll_shutdown(cx)
}
}
#[cfg(feature = "axum")]
use axum::extract::connect_info::Connected;
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
#[cfg(feature = "axum")]
impl Connected<&Conn> for SocketAddr {
fn connect_info(target: &Conn) -> Self {
target.remote_addr
}
}
macro_rules! make_tunnel_type {
($(#[$outer:meta])* $wrapper:ident, $builder:tt, $($m:tt),*) => {
$(#[$outer])*
pub struct $wrapper {
pub(crate) inner: TunnelInner,
}
#[async_trait]
impl Tunnel for $wrapper {
fn id(&self) -> &str {
self.inner.id()
}
async fn close(&mut self) -> Result<(), RpcError> {
self.inner.close().await
}
fn forwards_to(&self) -> &str {
self.inner.forwards_to()
}
fn metadata(&self) -> &str {
self.inner.metadata()
}
}
impl $wrapper {
pub fn builder(session: Session) -> $builder {
$builder::from(session)
}
}
$(
make_tunnel_type!($m; $wrapper);
)*
impl Stream for $wrapper {
type Item = Result<Conn, AcceptError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
#[cfg(feature = "hyper")]
#[cfg_attr(all(feature = "hyper", docsrs), doc(cfg(feature = "hyper")))]
impl Accept for $wrapper {
type Conn = Conn;
type Error = AcceptError;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
};
(url; $wrapper:ty) => {
impl UrlTunnel for $wrapper {
fn url(&self) -> &str {
self.inner.url()
}
}
};
(proto; $wrapper:ty) => {
impl ProtoTunnel for $wrapper {
fn proto(&self) -> &str {
self.inner.proto()
}
}
};
(labels; $wrapper:ty) => {
impl LabelsTunnel for $wrapper {
fn labels(&self) -> &HashMap<String, String> {
self.inner.labels()
}
}
};
}
make_tunnel_type! {
HttpTunnel, HttpTunnelBuilder, url, proto
}
make_tunnel_type! {
TcpTunnel, TcpTunnelBuilder, url, proto
}
make_tunnel_type! {
TlsTunnel, TlsTunnelBuilder, url, proto
}
make_tunnel_type! {
LabeledTunnel, LabeledTunnelBuilder, labels
}