use hyper::Version;
use hyper_rustls::ConfigBuilderExt;
use serde::Serialize;
use std::fmt::Debug;
use std::net::IpAddr;
use std::{
net::SocketAddr,
sync::Arc,
};
use crate::configuration::BackendFilter;
use crate::global_state::GlobalState;
use crate::tcp_proxy::Peekable;
use crate::types::proc_info::ProcId;
use tokio::net::TcpStream;
use tracing::*;
use super::{ManagedStream, GenericManagedStream};
#[derive(Debug,Eq,PartialEq,Hash,Clone,Serialize)]
pub struct ReverseTcpProxyTarget {
pub remote_target_config: Option<crate::configuration::RemoteSiteConfig>,
pub hosted_target_config: Option<crate::configuration::InProcessSiteConfig>,
pub backends: Vec<crate::configuration::Backend>,
pub host_name: String,
pub is_hosted : bool,
pub capture_subdomains: bool,
pub forward_wildcard: bool,
pub sub_domain: Option<String> ,
pub disable_tcp_tunnel_mode : bool,
pub proc_id : Option<ProcId>,
}
#[derive(Debug,Eq,PartialEq)]
pub enum DataType {
TLS,
ClearText
}
#[derive(Debug)]
pub struct PeekResult {
pub typ : DataType,
#[allow(dead_code)]pub http_version : Option<Version>,
pub target_host : Option<String>,
pub is_h2c_upgrade : bool
}
#[allow(dead_code)]
#[derive(Debug)]
pub enum PeekError {
StreamIsClosed,
Unknown(String),
H2PriorKnowledgeNeedsToBeTerminated
}
impl ReverseTcpProxyTarget {
#[allow(dead_code)]
fn is_valid_ip_or_dns(target: &str) -> bool {
webpki::DnsNameRef::try_from_ascii_str(target)
.map(|_| true)
.or_else(|_| target.parse::<IpAddr>().map(|_| true))
.unwrap_or(false)
}
}
#[allow(dead_code)]
pub struct ReverseTcpProxy {
pub socket_addr: SocketAddr,
}
#[derive(Debug)]
pub enum TunnelError{
NoUsableBackendFound(GenericManagedStream),
Unknown(String)
}
impl std::error::Error for TunnelError {}
impl std::fmt::Display for TunnelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TunnelError::NoUsableBackendFound(_) => write!(f, "No usable backend found for incoming traffic"),
TunnelError::Unknown(e) => write!(f, "Unknown error: {}",e),
}
}
}
impl ReverseTcpProxy {
pub fn get_subdomain(requested_hostname: &str, backend_hostname: &str) -> Option<String> {
if requested_hostname == backend_hostname { return None };
if requested_hostname.to_uppercase().ends_with(&backend_hostname.to_uppercase()) {
let part_to_remove_len = backend_hostname.len();
let start_index = requested_hostname.len() - part_to_remove_len;
if start_index == 0 || requested_hostname.as_bytes()[start_index - 1] == b'.' {
return Some(requested_hostname[..start_index].trim_end_matches('.').to_string());
}
}
None
}
pub async fn tunnel(
client_tcp_stream: GenericManagedStream,
target:Arc<ReverseTcpProxyTarget>,
incoming_traffic_is_tls:bool,
state: Arc<GlobalState>,
client_address: SocketAddr,
rustls_config : Option<Arc<rustls::ServerConfig>>,
incoming_host_header_or_sni: String,
http_version: Option<Version>,
is_h2c_upgrade_request: bool
) -> anyhow::Result<(),TunnelError> {
let terminate_incoming = if target.disable_tcp_tunnel_mode {
incoming_traffic_is_tls
} else {
if incoming_traffic_is_tls {
let at_least_one_backend_is_tls= target.backends.iter().any(|x|x.https.unwrap_or_default());
if at_least_one_backend_is_tls {
false
} else {
true
}
} else {
false
}
};
let (client_tls_is_terminated,possibly_terminated_stream,backend_filter) = if terminate_incoming {
let tls_cfg = if let Some(cfg) = rustls_config {
cfg
} else {
return Err(TunnelError::Unknown("TLS termination is required but no rustls config provided. Stream cannot be processed further.".into()))
};
match client_tcp_stream {
GenericManagedStream::TCP(peekable_tcp_stream) => {
let tls_acceptor = TlsAcceptor::from(tls_cfg.clone());
match tls_acceptor.accept(peekable_tcp_stream).await {
Ok(mut tls_stream) => {
tracing::trace!("Terminated TLS connection established!");
tls_stream.get_mut().0.is_tls_terminated = true;
tls_stream.get_mut().0.events.push("Terminated TLS prior to running bidirectional TCP tunnel".into());
let mut gen_stream = GenericManagedStream::from_terminated_tls_stream(ManagedStream::from_tls_stream(tls_stream));
let peek_result = gen_stream.peek_managed_stream(client_address).await;
gen_stream.seal();
match peek_result {
Ok(r) => {
let backend_filter = peekresult_to_backend_filter(r,true,is_h2c_upgrade_request);
(true,gen_stream,backend_filter)
},
Err(e) => {
return Err(TunnelError::Unknown(format!("error peeking stream {e:?}")));
},
}
},
Err(e) => {
tracing::warn!("Accept_tcp_stream_via_tls_terminating_proxy_service failed with error: {e:?}");
return Ok(())
}
}
},
GenericManagedStream::TerminatedTLS(_managed_stream) => {
tracing::warn!("Wormhole was already spawned.. this is a bug.");
return Ok(())
},
}
} else {
(false,client_tcp_stream,peekresult_to_backend_filter(
PeekResult {
typ: if incoming_traffic_is_tls { DataType::TLS } else { DataType::ClearText },
http_version,
target_host: Some(incoming_host_header_or_sni.clone()),
is_h2c_upgrade: is_h2c_upgrade_request
},false,is_h2c_upgrade_request
))
};
let backend_filter = if let Some(f) = backend_filter {
f
} else {
tracing::warn!("failed to generate a backend filter.. falling back to http termination");
return Err(TunnelError::NoUsableBackendFound(possibly_terminated_stream))
};
let backend =
match (&target.remote_target_config,&target.hosted_target_config) {
(Some(rem_conf),None) => {
rem_conf.next_backend(&state, backend_filter.clone()).await
},
(None,Some(proc_conf)) => {
proc_conf.next_backend(&state, backend_filter.clone()).await
},
_ => None
};
let backend = if backend == None {
tracing::warn!("No backend found for target {} using filter {backend_filter:?}.. falling back to http termination",target.host_name);
return Err(TunnelError::NoUsableBackendFound(possibly_terminated_stream))
} else {
backend.unwrap()
};
let resolved_address = format!("{}:{}",backend.address,backend.port);
let server_name_for_tls = Some(backend.address.clone());
let backend_is_tls = backend.https.unwrap_or_default();
let erect_tls_tunnel_to_backend = {
match (incoming_traffic_is_tls,backend_is_tls,client_tls_is_terminated) {
(_,false,_) => false, (true,true,true) => true, (true,true,false) => false, (false,true,_) => true, }
};
match TcpStream::connect(resolved_address.clone()).await {
Ok(rem_stream) => {
match state.app_state.statistics.connections_per_hostname.get_mut(&target.host_name) {
Some(mut guard) => {
let (_k,v) = guard.pair_mut();
v.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
},
None => {
state.app_state.statistics.connections_per_hostname
.insert(target.host_name.clone(), std::sync::atomic::AtomicUsize::new(1));
}
};
if let Ok(_target_addr_socket) = rem_stream.peer_addr() {
possibly_terminated_stream.update_tracked_info(|x|{
x.backend = Some(backend);
x.target = Some(target.as_ref().to_owned());
x.incoming_connection_uses_tls = incoming_traffic_is_tls;
x.outgoing_connection_is_tls = backend_is_tls;
});
match run_managed_bidirectional_tunnel(
possibly_terminated_stream,
rem_stream,
backend_is_tls,
server_name_for_tls,
erect_tls_tunnel_to_backend,
incoming_traffic_is_tls
).await {
Ok(_) => {
},
Err(e) => {
tracing::warn!("Tunnel failed with error: {:?}",e);
}
}
} else {
tracing::warn!("failed to read socket peer address..");
}
},
Err(e) => warn!("failed to connect to target {host} (using addr: {resolved_address}) --> {e:?}",host=target.host_name),
}
Ok(())
}
}
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use rustls::pki_types::ServerName;
async fn run_managed_bidirectional_tunnel(
ref mut original_client_stream: GenericManagedStream,
mut stream_connected_to_some_backend: TcpStream,
backend_is_tls: bool,
server_name: Option<String>,
erect_tls_tunnel: bool,
incoming_traffic_is_tls: bool
) -> Result<(), Box<dyn std::error::Error>> {
if backend_is_tls && erect_tls_tunnel {
let server_name = if let Some(s) = server_name {
s
} else {
return Err("no server name provided for tls connection".into());
};
let config = tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(tokio_rustls::rustls::ALL_VERSIONS)
.with_native_roots()
.expect("must be able to create tls configuration")
.with_no_client_auth();
let arc_config = Arc::new(config);
let connector = TlsConnector::from(arc_config);
let server_name = if let Ok(n) = ServerName::try_from(server_name.clone()) {
n
} else {
return Err(format!("failed to create server name from {}",server_name).into());
};
let mut backend_tls_stream = connector
.connect(server_name, stream_connected_to_some_backend)
.await?;
tracing::warn!("New TLS connection established towards the backend");
match original_client_stream {
GenericManagedStream::TerminatedTLS(peekable_tls_stream) => {
match tokio::io::copy_bidirectional( peekable_tls_stream, &mut backend_tls_stream).await {
Ok((_bytes_from_client, _bytes_from_backend)) => {}
Err(e) => {
tracing::warn!("Stream failed with error: {:?}", e);
}
}
peekable_tls_stream.inspect().await;
},
GenericManagedStream::TCP(peekable_tcp_stream) => {
tracing::trace!("Tunneling from cleartext to tls");
match tokio::io::copy_bidirectional(peekable_tcp_stream, &mut backend_tls_stream).await {
Ok((_bytes_from_client, _bytes_from_backend)) => {}
Err(e) => {
tracing::warn!("Stream failed with error: {:?}", e);
}
}
peekable_tcp_stream.inspect().await;
}
}
} else {
match original_client_stream {
GenericManagedStream::TerminatedTLS(peekable_tls_stream) => {
if backend_is_tls {
tracing::trace!("Unwrapped TLS tunnel established, forwarding inner byte stream to tls backend");
} else {
tracing::trace!("Unwrapped TLS tunnel established, forwarding inner byte stream to cleartext backend");
}
match tokio::io::copy_bidirectional(peekable_tls_stream, &mut stream_connected_to_some_backend).await {
Ok((_bytes_from_client, _bytes_from_backend)) => {
}
Err(e) => {
tracing::warn!("Stream failed with error: {:?}", e);
}
}
peekable_tls_stream.inspect().await;
}
GenericManagedStream::TCP(peekable_tcp_stream) => {
if incoming_traffic_is_tls {
tracing::trace!("Raw TCP tunnel established: tls");
} else {
tracing::trace!("Raw TCP tunnel established: cleartext");
}
match tokio::io::copy_bidirectional(peekable_tcp_stream, &mut stream_connected_to_some_backend).await {
Ok((_bytes_from_client, _bytes_from_backend)) => {
}
Err(e) => {
tracing::warn!("Stream failed with error: {:?}", e);
}
}
peekable_tcp_stream.inspect().await;
}
}
}
Ok(())
}
fn peekresult_to_backend_filter(
info_about_incoming_data: PeekResult,
incoming_is_tls_terminated: bool,
is_h2c_upgrade_request: bool,
) -> Option<BackendFilter> {
use DataType::*;
match (
info_about_incoming_data.http_version,
info_about_incoming_data.typ,
incoming_is_tls_terminated,
is_h2c_upgrade_request
) {
(Some(Version::HTTP_2), ClearText, false, false) => Some(BackendFilter::H2CPriorKnowledge),
(Some(Version::HTTP_11), ClearText, false, true) => Some(BackendFilter::H2C),
(Some(Version::HTTP_2), TLS, false, false) => Some(BackendFilter::Http2),
(Some(Version::HTTP_10) | Some(Version::HTTP_11), _, _,false) => Some(BackendFilter::Http1),
(None,scheme,terminated,_) => {
let incoming_byte_stream_is_tls = scheme == DataType::TLS && !terminated;
if incoming_byte_stream_is_tls {
return Some(BackendFilter::AnyTLS)
}
tracing::warn!("Incoming data has no version info, but byte stream is not tls... something is fishy here..");
None
}
(Some(Version::HTTP_2), ClearText, true, false) => {
Some(BackendFilter::H2OrH2cpk)
},
(a,b,c,d) => {
tracing::warn!("Cannot determine backend filter for incoming data: HTTP Version: {:?}, Data Type: {:?}, TLS Terminated: {:?}, H2C Upgrade: {:?}",a,b,c,d);
None
},
}
}