use snafu::{ResultExt, Snafu};
use std::net::SocketAddr;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use uuid::Uuid;
use crate::session::events::*;
use crate::session::*;
use crate::{CADetails, ConnectionOptions};
mod connect;
mod demux;
mod http2;
mod stream;
mod tls;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
#[allow(clippy::enum_variant_names)]
pub enum ConfigurationErrorKind
{
DNSError
{
source: rustls::client::InvalidDnsNameError,
},
UriError
{
source: http::uri::InvalidUri,
},
UriPartsError
{
source: http::uri::InvalidUriParts,
},
NoSource {},
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
#[allow(clippy::enum_variant_names)]
pub enum EndpointError
{
IoError
{
source: std::io::Error
},
ConnectError
{
source: httparse::Error
},
H2Error
{
source: h2::Error
},
TlsError
{
source: rustls::Error
},
#[snafu(display("{}", reason))]
ProxideError
{
reason: &'static str
},
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
#[allow(clippy::enum_variant_names)]
pub enum Error
{
#[snafu(display("Configuration error: {}", reason))]
ConfigurationError
{
reason: &'static str,
source: ConfigurationErrorKind,
},
#[snafu(display("Error occurred with the server in {}: {}", scenario, source))]
ServerError
{
scenario: &'static str,
source: EndpointError,
},
#[snafu(display("Error occurred with the client in {}: {}", scenario, source))]
ClientError
{
scenario: &'static str,
source: EndpointError,
},
}
pub type Result<S, E = Error> = std::result::Result<S, E>;
pub struct ConnectionDetails
{
pub uuid: Uuid,
pub protocol_stack: Vec<Protocol>,
pub opaque_redirect: Option<String>,
}
pub struct Streams<TClient, TServer>
{
pub client: TClient,
pub server: TServer,
}
impl<TClient, TServer> Streams<TClient, TServer>
{
pub fn new(client: TClient, server: TServer) -> Self
{
Self { client, server }
}
}
pub async fn run(
client: TcpStream,
src_addr: SocketAddr,
options: Arc<ConnectionOptions>,
ui: Sender<SessionEvent>,
) -> Result<()>
{
let details = ConnectionDetails {
uuid: Uuid::new_v4(),
protocol_stack: vec![],
opaque_redirect: None,
};
connect_phase(details, client, src_addr, options, ui).await
}
pub async fn connect_phase(
mut details: ConnectionDetails,
client: TcpStream,
src_addr: SocketAddr,
options: Arc<ConnectionOptions>,
ui: Sender<SessionEvent>,
) -> Result<()>
{
log::info!("{} - New connection from {:?}", details.uuid, src_addr);
let (protocol, client) =
demux::recognize(client)
.await
.context(IoError {})
.context(ClientError {
scenario: "demuxing stream",
})?;
log::debug!("{} - Top level protocol: {:?}", details.uuid, protocol);
if protocol == demux::Protocol::Connect {
let connect_filter = match &options.proxy {
Some(f) => f,
None => {
return Err(EndpointError::ProxideError {
reason: "CONNECT proxy requests are not allowed",
})
.context(ClientError {
scenario: "setting up server connection",
})
}
};
details.protocol_stack.push(Protocol::Connect);
let connect_data = connect::handle_connect(client).await?;
if connect::check_filter(connect_filter, &connect_data.target_server) {
log::info!("{} - Intercepting CONNECT", details.uuid);
let (protocol, client_stream) = demux::recognize(connect_data.client_stream)
.await
.context(IoError {})
.context(ClientError {
scenario: "demuxing stream",
})?;
log::debug!("{} - Next protocol: {:?}", details.uuid, protocol);
handle_protocol(
details,
protocol,
Streams::new(client_stream, connect_data.server_stream),
src_addr,
connect_data.target_server,
options,
ui,
)
.await
} else {
log::info!("{} - Proxying CONNECT without decoding", details.uuid);
let (server_read, server_write) = connect_data.server_stream.into_split();
let (client_read, client_write) = connect_data.client_stream.into_split();
pipe_stream(client_read, server_write);
pipe_stream(server_read, client_write);
Ok(())
}
} else {
let target_server = match &options.target_server {
Some(t) => t,
None => {
return Err(EndpointError::ProxideError {
reason: "Direct connections are not allowed",
})
.context(ClientError {
scenario: "setting up server connection",
})
}
};
details.opaque_redirect = Some(target_server.to_string());
log::trace!("Connecting directly to {}", target_server);
let server = TcpStream::connect(target_server)
.await
.context(IoError {})
.context(ServerError {
scenario: "connecting",
})?;
handle_protocol(
details,
protocol,
Streams::new(client, server),
src_addr,
target_server.to_string(),
options,
ui,
)
.await
}
}
pub async fn handle_protocol<TClient, TServer>(
mut details: ConnectionDetails,
protocol: demux::Protocol,
streams: Streams<TClient, TServer>,
src_addr: SocketAddr,
target: String,
options: Arc<ConnectionOptions>,
ui: Sender<SessionEvent>,
) -> Result<()>
where
TClient: AsyncRead + AsyncWrite + Unpin + Send + 'static,
TServer: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let ui_clone = ui.clone();
if protocol == demux::Protocol::Tls {
let streams = tls::handle(&mut details, streams, options.clone(), target).await?;
http2::handle(details, src_addr, streams, ui_clone).await?;
} else {
http2::handle(details, src_addr, streams, ui_clone).await?;
}
Ok(())
}
fn pipe_stream<TRead, TWrite>(mut read: TRead, mut write: TWrite)
where
TRead: AsyncRead + Unpin + Send + 'static,
TWrite: AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
tokio::spawn(async move {
let mut b = [0_u8; 1024];
log::info!("Enter");
loop {
let count = match read.read(&mut b).await {
Err(e) => {
log::error!("Error reading data: {}", e);
break;
}
Ok(c) if c == 0 => break,
Ok(c) => c,
};
if let Err(e) = write.write(&b[..count]).await {
log::error!("Error writing data: {}", e);
break;
}
}
log::info!("Exit");
});
}