use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Semaphore, watch};
use crate::error::NetworkError;
use crate::proto::http::{HttpRequest, HttpResponse};
async fn write_bad_request_and_close(stream: &mut TcpStream, handler: Option<&mut dyn ConnectionHandler>) {
let mut response = HttpResponse::new("RTSP/1.0", 400, "Bad Request");
response.add_header("Connection", "close");
response.finish(None);
let wire_out = match handler {
Some(handler) if handler.is_encrypted() => handler.encrypt_outgoing(response.get_data()),
_ => response.get_data().to_vec(),
};
let _ = stream.write_all(&wire_out).await;
let _ = stream.shutdown().await;
}
#[derive(Debug, Clone)]
pub struct BindConfig {
pub bind_addrs: Vec<IpAddr>,
pub port: u16,
pub auto_port: bool,
}
impl Default for BindConfig {
fn default() -> Self {
Self {
bind_addrs: Vec::new(),
port: 5000,
auto_port: true,
}
}
}
impl BindConfig {
pub fn new() -> Self {
Self::default()
}
pub fn addrs(mut self, addrs: impl IntoIterator<Item = IpAddr>) -> Self {
self.bind_addrs = addrs.into_iter().collect();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn exact_port(mut self) -> Self {
self.auto_port = false;
self
}
}
pub trait HttpdCallbacks: Send + Sync + 'static {
fn conn_init(&self, local: SocketAddr, remote: SocketAddr) -> Option<Box<dyn ConnectionHandler>>;
}
pub trait ConnectionHandler: Send {
fn conn_request(&mut self, request: &HttpRequest) -> HttpResponse;
fn decrypt_incoming(&mut self, data: &[u8]) -> Option<(Vec<u8>, usize)> {
Some((data.to_vec(), data.len()))
}
fn encrypt_outgoing(&mut self, data: &[u8]) -> Vec<u8> {
data.to_vec()
}
fn is_encrypted(&self) -> bool {
false
}
fn after_response(&mut self) {}
}
pub struct HttpServer {
callbacks: Arc<dyn HttpdCallbacks>,
max_connections: usize,
shutdown_tx: Option<watch::Sender<bool>>,
port: u16,
running: bool,
bind_config: BindConfig,
}
impl HttpServer {
pub fn new(callbacks: Arc<dyn HttpdCallbacks>, max_connections: usize) -> Self {
Self {
callbacks,
max_connections,
shutdown_tx: None,
port: 0,
running: false,
bind_config: BindConfig::default(),
}
}
pub fn set_bind_config(&mut self, config: BindConfig) {
self.bind_config = config;
}
pub async fn start(&mut self, port: u16) -> Result<u16, NetworkError> {
if self.running {
return Ok(self.port);
}
let bind_port = if port > 0 { port } else { self.bind_config.port };
let auto_port = self.bind_config.auto_port;
let addrs: Vec<IpAddr> = if self.bind_config.bind_addrs.is_empty() {
vec![IpAddr::V4(Ipv4Addr::UNSPECIFIED), IpAddr::V6(Ipv6Addr::UNSPECIFIED)]
} else {
self.bind_config.bind_addrs.clone()
};
let first = bind_listener(addrs[0], bind_port, auto_port).await?;
let actual_port = first.local_addr()?.port();
let mut listeners = vec![first];
for &addr in &addrs[1..] {
match bind_listener(addr, actual_port, false).await {
Ok(l) => listeners.push(l),
Err(e) => tracing::warn!(%addr, "Failed to bind additional listener: {e}"),
}
}
let (shutdown_tx, shutdown_rx) = watch::channel(false);
self.shutdown_tx = Some(shutdown_tx);
self.port = actual_port;
self.running = true;
let callbacks = self.callbacks.clone();
let semaphore = Arc::new(Semaphore::new(self.max_connections));
for listener in listeners {
let addr = listener.local_addr().unwrap();
tracing::debug!(%addr, "Listener bound");
spawn_accept_loop(listener, callbacks.clone(), semaphore.clone(), shutdown_rx.clone());
}
Ok(actual_port)
}
pub fn is_running(&self) -> bool {
self.running
}
pub fn port(&self) -> u16 {
self.port
}
pub async fn stop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(true);
}
self.running = false;
}
}
async fn bind_listener(addr: IpAddr, start_port: u16, auto_port: bool) -> Result<TcpListener, NetworkError> {
let mut port = start_port;
loop {
match TcpListener::bind(SocketAddr::new(addr, port)).await {
Ok(listener) => return Ok(listener),
Err(_e) if auto_port && port < start_port.saturating_add(100) => {
port += 1;
}
Err(e) => return Err(NetworkError::Io(e)),
}
}
}
fn spawn_accept_loop(
listener: TcpListener,
callbacks: Arc<dyn HttpdCallbacks>,
semaphore: Arc<Semaphore>,
mut shutdown_rx: watch::Receiver<bool>,
) {
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
let (stream, remote) = match result {
Ok(v) => v,
Err(_) => continue,
};
tracing::info!(%remote, "New connection");
let local = match stream.local_addr() {
Ok(a) => a,
Err(_) => continue,
};
let permit = match semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => { tracing::warn!("Max connections reached"); continue; }
};
let cb = callbacks.clone();
tokio::spawn(async move {
let _permit = permit;
let mut handler = match cb.conn_init(local, remote) {
Some(h) => h,
None => return,
};
let mut stream = stream;
let mut buf = [0u8; 4096];
let mut request = HttpRequest::new();
let mut raw_buf = Vec::new();
loop {
while request.is_complete() {
let method = request.method().unwrap_or("?").to_string();
let url = request.url().unwrap_or("?").to_string();
tracing::debug!(%method, %url, body_len = request.data().map(|d| d.len()).unwrap_or(0), "RTSP request");
let response = handler.conn_request(&request);
let status = response.status_code();
tracing::debug!(%method, %url, status, "RTSP response");
let disconnect = response.get_disconnect();
let raw_out = response.get_data();
let wire_out = if handler.is_encrypted() {
handler.encrypt_outgoing(raw_out)
} else {
raw_out.to_vec()
};
if stream.write_all(&wire_out).await.is_err() {
return;
}
handler.after_response();
if disconnect {
let _ = stream.shutdown().await;
return;
}
let leftover = request.take_leftover();
request = HttpRequest::new();
if !leftover.is_empty() && request.add_data(&leftover).is_err() {
write_bad_request_and_close(&mut stream, Some(handler.as_mut())).await;
return;
}
}
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler.is_encrypted() {
raw_buf.extend_from_slice(&buf[..n]);
if raw_buf.len() > 1024 * 1024 {
tracing::warn!("Encrypted buffer exceeded 1 MB, dropping connection");
break;
}
tracing::trace!(encrypted = true, raw_len = raw_buf.len(), new_bytes = n, "Read");
match handler.decrypt_incoming(&raw_buf) {
Some((plain, consumed)) => {
tracing::trace!(plain_len = plain.len(), consumed, "Decrypt");
if consumed > 0 {
raw_buf.drain(..consumed);
}
if !plain.is_empty() {
tracing::trace!("Decrypted: {:?}", String::from_utf8_lossy(&plain[..plain.len().min(120)]));
if request.add_data(&plain).is_err() {
tracing::warn!("HTTP parse error on decrypted data");
write_bad_request_and_close(&mut stream, Some(handler.as_mut())).await;
break;
}
tracing::trace!(complete = request.is_complete(), headers_complete = request.headers_complete(), "After add_data");
}
}
None => {
tracing::warn!("Decryption failed, raw_buf first bytes: {:02x?}", &raw_buf[..raw_buf.len().min(16)]);
break;
}
}
} else {
tracing::trace!(encrypted = false, n, "Read (plaintext)");
if request.add_data(&buf[..n]).is_err() {
tracing::warn!("HTTP parse error, first bytes: {:02x?}", &buf[..n.min(32)]);
write_bad_request_and_close(&mut stream, Some(handler.as_mut())).await;
break;
}
}
}
tracing::info!(%remote, "Connection closed");
});
}
_ = shutdown_rx.changed() => {
break;
}
}
}
});
}