use crate::mitm::{certificate_authority::CertificateAuthority, HttpContext, HttpHandler, RequestOrResponse, Rewind, rustls, WebSocketContext, WebSocketHandler};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, FutureExt};
use hyper_util::{
client::legacy::{
connect::{Connect, HttpConnector},
Client,
},
rt::TokioExecutor,
server
};
use product_os_http::{
HeaderValue, StatusCode, Request, Response, Method, Uri,
uri::{Authority, Scheme}
};
use product_os_http_body::{
Body, BodyExt, BodyBytes, Full, BodyError, StreamBody
};
use product_os_request::{ProductOSClient, ProductOSRequest, ProductOSRequestClient, ProductOSRequester, ProductOSRequestError, ProductOSResponse};
use product_os_server::RequestParts;
use std::{
future::Future, net::SocketAddr, sync::Arc
};
use std::fmt::{Debug, Formatter};
use std::io::Read;
use std::time::Duration;
use bstr::ByteSlice;
use bytes::Bytes;
use hyper::body::{Frame, Incoming};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use hyper::service::service_fn;
use hyper_tungstenite::{ HyperWebsocket, HyperWebsocketStream };
use hyper_tungstenite::tungstenite::Error;
use hyper_tungstenite::tungstenite::error::ProtocolError;
use hyper_tungstenite::tungstenite::protocol::{Role, WebSocketConfig};
use hyper_util::server::conn::auto::Builder;
use product_os_utilities::ProductOSError;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite},
task::JoinHandle,
net::TcpStream
};
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::{
tungstenite::{self, Message},
Connector, WebSocketStream,
};
#[cfg(feature = "tor")]
use arti_client::TorClient;
use hyper::http::uri::Port;
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
use tracing::{Instrument, Span };
use crate::mitm::rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use crate::mitm::rustls::{DigitallySignedStruct, SignatureScheme};
fn spawn_with_trace<T: Send + Sync + 'static>(fut: impl Future<Output = T> + Send + 'static, span: Span) -> JoinHandle<T> {
tokio::spawn(fut.instrument(span))
}
pub struct InternalProxy<C, CA, H, W> {
pub client_addr: SocketAddr,
pub ca: Arc<CA>,
pub client: Client<C, BodyBytes>,
pub websocket_connector: Option<Connector>,
pub http_handler: H,
pub websocket_handler: W,
pub server: Builder<TokioExecutor>,
pub custom_requester: Option<product_os_request::ProductOSRequestClient>,
pub compression: product_os_configuration::NetworkProxyCompression,
#[cfg(feature = "tor")]
pub tor_client: Option<arti_client::TorClient<tor_rtcompat::PreferredRuntime>>,
#[cfg(feature = "vpn")]
pub vpn_client: Option<product_os_vpn::ProductOSVPN>,
}
impl<C, CA, H, W> Clone for InternalProxy<C, CA, H, W>
where
C: Connect + Clone + Send + Sync + 'static,
CA: CertificateAuthority,
H: HttpHandler + Clone,
W: WebSocketHandler + Clone
{
fn clone(&self) -> Self {
InternalProxy {
client_addr: self.client_addr,
ca: Arc::clone(&self.ca),
client: self.client.clone(),
websocket_connector: self.websocket_connector.clone(),
http_handler: self.http_handler.clone(),
websocket_handler: self.websocket_handler.clone(),
custom_requester: self.custom_requester.clone(),
compression: self.compression.clone(),
server: self.server.clone(),
#[cfg(feature = "tor")]
tor_client: self.tor_client.clone(),
#[cfg(feature = "vpn")]
vpn_client: self.vpn_client.clone()
}
}
}
impl<C, CA, H, W> InternalProxy<C, CA, H, W>
where
C: Connect + Clone + Send + Sync + 'static,
CA: CertificateAuthority,
H: HttpHandler + Clone,
W: WebSocketHandler + Clone
{
fn context(&self) -> HttpContext {
HttpContext {
client_addr: self.client_addr,
}
}
pub async fn proxy(mut self, req: Request<Incoming>) -> Result<Response<BodyBytes>, hyper::Error> {
let ctx = HttpContext {
client_addr: self.client_addr,
};
let mut req = match self
.http_handler
.handle_request(&ctx, req.map(BodyBytes::from))
.instrument(tracing::trace_span!("handle_request"))
.await
{
RequestOrResponse::Request(req) => req,
RequestOrResponse::Response(res) => return Ok(res),
};
if hyper_tungstenite::is_upgrade_request(&req) {
let res = self.upgrade_websocket(req);
Ok(res)
}
else if req.method() == Method::CONNECT {
match req.uri().authority().cloned() {
Some(authority) => {
self.process_connect(req, authority);
Ok(Response::new(BodyBytes::empty()))
}
None => Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(BodyBytes::empty()).unwrap())
}
}
else {
let (parts, body) = req.into_parts();
let body_bytes = body.map_frame(|b| {b}).collect().await.unwrap().to_bytes();
let request_original = Request::from_parts(parts.clone(), BodyBytes::from(Full::new(body_bytes.clone())));
let mut request = Request::from_parts(parts, BodyBytes::from(Full::new(body_bytes)));
request.headers_mut().remove("accept-encoding");
match self.compression {
product_os_configuration::NetworkProxyCompression::None => {
match HeaderValue::from_str("identity") {
Ok(value) => { request.headers_mut().insert("accept-encoding", value); }
Err(_) => {}
};
}
product_os_configuration::NetworkProxyCompression::Gzip => {
match HeaderValue::from_str("gzip") {
Ok(value) => { request.headers_mut().insert("accept-encoding", value); }
Err(_) => {}
};
}
product_os_configuration::NetworkProxyCompression::Brotli => {
match HeaderValue::from_str("br") {
Ok(value) => { request.headers_mut().insert("accept-encoding", value); }
Err(_) => {}
};
}
}
let mut res = Response::new(BodyBytes::empty());
#[cfg(not(feature = "tor"))]
{
tracing::debug!("Attempting request: {:?}", request);
match self.custom_requester {
None => {
res = match self.client
.request(normalize_request(request))
.await {
Ok(mut response) => {
tracing::debug!("Response being made: {:?}", response);
let (parts, body) = response.into_parts();
let body_bytes = body.map_frame(|b| {b}).collect().await.unwrap().to_bytes();
Response::from_parts(parts, BodyBytes::from(Full::new(body_bytes)))
}
Err(e) => {
tracing::error!("Error attempting request: {:?}", e);
Response::new(BodyBytes::empty())
}
};
}
Some(custom_requester) => {
tracing::info!("Request being made: {:?}", request);
res = match custom_requester.request_raw(request).await {
Ok(res) => {
let status = res.status();
let headers = res.get_headers();
match custom_requester.bytes(res).await {
Ok(bytes) => {
let mut response = Response::builder()
.status(status);
for (key, value) in headers {
response = response.header(key, value);
}
match response.body(BodyBytes::from(Full::new(bytes))) {
Ok(response) => {
tracing::info!("Response being made: {:?}", response);
response
},
Err(e) => {
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
}
}
}
Err(e) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
}
}
Err(e) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
};
}
}
}
#[cfg(feature = "tor")]
{
res = match self.tor_client {
Some(tor_client) => {
let uri = request.uri();
let host = match uri.host() {
None => "unknown",
Some(host) => host
};
let port = match uri.scheme() {
None => {
match uri.port() {
None => 80,
Some(port) => {
match port.as_str() {
"443" => 443,
_ => 80
}
}
}
}
Some(scheme) => {
match scheme.as_str().to_uppercase().as_str() {
"HTTPS" => 443,
"HTTP" | _ => 80,
}
}
};
match tor_client.connect((host, port)).await {
Ok(stream) => {
if port == 443 {
tracing::info!("Attempting TLS connect: {:?}, port: {:?}", host, port);
let mut tls_connector_builder = rustls::ClientConfig::builder();
let tls_connector_config = tls_connector_builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(CustomCertVerifier::new()))
.with_no_client_auth();
let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(tls_connector_config));
match tls_connector.connect(ServerName::try_from(host.to_owned()).unwrap(), stream).await {
Ok(stream) => {
tracing::info!("TLS connect success: {:?}, address: {:?}", host, port);
make_stream_request(normalize_request(request), stream).await
}
Err(e) => {
tracing::error!("Error attempting request via Tor: {:?}", e);
Response::new(BodyBytes::empty())
}
}
}
else {
make_stream_request(normalize_request(request), stream).await
}
}
Err(e) => {
tracing::error!("Error attempting to create connection via Tor: {:?}", e);
Response::new(BodyBytes::empty())
}
}
}
None => {
tracing::debug!("Attempting request via standard channel: {:?}", request);
match self.custom_requester {
None => {
match self.client
.request(normalize_request(request))
.await {
Ok(res) => {
let (parts, body) = res.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
let body_bytes = BodyBytes::new(bytes);
Response::from_parts(parts, body_bytes)
}
Err(e) => {
tracing::error!("Error attempting request via Vpn: {:?}", e);
Response::new(BodyBytes::empty())
}
}
}
Some(custom_requester) => {
match custom_requester.request_raw(request).await {
Ok(res) => {
let status = res.status();
let headers = res.get_headers();
match custom_requester.bytes(res).await {
Ok(bytes) => {
let mut response = Response::builder()
.status(status);
for (key, value) in headers {
response = response.header(key, value);
}
match response.body(BodyBytes::from(Full::new(bytes))) {
Ok(response) => response,
Err(e) => {
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
}
}
}
Err(e) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
}
}
Err(e) => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(BodyBytes::empty()).unwrap()
}
}
}
}
}
}
Ok(self
.http_handler
.handle_response(&ctx, request_original, res)
.instrument(tracing::trace_span!("handle_response"))
.await)
}
}
async fn upgrade_connect(&self, req: &mut Request<BodyBytes>) -> Result<(Rewind<TokioIo<Upgraded>>, [u8; 1024], usize, [u8; 4]), ProductOSError> {
let uri = req.uri().to_owned();
let path = uri.path();
let authority = match uri.authority() {
None => Authority::from_static(""),
Some(a) => a.to_owned()
};
tracing::debug!("Connect request {:?}", req);
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
let mut upgraded = TokioIo::new(upgraded);
let mut buffer = [0; 1024];
let mut bytes_read = match upgraded.read(&mut buffer).await {
Ok(bytes_read) => bytes_read,
Err(e) => {
tracing::error!("Failed to read from upgraded connection: {}", e);
return Err(ProductOSError::GenericError(e.to_string()))
}
};
tracing::debug!("Connect data {:02X?}", &buffer[..bytes_read]);
let first_bytes = [buffer[0], buffer[1], buffer[2], buffer[3]];
if buffer[..1] == *b"\x81" || buffer[..1] == *b"\x82" || buffer[..1] == *b"\x88" {
}
let upgraded = Rewind::new(
upgraded,
Bytes::copy_from_slice(buffer[..bytes_read].as_ref()),
);
Ok((upgraded, buffer, bytes_read, first_bytes))
}
Err(e) => {
tracing::error!("Upgrade error: {}", e);
Err(ProductOSError::GenericError(e.to_string()))
}
}
}
fn process_connect(mut self, mut req: Request<BodyBytes>, authority: Authority) {
let span = tracing::trace_span!("process_connect");
let fut = async move {
match self.upgrade_connect(&mut req).await {
Ok((mut upgraded, buffer, data_len, first_bytes)) => {
if self
.http_handler
.should_intercept(&self.context(), &req)
.await {
if first_bytes[..4] == *b"GET " {
match self.serve_stream(TokioIo::new(upgraded), Scheme::HTTP, authority).await {
Ok(_) => {}
Err(e) => {
tracing::error!("WebSocket connect error: {}", e);
}
}
return;
}
else if first_bytes[..2] == *b"\x16\x03" {
let server_config = self
.ca
.gen_server_config(&authority)
.instrument(tracing::info_span!("gen_server_config"))
.await;
let stream = match TlsAcceptor::from(server_config)
.accept(upgraded)
.await
{
Ok(stream) => TokioIo::new(stream),
Err(e) => {
tracing::error!("Failed to establish TLS connection: {}", e);
return;
}
};
if let Err(e) =
self.serve_stream(stream, Scheme::HTTPS, authority).await
{
if !e
.to_string()
.starts_with("error shutting down connection")
{
tracing::error!("HTTPS connect error: {}", e);
}
}
return;
}
else if first_bytes[..1] == *b"\x81" || buffer[..1] == *b"\x82" || buffer[..1] == *b"\x88" {
tracing::warn!(
"Websocket request: {:?} don't know yet what to do - request: {:02X?}",
req,
&buffer[..data_len]
);
return;
}
else {
tracing::warn!(
"Unknown protocol, read '{:02X?}' from upgraded connection",
&buffer[..data_len]
);
}
}
let mut server = match TcpStream::connect(authority.as_ref()).await {
Ok(server) => server,
Err(e) => {
tracing::error!("Failed to connect to {}: {}", authority, e);
return;
}
};
if let Err(e) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await {
tracing::error!("Failed to tunnel to {}: {}", authority, e);
}
}
Err(e) => {
tracing::error!("Problem upgrading request: {:?}", e);
return;
}
}
};
spawn_with_trace(fut, span);
}
fn upgrade_websocket(self, req: Request<BodyBytes>) -> Response<BodyBytes> {
let mut req = {
let (mut parts, _) = req.into_parts();
parts.uri = {
let mut parts = parts.uri.into_parts();
parts.scheme = if parts.scheme.unwrap_or(Scheme::HTTP) == Scheme::HTTP {
Some("ws".try_into().expect("Failed to convert scheme"))
} else {
Some("wss".try_into().expect("Failed to convert scheme"))
};
match Uri::from_parts(parts) {
Ok(uri) => uri,
Err(_) => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(BodyBytes::new(Bytes::new()))
.expect("Failed to build response")
}
}
};
Request::from_parts(parts, ())
};
match hyper_tungstenite::upgrade(&mut req, None) {
Ok((res, websocket)) => {
let span = tracing::info_span!("websocket");
let fut = async move {
match websocket.await {
Ok(ws) => {
if let Err(e) = self.handle_websocket(ws, req).await {
tracing::error!("Failed to handle WebSocket: {}", e);
}
}
Err(e) => {
tracing::error!("Failed to upgrade to WebSocket {:?}: {}", req.uri(), e);
}
}
};
spawn_with_trace(fut, span);
let (parts, body) = res.into_parts();
Response::from_parts(parts, BodyBytes::from(body))
}
Err(e) => Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(BodyBytes::new(Bytes::from(e.to_string())))
.expect("Failed to build response")
}
}
async fn handle_websocket(self, client_socket: WebSocketStream<TokioIo<Upgraded>>, req: Request<()>) -> Result<(), tungstenite::Error> {
let uri = req.uri().clone();
#[cfg(any(feature = "rustls_client"))]
let (server_socket, _) = tokio_tungstenite::connect_async_tls_with_config(
req,
None,
false,
self.websocket_connector,
)
.await?;
#[cfg(not(any(feature = "rustls_client")))]
let (server_socket, _) = tokio_tungstenite::connect_async(req).await?;
let (server_sink, server_stream) = server_socket.split();
let (client_sink, client_stream) = client_socket.split();
let InternalProxy {
websocket_handler, ..
} = self;
spawn_message_forwarder(
server_stream,
client_sink,
websocket_handler.clone(),
WebSocketContext::ServerToClient {
src: uri.clone(),
dst: self.client_addr,
},
);
spawn_message_forwarder(
client_stream,
server_sink,
websocket_handler,
WebSocketContext::ClientToServer {
src: self.client_addr,
dst: uri,
},
);
Ok(())
}
async fn serve_stream<I>(self, stream: I, scheme: Scheme, authority: Authority) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
{
let service = service_fn(|req | {
self.clone().proxy(req)
});
self.server.serve_connection_with_upgrades(stream, service).await
}
}
fn spawn_message_forwarder(
mut stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
mut sink: impl Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
mut handler: impl WebSocketHandler,
ctx: WebSocketContext) {
let span = tracing::trace_span!("message_forwarder", context = ?ctx);
let fut = async move {
while let Some(message) = stream.next().await {
match message {
Ok(message) => {
let message = match handler.handle_message(&ctx, message).await {
Some(message) => message,
None => continue,
};
match sink.send(message).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => tracing::error!("Websocket send error: {:?}", e),
_ => (),
}
}
Err(e) => {
tracing::error!("Websocket message error: {:?}", e);
match sink.send(Message::Close(None)).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => tracing::error!("Websocket close error: {:?}", e),
_ => (),
};
break;
}
}
}
};
spawn_with_trace(fut, span);
}
fn normalize_request<T>(mut req: Request<T>) -> Request<T> {
req
}
pub struct CustomCertVerifier {
}
impl CustomCertVerifier {
pub fn new() -> Self {
Self {}
}
}
impl Debug for CustomCertVerifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[cfg(feature = "tor")]
impl ServerCertVerifier for CustomCertVerifier {
fn verify_server_cert(&self, end_entity: &CertificateDer<'_>, intermediates: &[CertificateDer<'_>], server_name: &ServerName<'_>, ocsp_response: &[u8], now: UnixTime) -> Result<ServerCertVerified, rustls::Error> {
let verified = ServerCertVerified::assertion();
Ok(verified)
}
fn verify_tls12_signature(&self, message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct) -> Result<HandshakeSignatureValid, rustls::Error> {
todo!()
}
fn verify_tls13_signature(&self, message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct) -> Result<HandshakeSignatureValid, rustls::Error> {
todo!()
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
todo!()
}
}
#[cfg(feature = "tor")]
async fn make_stream_request(
request: Request<BodyBytes>,
stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
) -> Response<BodyBytes> {
let (mut request_sender, connection) =
match hyper::client::conn::http2::handshake(TokioExecutor::new(), TokioIo::new(stream)).await {
Ok(res) => res,
Err(e) => {
tracing::error!("Error generating tor request: {:?}", e);
return Response::new(BodyBytes::new(Bytes::from(e.to_string())))
}
};
tokio::spawn(async move {
match connection.await {
Ok(_) => {}
Err(e) => {
tracing::error!("Error while polling tor request: {:?}", e);
}
}
});
let mut response = match request_sender
.send_request(normalize_request(request))
.await {
Ok(res) => res,
Err(e) => {
tracing::error!("Error making tor request: {:?}", e);
return Response::new(BodyBytes::new(Bytes::from(e.to_string())))
}
};
let (parts, mut response_body) = response.into_parts();
let mut full_bytes = vec![];
while let Some(frame) = response_body.frame().await {
let bytes = match frame {
Ok(frame) => {
frame.into_data().unwrap_or_else(|e| {
tracing::error!("Problem converting frame to data from tor: {:?}", e);
Bytes::new()
})
},
Err(e) => {
tracing::error!("Problem getting data frame from tor: {:?}", e);
Bytes::new()
}
};
full_bytes.append(&mut bytes.to_vec());
}
Response::from_parts(parts, BodyBytes::new(Bytes::from(full_bytes)))
}