1pub mod handler;
45pub mod handlers;
46pub mod op_sink;
47pub mod operations;
48pub mod protocol;
49pub mod query;
50pub mod router;
51pub mod sink;
52
53pub use handler::{WsContext, WsError, WsMethod, WsRequest, WsResult};
55pub use op_sink::{WsOpSink, WsOpSinkError};
56pub use operations::{OperationRegistry, OperationStatus};
57pub use protocol::{
58 ErrorCode, ErrorData, ProgressStage, RequestEnvelope, ResponseEnvelope, ResponseType,
59 SystemInfo,
60};
61pub use router::{Dispatcher, Router};
62pub use sink::{WsSink, WsSinkError};
63
64use axum::{
65 extract::{
66 ws::{Message, WebSocket, WebSocketUpgrade},
67 State,
68 },
69 response::Response,
70 routing::get,
71 Router as AxumRouter,
72};
73use futures::StreamExt;
74use std::sync::Arc;
75use tokio::time::{Duration, Instant};
76use tower_http::cors::{Any, CorsLayer};
77
78#[derive(Debug, Clone)]
84pub struct ServerConfig {
85 pub address: String,
87
88 pub port: u16,
90
91 pub max_concurrent_ops: usize,
93
94 pub max_message_size: usize,
96
97 pub connection_timeout_secs: u64,
99
100 pub ping_interval_secs: u64,
102}
103
104impl Default for ServerConfig {
105 fn default() -> Self {
106 Self {
107 address: "127.0.0.1".to_string(),
108 port: 8080,
109 max_concurrent_ops: 10,
110 max_message_size: 1024 * 1024, connection_timeout_secs: 300, ping_interval_secs: 30,
113 }
114 }
115}
116
117impl ServerConfig {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn with_address(mut self, address: impl Into<String>) -> Self {
125 self.address = address.into();
126 self
127 }
128
129 pub fn with_port(mut self, port: u16) -> Self {
131 self.port = port;
132 self
133 }
134
135 pub fn bind_address(&self) -> String {
137 format!("{}:{}", self.address, self.port)
138 }
139}
140
141pub fn create_router() -> Router {
147 use handlers::*;
148
149 let mut router = Router::new();
150
151 router.register::<SystemInfoHandler>();
153
154 router.register::<TimeParseHandler>();
156
157 router.register::<CountryLookupHandler>();
159
160 router.register::<IpLookupHandler>();
162 router.register::<IpPublicHandler>();
163
164 router.register::<RpkiValidateHandler>();
166 router.register::<RpkiRoasHandler>();
167 router.register::<RpkiAspasHandler>();
168
169 router.register::<As2relSearchHandler>();
171 router.register::<As2relRelationshipHandler>();
172 router.register::<As2relUpdateHandler>();
173
174 router.register::<Pfx2asLookupHandler>();
176
177 router.register::<DatabaseStatusHandler>();
179 router.register::<DatabaseRefreshHandler>();
180
181 router.register::<InspectQueryHandler>();
183 router.register::<InspectRefreshHandler>();
184
185 router
186}
187
188#[derive(Clone)]
194pub struct ServerState {
195 pub dispatcher: Arc<Dispatcher>,
197
198 pub config: Arc<ServerConfig>,
200}
201
202pub fn create_axum_router(state: ServerState) -> AxumRouter {
208 let cors = CorsLayer::new()
210 .allow_origin(Any)
211 .allow_methods(Any)
212 .allow_headers(Any);
213
214 AxumRouter::new()
215 .route("/ws", get(ws_handler))
216 .route("/health", get(health_handler))
217 .layer(cors)
218 .with_state(state)
219}
220
221async fn health_handler() -> &'static str {
223 "OK"
224}
225
226async fn ws_handler(ws: WebSocketUpgrade, State(state): State<ServerState>) -> Response {
228 ws.on_upgrade(move |socket| handle_socket(socket, state))
229}
230
231async fn handle_socket(socket: WebSocket, state: ServerState) {
233 let (sender, mut receiver) = socket.split();
234 let sink = WsSink::new(sender);
235
236 tracing::info!("WebSocket connection established");
237
238 let max_message_size = state.config.max_message_size;
239 let ping_interval = Duration::from_secs(state.config.ping_interval_secs.max(1));
240 let idle_timeout = Duration::from_secs(state.config.connection_timeout_secs.max(1));
241
242 let mut last_activity = Instant::now();
243 let mut next_ping = Instant::now() + ping_interval;
244
245 loop {
247 tokio::select! {
248 maybe_msg = receiver.next() => {
249 let Some(msg) = maybe_msg else {
250 break;
251 };
252
253 match msg {
254 Ok(Message::Text(text)) => {
255 if text.len() > max_message_size {
256 tracing::warn!(
257 "Closing connection: text message too large ({} > {} bytes)",
258 text.len(),
259 max_message_size
260 );
261 let _ = sink.send_message_raw(Message::Close(None)).await;
262 break;
263 }
264 last_activity = Instant::now();
265 tracing::debug!("Received message: {}", text);
266 state.dispatcher.dispatch(&text, sink.clone()).await;
267 }
268 Ok(Message::Binary(data)) => {
269 if data.len() > max_message_size {
270 tracing::warn!(
271 "Closing connection: binary message too large ({} > {} bytes)",
272 data.len(),
273 max_message_size
274 );
275 let _ = sink.send_message_raw(Message::Close(None)).await;
276 break;
277 }
278 last_activity = Instant::now();
279
280 match String::from_utf8(data) {
282 Ok(text) => {
283 tracing::debug!("Received binary message as text: {}", text);
284 state.dispatcher.dispatch(&text, sink.clone()).await;
285 }
286 Err(_) => {
287 tracing::warn!("Received non-UTF8 binary message, ignoring");
288 }
289 }
290 }
291 Ok(Message::Ping(data)) => {
292 last_activity = Instant::now();
293 if let Err(e) = sink.send_message_raw(Message::Pong(data)).await {
295 tracing::warn!("Failed to send pong: {}", e);
296 break;
297 }
298 }
299 Ok(Message::Pong(_)) => {
300 last_activity = Instant::now();
301 }
303 Ok(Message::Close(_)) => {
304 tracing::info!("WebSocket connection closed by client");
305 break;
306 }
307 Err(e) => {
308 tracing::error!("WebSocket error: {}", e);
309 break;
310 }
311 }
312 }
313
314 _ = tokio::time::sleep_until(next_ping) => {
315 if last_activity.elapsed() > idle_timeout {
317 tracing::info!(
318 "Closing connection due to idle timeout (>{}s)",
319 idle_timeout.as_secs()
320 );
321 let _ = sink.send_message_raw(Message::Close(None)).await;
322 break;
323 }
324
325 if let Err(e) = sink.send_message_raw(Message::Ping(Vec::new())).await {
327 tracing::warn!("Failed to send ping: {}", e);
328 break;
329 }
330
331 next_ping = Instant::now() + ping_interval;
332 }
333 }
334 }
335
336 tracing::info!("WebSocket connection closed");
337}
338
339pub async fn start_server(
345 router: Router,
346 context: WsContext,
347 config: ServerConfig,
348) -> anyhow::Result<()> {
349 let operations = OperationRegistry::with_max_concurrent(config.max_concurrent_ops);
350 let dispatcher = Dispatcher::new(router, context, operations);
351
352 let state = ServerState {
353 dispatcher: Arc::new(dispatcher),
354 config: Arc::new(config.clone()),
355 };
356
357 let app = create_axum_router(state);
358
359 let bind_address = config.bind_address();
360 tracing::info!("Starting WebSocket server on {}", bind_address);
361
362 let listener = tokio::net::TcpListener::bind(&bind_address).await?;
363 axum::serve(listener, app).await?;
364
365 Ok(())
366}
367
368#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_server_config_default() {
378 let config = ServerConfig::default();
379 assert_eq!(config.address, "127.0.0.1");
380 assert_eq!(config.port, 8080);
381 assert_eq!(config.max_concurrent_ops, 10);
382 }
383
384 #[test]
385 fn test_server_config_builder() {
386 let config = ServerConfig::new().with_address("0.0.0.0").with_port(9000);
387
388 assert_eq!(config.address, "0.0.0.0");
389 assert_eq!(config.port, 9000);
390 assert_eq!(config.bind_address(), "0.0.0.0:9000");
391 }
392
393 #[test]
394 fn test_create_router() {
395 let router = create_router();
396
397 assert!(router.has_method("system.info"));
399 assert!(router.has_method("time.parse"));
400 assert!(router.has_method("country.lookup"));
401 assert!(router.has_method("ip.lookup"));
402 assert!(router.has_method("ip.public"));
403 assert!(router.has_method("rpki.validate"));
404 assert!(router.has_method("rpki.roas"));
405 assert!(router.has_method("rpki.aspas"));
406 assert!(router.has_method("as2rel.search"));
407 assert!(router.has_method("as2rel.relationship"));
408 assert!(router.has_method("as2rel.update"));
409 assert!(router.has_method("pfx2as.lookup"));
410 assert!(router.has_method("database.status"));
411 assert!(router.has_method("database.refresh"));
412 assert!(router.has_method("inspect.query"));
413 assert!(router.has_method("inspect.refresh"));
414
415 assert!(!router.has_method("unknown.method"));
417 }
418
419 #[test]
420 fn test_router_streaming_flags() {
421 let router = create_router();
422
423 assert!(!router.is_streaming("system.info"));
425 assert!(!router.is_streaming("time.parse"));
426 assert!(!router.is_streaming("rpki.validate"));
427
428 assert!(!router.is_streaming("unknown.method"));
430 }
431}