pub mod handler;
pub mod handlers;
pub mod op_sink;
pub mod operations;
pub mod protocol;
pub mod query;
pub mod router;
pub mod sink;
pub use handler::{WsContext, WsError, WsMethod, WsRequest, WsResult};
pub use op_sink::{WsOpSink, WsOpSinkError};
pub use operations::{OperationRegistry, OperationStatus};
pub use protocol::{
ErrorCode, ErrorData, ProgressStage, RequestEnvelope, ResponseEnvelope, ResponseType,
SystemInfo,
};
pub use router::{Dispatcher, Router};
pub use sink::{WsSink, WsSinkError};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::Response,
routing::get,
Router as AxumRouter,
};
use futures::StreamExt;
use std::sync::Arc;
use tokio::time::{Duration, Instant};
use tower_http::cors::{Any, CorsLayer};
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub address: String,
pub port: u16,
pub max_concurrent_ops: usize,
pub max_message_size: usize,
pub connection_timeout_secs: u64,
pub ping_interval_secs: u64,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
address: "127.0.0.1".to_string(),
port: 8080,
max_concurrent_ops: 10,
max_message_size: 1024 * 1024, connection_timeout_secs: 300, ping_interval_secs: 30,
}
}
}
impl ServerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_address(mut self, address: impl Into<String>) -> Self {
self.address = address.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn bind_address(&self) -> String {
format!("{}:{}", self.address, self.port)
}
}
pub fn create_router() -> Router {
use handlers::*;
let mut router = Router::new();
router.register::<SystemInfoHandler>();
router.register::<TimeParseHandler>();
router.register::<CountryLookupHandler>();
router.register::<IpLookupHandler>();
router.register::<IpPublicHandler>();
router.register::<RpkiValidateHandler>();
router.register::<RpkiRoasHandler>();
router.register::<RpkiAspasHandler>();
router.register::<As2relSearchHandler>();
router.register::<As2relRelationshipHandler>();
router.register::<As2relUpdateHandler>();
router.register::<Pfx2asLookupHandler>();
router.register::<DatabaseStatusHandler>();
router.register::<DatabaseRefreshHandler>();
router.register::<InspectQueryHandler>();
router.register::<InspectRefreshHandler>();
router
}
#[derive(Clone)]
pub struct ServerState {
pub dispatcher: Arc<Dispatcher>,
pub config: Arc<ServerConfig>,
}
pub fn create_axum_router(state: ServerState) -> AxumRouter {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
AxumRouter::new()
.route("/ws", get(ws_handler))
.route("/health", get(health_handler))
.layer(cors)
.with_state(state)
}
async fn health_handler() -> &'static str {
"OK"
}
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<ServerState>) -> Response {
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: ServerState) {
let (sender, mut receiver) = socket.split();
let sink = WsSink::new(sender);
tracing::info!("WebSocket connection established");
let max_message_size = state.config.max_message_size;
let ping_interval = Duration::from_secs(state.config.ping_interval_secs.max(1));
let idle_timeout = Duration::from_secs(state.config.connection_timeout_secs.max(1));
let mut last_activity = Instant::now();
let mut next_ping = Instant::now() + ping_interval;
loop {
tokio::select! {
maybe_msg = receiver.next() => {
let Some(msg) = maybe_msg else {
break;
};
match msg {
Ok(Message::Text(text)) => {
if text.len() > max_message_size {
tracing::warn!(
"Closing connection: text message too large ({} > {} bytes)",
text.len(),
max_message_size
);
let _ = sink.send_message_raw(Message::Close(None)).await;
break;
}
last_activity = Instant::now();
tracing::debug!("Received message: {}", text);
state.dispatcher.dispatch(&text, sink.clone()).await;
}
Ok(Message::Binary(data)) => {
if data.len() > max_message_size {
tracing::warn!(
"Closing connection: binary message too large ({} > {} bytes)",
data.len(),
max_message_size
);
let _ = sink.send_message_raw(Message::Close(None)).await;
break;
}
last_activity = Instant::now();
match String::from_utf8(data) {
Ok(text) => {
tracing::debug!("Received binary message as text: {}", text);
state.dispatcher.dispatch(&text, sink.clone()).await;
}
Err(_) => {
tracing::warn!("Received non-UTF8 binary message, ignoring");
}
}
}
Ok(Message::Ping(data)) => {
last_activity = Instant::now();
if let Err(e) = sink.send_message_raw(Message::Pong(data)).await {
tracing::warn!("Failed to send pong: {}", e);
break;
}
}
Ok(Message::Pong(_)) => {
last_activity = Instant::now();
}
Ok(Message::Close(_)) => {
tracing::info!("WebSocket connection closed by client");
break;
}
Err(e) => {
tracing::error!("WebSocket error: {}", e);
break;
}
}
}
_ = tokio::time::sleep_until(next_ping) => {
if last_activity.elapsed() > idle_timeout {
tracing::info!(
"Closing connection due to idle timeout (>{}s)",
idle_timeout.as_secs()
);
let _ = sink.send_message_raw(Message::Close(None)).await;
break;
}
if let Err(e) = sink.send_message_raw(Message::Ping(Vec::new())).await {
tracing::warn!("Failed to send ping: {}", e);
break;
}
next_ping = Instant::now() + ping_interval;
}
}
}
tracing::info!("WebSocket connection closed");
}
pub async fn start_server(
router: Router,
context: WsContext,
config: ServerConfig,
) -> anyhow::Result<()> {
let operations = OperationRegistry::with_max_concurrent(config.max_concurrent_ops);
let dispatcher = Dispatcher::new(router, context, operations);
let state = ServerState {
dispatcher: Arc::new(dispatcher),
config: Arc::new(config.clone()),
};
let app = create_axum_router(state);
let bind_address = config.bind_address();
tracing::info!("Starting WebSocket server on {}", bind_address);
let listener = tokio::net::TcpListener::bind(&bind_address).await?;
axum::serve(listener, app).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.address, "127.0.0.1");
assert_eq!(config.port, 8080);
assert_eq!(config.max_concurrent_ops, 10);
}
#[test]
fn test_server_config_builder() {
let config = ServerConfig::new().with_address("0.0.0.0").with_port(9000);
assert_eq!(config.address, "0.0.0.0");
assert_eq!(config.port, 9000);
assert_eq!(config.bind_address(), "0.0.0.0:9000");
}
#[test]
fn test_create_router() {
let router = create_router();
assert!(router.has_method("system.info"));
assert!(router.has_method("time.parse"));
assert!(router.has_method("country.lookup"));
assert!(router.has_method("ip.lookup"));
assert!(router.has_method("ip.public"));
assert!(router.has_method("rpki.validate"));
assert!(router.has_method("rpki.roas"));
assert!(router.has_method("rpki.aspas"));
assert!(router.has_method("as2rel.search"));
assert!(router.has_method("as2rel.relationship"));
assert!(router.has_method("as2rel.update"));
assert!(router.has_method("pfx2as.lookup"));
assert!(router.has_method("database.status"));
assert!(router.has_method("database.refresh"));
assert!(router.has_method("inspect.query"));
assert!(router.has_method("inspect.refresh"));
assert!(!router.has_method("unknown.method"));
}
#[test]
fn test_router_streaming_flags() {
let router = create_router();
assert!(!router.is_streaming("system.info"));
assert!(!router.is_streaming("time.parse"));
assert!(!router.is_streaming("rpki.validate"));
assert!(!router.is_streaming("unknown.method"));
}
}