use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::Router;
use axum::extract::State;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use futures::{SinkExt, StreamExt};
use tokio::sync::{RwLock, Semaphore};
use tokio_util::sync::CancellationToken;
use tower_http::cors::{AllowOrigin, CorsLayer};
use zeroize::Zeroizing;
use scp_identity::document::DidDocument;
use scp_platform::traits::Storage;
use scp_transport::native::server::RelayConfig as TransportRelayConfig;
use scp_transport::native::storage::BlobStorageBackend;
use scp_transport::relay::rate_limit::{ConnectionTracker, PublishRateLimiter};
use scp_transport::relay::subscription::SubscriptionRegistry;
use crate::tls;
use crate::projection::ProjectedContext;
use crate::well_known::well_known_handler;
use crate::{ApplicationNode, NodeError};
#[derive(Debug, Clone)]
pub struct BroadcastContext {
pub id: String,
pub name: Option<String>,
}
pub struct NodeState {
pub(crate) did: String,
pub(crate) relay_url: String,
pub(crate) broadcast_contexts: RwLock<HashMap<String, BroadcastContext>>,
pub(crate) relay_addr: SocketAddr,
pub(crate) bridge_secret: Zeroizing<[u8; 32]>,
pub(crate) dev_token: Option<String>,
pub(crate) dev_bind_addr: Option<SocketAddr>,
pub(crate) projected_contexts: RwLock<HashMap<[u8; 32], ProjectedContext>>,
pub(crate) blob_storage: Arc<BlobStorageBackend>,
pub(crate) relay_config: TransportRelayConfig,
pub(crate) start_time: Instant,
pub(crate) http_bind_addr: SocketAddr,
pub(crate) shutdown_token: CancellationToken,
pub(crate) cors_origins: Option<Vec<String>>,
pub(crate) projection_rate_limiter: PublishRateLimiter,
pub(crate) tls_config: Option<Arc<rustls::ServerConfig>>,
pub(crate) cert_resolver: Option<Arc<crate::tls::CertResolver>>,
pub(crate) did_document: DidDocument,
pub(crate) connection_tracker: ConnectionTracker,
pub(crate) subscription_registry: SubscriptionRegistry,
pub(crate) acme_challenges: Option<Arc<RwLock<HashMap<String, String>>>>,
pub(crate) bridge_state: Arc<crate::bridge_handlers::BridgeState>,
}
pub fn build_cors_layer(origins: &Option<Vec<String>>) -> CorsLayer {
let allow_origin = origins.as_ref().map_or_else(AllowOrigin::any, |list| {
let parsed: Vec<axum::http::HeaderValue> = list
.iter()
.filter_map(|o| {
o.parse().map_or_else(
|_| {
tracing::warn!(
origin = %o,
"ignoring invalid CORS origin; \
this may make endpoints more permissive than intended"
);
None
},
Some,
)
})
.collect();
AllowOrigin::list(parsed)
});
CorsLayer::new()
.allow_origin(allow_origin)
.allow_methods([axum::http::Method::GET, axum::http::Method::OPTIONS])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::IF_NONE_MATCH,
])
}
const BRIDGE_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
pub fn well_known_router(state: Arc<NodeState>) -> Router {
Router::new()
.route("/.well-known/scp", get(well_known_handler))
.with_state(state)
}
pub fn relay_router(state: Arc<NodeState>) -> Router {
let bridge_semaphore = Arc::new(Semaphore::new(state.relay_config.max_total_connections));
Router::new()
.route("/scp/v1", get(ws_upgrade_handler))
.with_state((state, bridge_semaphore))
}
async fn ws_upgrade_handler(
ws: WebSocketUpgrade,
State((state, sem)): State<(Arc<NodeState>, Arc<Semaphore>)>,
) -> impl IntoResponse {
let Ok(permit) = sem.try_acquire_owned() else {
return StatusCode::SERVICE_UNAVAILABLE.into_response();
};
let relay_addr = state.relay_addr;
let bridge_secret = state.bridge_secret.clone();
ws.on_upgrade(move |socket| async move {
let _permit = permit; relay_bridge(socket, relay_addr, bridge_secret).await;
})
.into_response()
}
async fn relay_bridge(
axum_ws: WebSocket,
relay_addr: SocketAddr,
bridge_secret: Zeroizing<[u8; 32]>,
) {
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let token_hex = scp_transport::native::server::hex_encode_32(&bridge_secret);
let url = format!("ws://{relay_addr}/");
let mut request = match url.into_client_request() {
Ok(r) => r,
Err(e) => {
tracing::error!(
addr = %relay_addr,
error = %e,
"failed to build WebSocket request for internal relay bridge"
);
return;
}
};
let Ok(header_value) = format!("Bearer {token_hex}").parse() else {
tracing::error!("bridge token produced invalid HTTP header value");
return;
};
request.headers_mut().insert("Authorization", header_value);
let relay_conn = tokio_tungstenite::connect_async(request).await;
let Ok((relay_ws, _)) = relay_conn else {
tracing::error!(
addr = %relay_addr,
"failed to connect to internal relay for WebSocket bridge"
);
return;
};
let (mut relay_sink, mut relay_source) = relay_ws.split();
let (mut axum_sink, mut axum_source) = axum_ws.split();
let idle_timeout = tokio::time::sleep(BRIDGE_IDLE_TIMEOUT);
tokio::pin!(idle_timeout);
loop {
tokio::select! {
msg = StreamExt::next(&mut axum_source) => {
match msg {
Some(Ok(Message::Close(_)) | Err(_)) | None => break,
Some(Ok(msg)) => {
let relay_msg = match msg {
Message::Text(t) => {
idle_timeout.as_mut().reset(tokio::time::Instant::now() + BRIDGE_IDLE_TIMEOUT);
tokio_tungstenite::tungstenite::Message::Text(t.to_string())
}
Message::Binary(b) => {
idle_timeout.as_mut().reset(tokio::time::Instant::now() + BRIDGE_IDLE_TIMEOUT);
tokio_tungstenite::tungstenite::Message::Binary(b.to_vec())
}
Message::Ping(p) => tokio_tungstenite::tungstenite::Message::Ping(p.to_vec()),
Message::Pong(p) => tokio_tungstenite::tungstenite::Message::Pong(p.to_vec()),
Message::Close(_) => break,
};
if let Err(e) = SinkExt::send(&mut relay_sink, relay_msg).await {
tracing::debug!(
direction = "client->relay",
error = %e,
"bridge forwarding failed"
);
break;
}
}
}
}
msg = StreamExt::next(&mut relay_source) => {
match msg {
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_)) | Err(_)) | None => break,
Some(Ok(msg)) => {
let axum_msg = match msg {
tokio_tungstenite::tungstenite::Message::Text(t) => {
idle_timeout.as_mut().reset(tokio::time::Instant::now() + BRIDGE_IDLE_TIMEOUT);
Message::Text(t.into())
}
tokio_tungstenite::tungstenite::Message::Binary(b) => {
idle_timeout.as_mut().reset(tokio::time::Instant::now() + BRIDGE_IDLE_TIMEOUT);
Message::Binary(b.into())
}
tokio_tungstenite::tungstenite::Message::Ping(p) => Message::Ping(p.into()),
tokio_tungstenite::tungstenite::Message::Pong(p) => Message::Pong(p.into()),
tokio_tungstenite::tungstenite::Message::Close(_) => break,
tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
};
if let Err(e) = SinkExt::send(&mut axum_sink, axum_msg).await {
tracing::debug!(
direction = "relay->client",
error = %e,
"bridge forwarding failed"
);
break;
}
}
}
}
() = &mut idle_timeout => {
tracing::debug!("bridge connection idle timeout reached");
break;
}
}
}
let _ = SinkExt::close(&mut relay_sink).await;
let _ = SinkExt::close(&mut axum_sink).await;
}
impl<S: Storage + Send + Sync + 'static> ApplicationNode<S> {
#[must_use = "returns the well-known router, which must be mounted into an axum application"]
pub fn well_known_router(&self) -> Router {
well_known_router(Arc::clone(&self.state))
}
#[must_use = "returns the relay router, which must be mounted into an axum application"]
pub fn relay_router(&self) -> Router {
relay_router(Arc::clone(&self.state))
}
#[must_use = "returns the projection router, which must be mounted into an axum application"]
pub fn broadcast_projection_router(&self) -> Router {
crate::projection::broadcast_projection_router(Arc::clone(&self.state))
}
#[must_use = "returns the bridge router, which must be mounted into an axum application"]
pub fn bridge_router(&self) -> Router {
crate::bridge_handlers::bridge_router(Arc::clone(&self.state.bridge_state))
}
#[must_use = "returns the dev API router, which must be served on a separate listener"]
pub fn dev_router(&self) -> Option<Router> {
let token = self.state.dev_token.clone()?;
Some(crate::dev_api::dev_router(Arc::clone(&self.state), token))
}
pub async fn serve(
self,
app_router: Router,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> Result<(), NodeError> {
spawn_projection_rate_limit_cleanup(
self.state.projection_rate_limiter.clone(),
self.state.shutdown_token.clone(),
);
let cors = build_cors_layer(&self.state.cors_origins);
let well_known = well_known_router(Arc::clone(&self.state)).layer(cors.clone());
let relay_rt = relay_router(Arc::clone(&self.state));
let projection =
crate::projection::broadcast_projection_router(Arc::clone(&self.state)).layer(cors);
let dev_router = self
.state
.dev_token
.clone()
.map(|t| crate::dev_api::dev_router(Arc::clone(&self.state), t));
let dev_bind_addr = self.state.dev_bind_addr;
let tls_config = self.state.tls_config.clone();
#[cfg(feature = "http3")]
let http3_config = self.http3_config;
let bridge = crate::bridge_handlers::bridge_router(Arc::clone(&self.state.bridge_state));
let relay = self.relay;
let state = self.state;
let merged = build_merged_router(
app_router,
well_known,
relay_rt,
projection,
bridge,
state.acme_challenges.as_ref(),
);
let dev_api_handle = spawn_dev_api(dev_router, dev_bind_addr, state.shutdown_token.clone());
#[cfg(feature = "http3")]
if let Some(http3_config) = http3_config {
spawn_http3_listener(http3_config, &state);
}
let listener = tokio::net::TcpListener::bind(state.http_bind_addr)
.await
.map_err(|e| NodeError::Serve(e.to_string()))?;
let local_addr = listener
.local_addr()
.map_err(|e| NodeError::Serve(e.to_string()))?;
let shutdown_token = state.shutdown_token.clone();
let token = shutdown_token.clone();
tokio::spawn(async move {
shutdown.await;
token.cancel();
});
let main_server = build_main_server(
listener,
merged,
tls_config,
shutdown_token.clone(),
local_addr,
);
let result = match dev_api_handle {
Some(handle) => {
tokio::pin!(handle);
tokio::select! {
result = main_server => {
state.shutdown_token.cancel();
handle.abort();
result
}
result = &mut handle => {
state.shutdown_token.cancel();
match result {
Ok(inner) => inner,
Err(join_err) => {
Err(NodeError::Serve(
format!("dev API task failed: {join_err}")
))
}
}
}
}
}
None => main_server.await,
};
state.shutdown_token.cancel();
relay.shutdown_handle.shutdown();
tracing::info!("application node shut down");
result
}
}
fn build_merged_router(
app_router: Router,
well_known: Router,
relay_rt: Router,
projection: Router,
bridge: Router,
acme_challenges: Option<&Arc<RwLock<HashMap<String, String>>>>,
) -> Router {
let merged = app_router
.merge(well_known)
.merge(relay_rt)
.merge(projection)
.merge(bridge);
if let Some(challenges) = acme_challenges {
merged.merge(tls::acme_challenge_router(Arc::clone(challenges)))
} else {
merged
}
}
fn spawn_dev_api(
dev_router: Option<Router>,
dev_bind_addr: Option<SocketAddr>,
shutdown_token: CancellationToken,
) -> Option<tokio::task::JoinHandle<Result<(), NodeError>>> {
let (Some(dev_router), Some(dev_addr)) = (dev_router, dev_bind_addr) else {
return None;
};
Some(tokio::spawn(async move {
let dev_listener = tokio::net::TcpListener::bind(dev_addr).await.map_err(|e| {
NodeError::Serve(format!("failed to bind dev API server on {dev_addr}: {e}"))
})?;
let local_addr = dev_listener.local_addr().unwrap_or(dev_addr);
tracing::info!(addr = %local_addr, "dev API server started");
axum::serve(dev_listener, dev_router)
.with_graceful_shutdown(shutdown_token.cancelled_owned())
.await
.map_err(|e| NodeError::Serve(format!("dev API server error: {e}")))
}))
}
fn build_main_server(
listener: tokio::net::TcpListener,
merged: Router,
tls_config: Option<Arc<rustls::ServerConfig>>,
shutdown_token: CancellationToken,
local_addr: SocketAddr,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), NodeError>> + Send>> {
if let Some(tls_cfg) = tls_config {
tracing::info!(
addr = %local_addr, scheme = "HTTPS",
"application node server started (TLS active)"
);
Box::pin(tls::serve_tls(listener, tls_cfg, merged, shutdown_token))
} else {
tracing::info!(
addr = %local_addr, scheme = "HTTP",
"application node server started (plain HTTP, broadcast projection endpoints active)"
);
Box::pin(async move {
axum::serve(listener, merged)
.with_graceful_shutdown(shutdown_token.cancelled_owned())
.await
.map_err(|e| NodeError::Serve(e.to_string()))
})
}
}
fn spawn_projection_rate_limit_cleanup(
limiter: PublishRateLimiter,
shutdown_token: CancellationToken,
) {
tokio::spawn(async move {
limiter
.cleanup_loop(
Duration::from_secs(60),
Duration::from_secs(300),
shutdown_token,
)
.await;
});
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Instant;
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use tokio::sync::RwLock;
use tower::ServiceExt;
use scp_transport::native::storage::BlobStorageBackend;
use super::*;
fn test_state(cors_origins: Option<Vec<String>>) -> Arc<NodeState> {
Arc::new(NodeState {
did: "did:dht:cors_test".to_owned(),
relay_url: "wss://localhost/scp/v1".to_owned(),
broadcast_contexts: RwLock::new(HashMap::new()),
relay_addr: "127.0.0.1:9000".parse::<SocketAddr>().unwrap(),
bridge_secret: Zeroizing::new([0u8; 32]),
dev_token: None,
dev_bind_addr: None,
projected_contexts: RwLock::new(HashMap::new()),
blob_storage: Arc::new(BlobStorageBackend::default()),
relay_config: scp_transport::native::server::RelayConfig::default(),
start_time: Instant::now(),
http_bind_addr: SocketAddr::from(([0, 0, 0, 0], 8443)),
shutdown_token: CancellationToken::new(),
cors_origins,
projection_rate_limiter: scp_transport::relay::rate_limit::PublishRateLimiter::new(
1000,
),
tls_config: None,
cert_resolver: None,
did_document: scp_identity::document::DidDocument {
context: vec!["https://www.w3.org/ns/did/v1".to_owned()],
id: "did:dht:cors_test".to_owned(),
verification_method: vec![],
authentication: vec![],
assertion_method: vec![],
also_known_as: vec![],
service: vec![],
},
connection_tracker: scp_transport::relay::rate_limit::new_connection_tracker(),
subscription_registry: scp_transport::relay::subscription::new_registry(),
acme_challenges: None,
bridge_state: Arc::new(crate::bridge_handlers::BridgeState::new()),
})
}
#[tokio::test]
async fn cors_permissive_well_known_returns_wildcard_origin() {
let state = test_state(None);
let cors = build_cors_layer(&state.cors_origins);
let router = well_known_router(state).layer(cors);
let req = Request::builder()
.uri("/.well-known/scp")
.header("Origin", "https://example.com")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let acao = resp
.headers()
.get("access-control-allow-origin")
.expect("should have ACAO header")
.to_str()
.unwrap();
assert_eq!(acao, "*", "permissive mode should return wildcard origin");
}
#[tokio::test]
async fn cors_restricted_well_known_allows_matching_origin() {
let origins = Some(vec!["https://allowed.example".to_owned()]);
let state = test_state(origins);
let cors = build_cors_layer(&state.cors_origins);
let router = well_known_router(state).layer(cors);
let req = Request::builder()
.uri("/.well-known/scp")
.header("Origin", "https://allowed.example")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let acao = resp
.headers()
.get("access-control-allow-origin")
.expect("should have ACAO header for allowed origin")
.to_str()
.unwrap();
assert_eq!(acao, "https://allowed.example");
}
#[tokio::test]
async fn cors_restricted_well_known_rejects_non_matching_origin() {
let origins = Some(vec!["https://allowed.example".to_owned()]);
let state = test_state(origins);
let cors = build_cors_layer(&state.cors_origins);
let router = well_known_router(state).layer(cors);
let req = Request::builder()
.uri("/.well-known/scp")
.header("Origin", "https://evil.example")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(
resp.headers().get("access-control-allow-origin").is_none(),
"non-matching origin should NOT receive ACAO header"
);
}
#[tokio::test]
async fn cors_preflight_options_returns_200() {
let state = test_state(None);
let cors = build_cors_layer(&state.cors_origins);
let router = well_known_router(state).layer(cors);
let req = Request::builder()
.method(Method::OPTIONS)
.uri("/.well-known/scp")
.header("Origin", "https://example.com")
.header("Access-Control-Request-Method", "GET")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let acao = resp
.headers()
.get("access-control-allow-origin")
.expect("preflight should include ACAO")
.to_str()
.unwrap();
assert_eq!(acao, "*");
let methods = resp
.headers()
.get("access-control-allow-methods")
.expect("preflight should include allow-methods")
.to_str()
.unwrap();
assert!(methods.contains("GET"), "should allow GET method");
}
}
#[cfg(feature = "http3")]
fn spawn_http3_listener(http3_config: scp_transport::http3::Http3Config, state: &Arc<NodeState>) {
use scp_transport::http3::Http3Server;
use scp_transport::http3::adapter::RequestHandler;
struct H3RequestHandler {
state: Arc<NodeState>,
rt: tokio::runtime::Handle,
}
impl RequestHandler for H3RequestHandler {
fn handle(
&self,
method: &str,
uri: &str,
_headers: &[(String, String)],
) -> axum::http::Response<Vec<u8>> {
if method == "GET" && uri == "/.well-known/scp" {
let doc = self
.rt
.block_on(crate::well_known::build_well_known_scp(&self.state));
let body_bytes = serde_json::to_vec(&doc).unwrap_or_default();
axum::http::Response::builder()
.status(200)
.header("content-type", "application/json")
.body(body_bytes)
.unwrap_or_else(|_| axum::http::Response::new(b"internal error".to_vec()))
} else {
axum::http::Response::builder()
.status(404)
.body(b"not found".to_vec())
.unwrap_or_else(|_| axum::http::Response::new(Vec::new()))
}
}
}
let handler: Arc<dyn RequestHandler> = Arc::new(H3RequestHandler {
state: Arc::clone(state),
rt: tokio::runtime::Handle::current(),
});
tokio::spawn(async move {
let mut server = Http3Server::new(http3_config, handler);
match server.bind() {
Ok(addr) => {
tracing::info!(addr = %addr, "HTTP/3 server started");
if let Err(e) = server.serve().await {
tracing::error!(error = %e, "HTTP/3 server exited with error");
}
}
Err(e) => {
tracing::error!(error = %e, "failed to bind HTTP/3 server");
}
}
});
}