use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{debug, error, info, trace};
use warp::Filter;
use super::router::{EventRouter, NotificationPayload};
pub struct CallbackServer {
port: u16,
base_url: String,
event_router: Arc<EventRouter>,
shutdown_tx: Option<mpsc::Sender<()>>,
server_handle: Option<tokio::task::JoinHandle<()>>,
}
impl CallbackServer {
pub async fn new(
port_range: (u16, u16),
event_sender: mpsc::UnboundedSender<NotificationPayload>,
) -> Result<Self, String> {
let port = Self::find_available_port(port_range.0, port_range.1).ok_or_else(|| {
format!(
"No available port found in range {}-{}",
port_range.0, port_range.1
)
})?;
let local_ip = Self::detect_local_ip()
.ok_or_else(|| "Failed to detect local IP address".to_string())?;
eprintln!("Detected local IP address: {local_ip}");
let base_url = format!("http://{local_ip}:{port}");
eprintln!("Callback server base URL: {base_url}");
let event_router = Arc::new(EventRouter::new(event_sender));
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let (ready_tx, mut ready_rx) = mpsc::channel::<()>(1);
let server_handle = Self::start_server(port, event_router.clone(), shutdown_rx, ready_tx);
ready_rx
.recv()
.await
.ok_or_else(|| "Server failed to start".to_string())?;
Ok(Self {
port,
base_url,
event_router,
shutdown_tx: Some(shutdown_tx),
server_handle: Some(server_handle),
})
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn port(&self) -> u16 {
self.port
}
pub fn router(&self) -> &Arc<EventRouter> {
&self.event_router
}
pub async fn shutdown(mut self) -> Result<(), String> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
if let Some(handle) = self.server_handle.take() {
let _ = handle.await;
}
Ok(())
}
fn find_available_port(start: u16, end: u16) -> Option<u16> {
(start..=end).find(|&port| Self::is_port_available(port))
}
fn is_port_available(port: u16) -> bool {
TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).is_ok()
}
fn detect_local_ip() -> Option<IpAddr> {
let socket = std::net::UdpSocket::bind("0.0.0.0:0").ok()?;
socket.connect("8.8.8.8:80").ok()?;
let local_addr = socket.local_addr().ok()?;
Some(local_addr.ip())
}
fn start_server(
port: u16,
event_router: Arc<EventRouter>,
mut shutdown_rx: mpsc::Receiver<()>,
ready_tx: mpsc::Sender<()>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let notify_route = warp::method()
.and(warp::path::full())
.and(warp::header::optional::<String>("sid"))
.and(warp::header::optional::<String>("nt"))
.and(warp::header::optional::<String>("nts"))
.and(warp::body::bytes())
.and_then({
let router = event_router.clone();
move |method: warp::http::Method,
path: warp::path::FullPath,
sid: Option<String>,
nt: Option<String>,
nts: Option<String>,
body: bytes::Bytes| {
let router = router.clone();
async move {
if method != warp::http::Method::from_bytes(b"NOTIFY").unwrap() {
return Err(warp::reject::not_found());
}
debug!(
method = %method,
path = %path.as_str(),
body_size = body.len(),
sid = ?sid,
nt = ?nt,
nts = ?nts,
"Received UPnP NOTIFY event"
);
let event_xml = String::from_utf8_lossy(&body).to_string();
if event_xml.len() > 200 {
trace!(
event_xml_preview = %&event_xml[..200],
total_length = event_xml.len(),
"UPnP event XML content (truncated)"
);
} else {
trace!(
event_xml = %event_xml,
"UPnP event XML content (full)"
);
}
if !Self::validate_upnp_headers(&sid, &nt, &nts) {
error!(
sid = ?sid,
nt = ?nt,
nts = ?nts,
"Invalid UPnP headers in NOTIFY request"
);
return Err(warp::reject::custom(InvalidUpnpHeaders));
}
let sub_id = sid.ok_or_else(|| {
error!("Missing required SID header in UPnP NOTIFY request");
warp::reject::custom(InvalidUpnpHeaders)
})?;
router.route_event(sub_id.clone(), event_xml).await;
debug!(
subscription_id = %sub_id,
"UPnP event accepted"
);
Ok::<_, warp::Rejection>(warp::reply::with_status(
"",
warp::http::StatusCode::OK,
))
}
}
});
let routes = notify_route.recover(handle_rejection);
let (addr, server) = warp::serve(routes).bind_with_graceful_shutdown(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port),
async move {
shutdown_rx.recv().await;
},
);
info!(
address = %addr,
"CallbackServer listening - ready to process UPnP events"
);
let _ = ready_tx.send(()).await;
server.await;
})
}
fn validate_upnp_headers(
sid: &Option<String>,
nt: &Option<String>,
nts: &Option<String>,
) -> bool {
if sid.is_none() {
return false;
}
if let (Some(nt_val), Some(nts_val)) = (nt, nts) {
if nt_val != "upnp:event" || nts_val != "upnp:propchange" {
return false;
}
}
true
}
}
#[derive(Debug)]
struct InvalidUpnpHeaders;
impl warp::reject::Reject for InvalidUpnpHeaders {}
async fn handle_rejection(
err: warp::Rejection,
) -> Result<impl warp::Reply, std::convert::Infallible> {
let code;
let message;
if err.is_not_found() {
code = warp::http::StatusCode::NOT_FOUND;
message = "Subscription not found";
} else if err.find::<InvalidUpnpHeaders>().is_some() {
code = warp::http::StatusCode::BAD_REQUEST;
message = "Invalid UPnP headers";
} else {
code = warp::http::StatusCode::INTERNAL_SERVER_ERROR;
message = "Internal server error";
}
Ok(warp::reply::with_status(message, code))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_port_available() {
assert!(CallbackServer::is_port_available(0));
let _listener = TcpListener::bind("0.0.0.0:0").unwrap();
let port = _listener.local_addr().unwrap().port();
assert!(!CallbackServer::is_port_available(port));
drop(_listener);
}
#[test]
fn test_find_available_port() {
let port = CallbackServer::find_available_port(50000, 50100);
assert!(port.is_some());
assert!(port.unwrap() >= 50000 && port.unwrap() <= 50100);
}
#[test]
fn test_detect_local_ip() {
let ip = CallbackServer::detect_local_ip();
assert!(ip.is_some());
if let Some(IpAddr::V4(addr)) = ip {
assert_ne!(addr, Ipv4Addr::new(127, 0, 0, 1));
}
}
#[test]
fn test_validate_upnp_headers() {
assert!(CallbackServer::validate_upnp_headers(
&Some("uuid:123".to_string()),
&Some("upnp:event".to_string()),
&Some("upnp:propchange".to_string()),
));
assert!(CallbackServer::validate_upnp_headers(
&Some("uuid:123".to_string()),
&None,
&None,
));
assert!(!CallbackServer::validate_upnp_headers(
&None,
&Some("upnp:event".to_string()),
&Some("upnp:propchange".to_string()),
));
assert!(!CallbackServer::validate_upnp_headers(
&Some("uuid:123".to_string()),
&Some("wrong".to_string()),
&Some("upnp:propchange".to_string()),
));
assert!(!CallbackServer::validate_upnp_headers(
&Some("uuid:123".to_string()),
&Some("upnp:event".to_string()),
&Some("wrong".to_string()),
));
}
#[tokio::test]
async fn test_callback_server_creation() {
let (tx, _rx) = mpsc::unbounded_channel();
let server = CallbackServer::new((50000, 50100), tx).await;
assert!(server.is_ok());
let server = server.unwrap();
assert!(server.port() >= 50000 && server.port() <= 50100);
assert!(server.base_url().contains(&server.port().to_string()));
server.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_callback_server_register_unregister() {
let (tx, _rx) = mpsc::unbounded_channel();
let server = CallbackServer::new((51000, 51100), tx).await.unwrap();
let sub_id = "test-sub-123".to_string();
server.router().register(sub_id.clone()).await;
server.router().unregister(&sub_id).await;
server.shutdown().await.unwrap();
}
}