Skip to main content

monocle/server/
mod.rs

1//! WebSocket server module for Monocle
2//!
3//! This module provides a WebSocket API server for Monocle, enabling real-time
4//! communication with clients for BGP data operations.
5//!
6//! # Architecture
7//!
8//! The server is organized into several submodules:
9//!
10//! - `protocol` - Protocol types (request/response envelopes, error codes)
11//! - `query` - Non-core protocol helper types (pagination/filters) used by query/streaming methods
12//! - `handler` - Handler trait and context for method implementations
13//! - `sink` - WebSocket sink abstraction for typed envelope writing (transport-level)
14//! - `op_sink` - Operation-scoped sink enforcing streaming terminal semantics (protocol-level)
15//! - `router` - Registry-based method routing
16//! - `operations` - Operation registry for streaming operations and cancellation
17//! - `handlers` - Individual method handler implementations
18//!
19//! # Connection lifecycle
20//!
21//! The WebSocket connection loop enforces:
22//! - max message size (`ServerConfig.max_message_size`)
23//! - periodic ping keepalive (`ServerConfig.ping_interval_secs`)
24//! - idle timeout (`ServerConfig.connection_timeout_secs`)
25//!
26//! # Usage
27//!
28//! ```rust,ignore
29//! use monocle::server::{create_router, WsContext, ServerConfig};
30//! use monocle::config::MonocleConfig;
31//!
32//! // Create the router with all handlers registered
33//! let router = create_router();
34//!
35//! // Create context from config
36//! let config = MonocleConfig::new(&None)?;
37//! let context = WsContext::from_config(config);
38//!
39//! // Start the server
40//! let server_config = ServerConfig::default();
41//! start_server(router, context, server_config).await?;
42//! ```
43
44pub mod handler;
45pub mod handlers;
46pub mod op_sink;
47pub mod operations;
48pub mod protocol;
49pub mod query;
50pub mod router;
51pub mod sink;
52
53// Re-export commonly used types
54pub use handler::{WsContext, WsError, WsMethod, WsRequest, WsResult};
55pub use op_sink::{WsOpSink, WsOpSinkError};
56pub use operations::{OperationRegistry, OperationStatus};
57pub use protocol::{
58    ErrorCode, ErrorData, ProgressStage, RequestEnvelope, ResponseEnvelope, ResponseType,
59    SystemInfo,
60};
61pub use router::{Dispatcher, Router};
62pub use sink::{WsSink, WsSinkError};
63
64use axum::{
65    extract::{
66        ws::{Message, WebSocket, WebSocketUpgrade},
67        State,
68    },
69    response::Response,
70    routing::get,
71    Router as AxumRouter,
72};
73use futures::StreamExt;
74use std::sync::Arc;
75use tokio::time::{Duration, Instant};
76use tower_http::cors::{Any, CorsLayer};
77
78// =============================================================================
79// Server Configuration
80// =============================================================================
81
82/// Server configuration
83#[derive(Debug, Clone)]
84pub struct ServerConfig {
85    /// Address to bind to
86    pub address: String,
87
88    /// Port to listen on
89    pub port: u16,
90
91    /// Maximum concurrent operations per connection
92    pub max_concurrent_ops: usize,
93
94    /// Maximum message size in bytes
95    pub max_message_size: usize,
96
97    /// Connection timeout in seconds
98    pub connection_timeout_secs: u64,
99
100    /// Ping interval in seconds
101    pub ping_interval_secs: u64,
102}
103
104impl Default for ServerConfig {
105    fn default() -> Self {
106        Self {
107            address: "127.0.0.1".to_string(),
108            port: 8080,
109            max_concurrent_ops: 10,
110            max_message_size: 1024 * 1024, // 1MB
111            connection_timeout_secs: 300,  // 5 minutes
112            ping_interval_secs: 30,
113        }
114    }
115}
116
117impl ServerConfig {
118    /// Create a new server configuration
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Set the address
124    pub fn with_address(mut self, address: impl Into<String>) -> Self {
125        self.address = address.into();
126        self
127    }
128
129    /// Set the port
130    pub fn with_port(mut self, port: u16) -> Self {
131        self.port = port;
132        self
133    }
134
135    /// Get the full bind address
136    pub fn bind_address(&self) -> String {
137        format!("{}:{}", self.address, self.port)
138    }
139}
140
141// =============================================================================
142// Router Creation
143// =============================================================================
144
145/// Create a router with all handlers registered
146pub fn create_router() -> Router {
147    use handlers::*;
148
149    let mut router = Router::new();
150
151    // System handlers
152    router.register::<SystemInfoHandler>();
153
154    // Time handlers
155    router.register::<TimeParseHandler>();
156
157    // Country handlers
158    router.register::<CountryLookupHandler>();
159
160    // IP handlers
161    router.register::<IpLookupHandler>();
162    router.register::<IpPublicHandler>();
163
164    // RPKI handlers
165    router.register::<RpkiValidateHandler>();
166    router.register::<RpkiRoasHandler>();
167    router.register::<RpkiAspasHandler>();
168
169    // AS2Rel handlers
170    router.register::<As2relSearchHandler>();
171    router.register::<As2relRelationshipHandler>();
172    router.register::<As2relUpdateHandler>();
173
174    // Pfx2as handlers
175    router.register::<Pfx2asLookupHandler>();
176
177    // Database handlers
178    router.register::<DatabaseStatusHandler>();
179    router.register::<DatabaseRefreshHandler>();
180
181    // Inspect handlers
182    router.register::<InspectQueryHandler>();
183    router.register::<InspectRefreshHandler>();
184
185    router
186}
187
188// =============================================================================
189// Server State
190// =============================================================================
191
192/// Shared server state
193#[derive(Clone)]
194pub struct ServerState {
195    /// Dispatcher for routing messages
196    pub dispatcher: Arc<Dispatcher>,
197
198    /// Server configuration
199    pub config: Arc<ServerConfig>,
200}
201
202// =============================================================================
203// Axum Router Creation
204// =============================================================================
205
206/// Create the Axum router for the WebSocket server
207pub fn create_axum_router(state: ServerState) -> AxumRouter {
208    // Configure CORS
209    let cors = CorsLayer::new()
210        .allow_origin(Any)
211        .allow_methods(Any)
212        .allow_headers(Any);
213
214    AxumRouter::new()
215        .route("/ws", get(ws_handler))
216        .route("/health", get(health_handler))
217        .layer(cors)
218        .with_state(state)
219}
220
221/// Health check handler
222async fn health_handler() -> &'static str {
223    "OK"
224}
225
226/// WebSocket upgrade handler
227async fn ws_handler(ws: WebSocketUpgrade, State(state): State<ServerState>) -> Response {
228    ws.on_upgrade(move |socket| handle_socket(socket, state))
229}
230
231/// Handle a WebSocket connection
232async fn handle_socket(socket: WebSocket, state: ServerState) {
233    let (sender, mut receiver) = socket.split();
234    let sink = WsSink::new(sender);
235
236    tracing::info!("WebSocket connection established");
237
238    let max_message_size = state.config.max_message_size;
239    let ping_interval = Duration::from_secs(state.config.ping_interval_secs.max(1));
240    let idle_timeout = Duration::from_secs(state.config.connection_timeout_secs.max(1));
241
242    let mut last_activity = Instant::now();
243    let mut next_ping = Instant::now() + ping_interval;
244
245    // Connection loop: enforce max message size, periodic ping keepalive, and idle timeout.
246    loop {
247        tokio::select! {
248            maybe_msg = receiver.next() => {
249                let Some(msg) = maybe_msg else {
250                    break;
251                };
252
253                match msg {
254                    Ok(Message::Text(text)) => {
255                        if text.len() > max_message_size {
256                            tracing::warn!(
257                                "Closing connection: text message too large ({} > {} bytes)",
258                                text.len(),
259                                max_message_size
260                            );
261                            let _ = sink.send_message_raw(Message::Close(None)).await;
262                            break;
263                        }
264                        last_activity = Instant::now();
265                        tracing::debug!("Received message: {}", text);
266                        state.dispatcher.dispatch(&text, sink.clone()).await;
267                    }
268                    Ok(Message::Binary(data)) => {
269                        if data.len() > max_message_size {
270                            tracing::warn!(
271                                "Closing connection: binary message too large ({} > {} bytes)",
272                                data.len(),
273                                max_message_size
274                            );
275                            let _ = sink.send_message_raw(Message::Close(None)).await;
276                            break;
277                        }
278                        last_activity = Instant::now();
279
280                        // Try to parse binary as UTF-8 text
281                        match String::from_utf8(data) {
282                            Ok(text) => {
283                                tracing::debug!("Received binary message as text: {}", text);
284                                state.dispatcher.dispatch(&text, sink.clone()).await;
285                            }
286                            Err(_) => {
287                                tracing::warn!("Received non-UTF8 binary message, ignoring");
288                            }
289                        }
290                    }
291                    Ok(Message::Ping(data)) => {
292                        last_activity = Instant::now();
293                        // Respond with pong
294                        if let Err(e) = sink.send_message_raw(Message::Pong(data)).await {
295                            tracing::warn!("Failed to send pong: {}", e);
296                            break;
297                        }
298                    }
299                    Ok(Message::Pong(_)) => {
300                        last_activity = Instant::now();
301                        // Ignore pong responses
302                    }
303                    Ok(Message::Close(_)) => {
304                        tracing::info!("WebSocket connection closed by client");
305                        break;
306                    }
307                    Err(e) => {
308                        tracing::error!("WebSocket error: {}", e);
309                        break;
310                    }
311                }
312            }
313
314            _ = tokio::time::sleep_until(next_ping) => {
315                // Idle timeout check
316                if last_activity.elapsed() > idle_timeout {
317                    tracing::info!(
318                        "Closing connection due to idle timeout (>{}s)",
319                        idle_timeout.as_secs()
320                    );
321                    let _ = sink.send_message_raw(Message::Close(None)).await;
322                    break;
323                }
324
325                // Periodic ping keepalive
326                if let Err(e) = sink.send_message_raw(Message::Ping(Vec::new())).await {
327                    tracing::warn!("Failed to send ping: {}", e);
328                    break;
329                }
330
331                next_ping = Instant::now() + ping_interval;
332            }
333        }
334    }
335
336    tracing::info!("WebSocket connection closed");
337}
338
339// =============================================================================
340// Server Startup
341// =============================================================================
342
343/// Start the WebSocket server
344pub async fn start_server(
345    router: Router,
346    context: WsContext,
347    config: ServerConfig,
348) -> anyhow::Result<()> {
349    let operations = OperationRegistry::with_max_concurrent(config.max_concurrent_ops);
350    let dispatcher = Dispatcher::new(router, context, operations);
351
352    let state = ServerState {
353        dispatcher: Arc::new(dispatcher),
354        config: Arc::new(config.clone()),
355    };
356
357    let app = create_axum_router(state);
358
359    let bind_address = config.bind_address();
360    tracing::info!("Starting WebSocket server on {}", bind_address);
361
362    let listener = tokio::net::TcpListener::bind(&bind_address).await?;
363    axum::serve(listener, app).await?;
364
365    Ok(())
366}
367
368// =============================================================================
369// Tests
370// =============================================================================
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_server_config_default() {
378        let config = ServerConfig::default();
379        assert_eq!(config.address, "127.0.0.1");
380        assert_eq!(config.port, 8080);
381        assert_eq!(config.max_concurrent_ops, 10);
382    }
383
384    #[test]
385    fn test_server_config_builder() {
386        let config = ServerConfig::new().with_address("0.0.0.0").with_port(9000);
387
388        assert_eq!(config.address, "0.0.0.0");
389        assert_eq!(config.port, 9000);
390        assert_eq!(config.bind_address(), "0.0.0.0:9000");
391    }
392
393    #[test]
394    fn test_create_router() {
395        let router = create_router();
396
397        // Check that key methods are registered
398        assert!(router.has_method("system.info"));
399        assert!(router.has_method("time.parse"));
400        assert!(router.has_method("country.lookup"));
401        assert!(router.has_method("ip.lookup"));
402        assert!(router.has_method("ip.public"));
403        assert!(router.has_method("rpki.validate"));
404        assert!(router.has_method("rpki.roas"));
405        assert!(router.has_method("rpki.aspas"));
406        assert!(router.has_method("as2rel.search"));
407        assert!(router.has_method("as2rel.relationship"));
408        assert!(router.has_method("as2rel.update"));
409        assert!(router.has_method("pfx2as.lookup"));
410        assert!(router.has_method("database.status"));
411        assert!(router.has_method("database.refresh"));
412        assert!(router.has_method("inspect.query"));
413        assert!(router.has_method("inspect.refresh"));
414
415        // Check that unknown methods return false
416        assert!(!router.has_method("unknown.method"));
417    }
418
419    #[test]
420    fn test_router_streaming_flags() {
421        let router = create_router();
422
423        // Non-streaming methods
424        assert!(!router.is_streaming("system.info"));
425        assert!(!router.is_streaming("time.parse"));
426        assert!(!router.is_streaming("rpki.validate"));
427
428        // Unknown methods should return false
429        assert!(!router.is_streaming("unknown.method"));
430    }
431}