use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_rustls::TlsAcceptor;
use siphon_protocol::ServerMessage;
use crate::router::Router;
use crate::state::ResponseRegistry;
pub struct HttpPlane {
router: Arc<Router>,
base_domain: String,
stream_id_counter: AtomicU64,
response_registry: ResponseRegistry,
tls_acceptor: Option<TlsAcceptor>,
}
impl HttpPlane {
pub fn new(
router: Arc<Router>,
base_domain: String,
response_registry: ResponseRegistry,
tls_acceptor: Option<TlsAcceptor>,
) -> Arc<Self> {
Arc::new(Self {
router,
base_domain,
stream_id_counter: AtomicU64::new(1),
response_registry,
tls_acceptor,
})
}
fn next_stream_id(&self) -> u64 {
self.stream_id_counter.fetch_add(1, Ordering::Relaxed)
}
async fn serve_connection<S>(self: Arc<Self>, stream: S, peer_addr: SocketAddr)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let io = TokioIo::new(stream);
let service = service_fn(move |req| {
let this = self.clone();
async move { this.handle_request(req).await }
});
if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
tracing::debug!("HTTP connection error from {}: {}", peer_addr, e);
}
}
pub async fn run(self: Arc<Self>, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr).await?;
if self.tls_acceptor.is_some() {
tracing::info!("HTTPS plane listening on {}", addr);
} else {
tracing::info!("HTTP plane listening on {}", addr);
}
loop {
let (stream, peer_addr) = listener.accept().await?;
tracing::debug!("HTTP connection from {}", peer_addr);
let this = self.clone();
tokio::spawn(async move {
if let Some(ref acceptor) = this.tls_acceptor {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
this.serve_connection(tls_stream, peer_addr).await;
}
Err(e) => {
tracing::warn!("TLS handshake failed from {}: {}", peer_addr, e);
}
}
} else {
this.serve_connection(stream, peer_addr).await;
}
});
}
}
async fn handle_request(
self: Arc<Self>,
req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>, Infallible> {
tracing::debug!(
"HTTP request: {} {} (Host: {:?})",
req.method(),
req.uri(),
req.headers().get("host")
);
let subdomain = match self.extract_subdomain(&req) {
Some(s) => s,
None => {
tracing::warn!("Request without valid subdomain");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid or missing subdomain")))
.unwrap());
}
};
tracing::debug!("Forwarding to tunnel: {}", subdomain);
let sender = match self.router.get_sender(&subdomain) {
Some(s) => s,
None => {
tracing::warn!("No tunnel for subdomain: {}", subdomain);
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from(format!(
"Tunnel not found for: {}",
subdomain
))))
.unwrap());
}
};
let stream_id = self.next_stream_id();
let method = req.method().to_string();
let uri = req.uri().to_string();
let headers: Vec<(String, String)> = req
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let body = match req.into_body().collect().await {
Ok(collected) => collected.to_bytes().to_vec(),
Err(e) => {
tracing::error!("Failed to read request body: {}", e);
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::from("Failed to read request body")))
.unwrap());
}
};
let (response_tx, response_rx) = oneshot::channel();
self.response_registry.insert(stream_id, response_tx);
let msg = ServerMessage::HttpRequest {
stream_id,
method,
uri,
headers,
body,
};
if let Err(e) = sender.send(msg).await {
tracing::error!("Failed to send request to tunnel: {}", e);
self.response_registry.remove(&stream_id);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Tunnel connection lost")))
.unwrap());
}
let timeout = Duration::from_secs(30);
match tokio::time::timeout(timeout, response_rx).await {
Ok(Ok(response_data)) => {
let mut builder = Response::builder().status(response_data.status);
for (name, value) in response_data.headers {
builder = builder.header(name, value);
}
Ok(builder
.body(Full::new(Bytes::from(response_data.body)))
.unwrap())
}
Ok(Err(_)) => {
tracing::error!("Tunnel disconnected while waiting for response");
Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Tunnel disconnected")))
.unwrap())
}
Err(_) => {
tracing::error!("Timeout waiting for tunnel response");
self.response_registry.remove(&stream_id);
Ok(Response::builder()
.status(StatusCode::GATEWAY_TIMEOUT)
.body(Full::new(Bytes::from("Tunnel response timeout")))
.unwrap())
}
}
}
fn extract_subdomain(&self, req: &Request<Incoming>) -> Option<String> {
let host = req.headers().get("host")?.to_str().ok()?;
let host = host.split(':').next()?;
if !host.ends_with(&self.base_domain) {
return None;
}
let subdomain_part = host.strip_suffix(&format!(".{}", self.base_domain))?;
Some(subdomain_part.split('.').next()?.to_string())
}
}