monocle/server/
handler.rs

1//! Handler trait and context module for WebSocket methods
2//!
3//! This module defines the `WsMethod` trait which all WebSocket method handlers
4//! must implement, along with the `WsContext` which provides access to shared
5//! resources like database handles and configuration.
6
7use crate::server::op_sink::WsOpSink;
8use crate::server::protocol::{ErrorCode, ErrorData, RequestEnvelope};
9use async_trait::async_trait;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12use std::sync::Arc;
13
14// =============================================================================
15// Context
16// =============================================================================
17
18/// WebSocket context providing access to shared resources
19///
20/// This context is passed to all handlers and provides access to:
21/// - Database handles (MonocleDatabase)
22/// - Configuration settings
23/// - Operation registry for cancellation
24/// - Rate limiting state
25#[derive(Clone)]
26pub struct WsContext {
27    /// Path to the monocle data directory
28    pub data_dir: String,
29}
30
31impl WsContext {
32    /// Create a new WebSocket context
33    ///
34    /// Note: transport policy (message size, timeouts, concurrency limits) is owned by `ServerConfig`
35    /// and enforced in the connection loop / dispatcher layer.
36    pub fn new(data_dir: String) -> Self {
37        Self { data_dir }
38    }
39}
40
41impl Default for WsContext {
42    fn default() -> Self {
43        let home_dir = dirs::home_dir()
44            .map(|h| h.to_string_lossy().to_string())
45            .unwrap_or_else(|| ".".to_string());
46
47        Self::new(format!("{}/.monocle", home_dir))
48    }
49}
50
51// =============================================================================
52// Request
53// =============================================================================
54
55/// Processed WebSocket request with guaranteed ID
56#[derive(Debug, Clone)]
57pub struct WsRequest {
58    /// Request correlation ID (client-provided or server-generated)
59    pub id: String,
60
61    /// Server-generated operation identifier (present for streaming/long operations)
62    pub op_id: Option<String>,
63
64    /// Method name
65    pub method: String,
66
67    /// Raw parameters
68    pub params: Value,
69}
70
71impl WsRequest {
72    /// Create a new request from an envelope, generating an ID if not provided.
73    ///
74    /// Note: `op_id` is assigned by the dispatcher/router for streaming/long operations.
75    pub fn from_envelope(envelope: RequestEnvelope) -> Self {
76        let id = envelope
77            .id
78            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
79        Self {
80            id,
81            op_id: None,
82            method: envelope.method,
83            params: envelope.params,
84        }
85    }
86}
87
88// =============================================================================
89// Handler Trait
90// =============================================================================
91
92/// Result type for WebSocket handlers
93pub type WsResult<T> = Result<T, WsError>;
94
95/// Error type for WebSocket handlers
96#[derive(Debug, Clone)]
97pub struct WsError {
98    /// Error code
99    pub code: ErrorCode,
100    /// Error message
101    pub message: String,
102    /// Optional details
103    pub details: Option<Value>,
104}
105
106impl WsError {
107    /// Create a new error
108    pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
109        Self {
110            code,
111            message: message.into(),
112            details: None,
113        }
114    }
115
116    /// Create an error with details
117    pub fn with_details(code: ErrorCode, message: impl Into<String>, details: Value) -> Self {
118        Self {
119            code,
120            message: message.into(),
121            details: Some(details),
122        }
123    }
124
125    /// Create an invalid params error
126    pub fn invalid_params(message: impl Into<String>) -> Self {
127        Self::new(ErrorCode::InvalidParams, message)
128    }
129
130    /// Create an operation failed error
131    pub fn operation_failed(message: impl Into<String>) -> Self {
132        Self::new(ErrorCode::OperationFailed, message)
133    }
134
135    /// Create a not initialized error
136    pub fn not_initialized(resource: &str) -> Self {
137        Self::new(
138            ErrorCode::NotInitialized,
139            format!("{} data not initialized", resource),
140        )
141    }
142
143    /// Create an internal error
144    pub fn internal(message: impl Into<String>) -> Self {
145        Self::new(ErrorCode::InternalError, message)
146    }
147
148    /// Convert to ErrorData
149    pub fn to_error_data(&self) -> ErrorData {
150        match &self.details {
151            Some(details) => {
152                ErrorData::with_details(self.code, self.message.clone(), details.clone())
153            }
154            None => ErrorData::new(self.code, self.message.clone()),
155        }
156    }
157}
158
159impl std::fmt::Display for WsError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(f, "{:?}: {}", self.code, self.message)
162    }
163}
164
165impl std::error::Error for WsError {}
166
167impl From<anyhow::Error> for WsError {
168    fn from(err: anyhow::Error) -> Self {
169        Self::operation_failed(err.to_string())
170    }
171}
172
173impl From<serde_json::Error> for WsError {
174    fn from(err: serde_json::Error) -> Self {
175        Self::invalid_params(err.to_string())
176    }
177}
178
179/// Trait for WebSocket method handlers
180///
181/// Each method handler implements this trait to define:
182/// - The method name (e.g., "rpki.validate")
183/// - Whether it's a streaming method
184/// - How to parse and validate parameters
185/// - How to execute the method
186#[async_trait]
187pub trait WsMethod: Send + Sync + 'static {
188    /// Fully qualified method name, e.g., "rpki.validate"
189    const METHOD: &'static str;
190
191    /// Whether this method is streaming (returns progress/stream messages)
192    const IS_STREAMING: bool = false;
193
194    /// Parameter type for this method
195    type Params: DeserializeOwned + Send;
196
197    /// Validate parameters after parsing
198    ///
199    /// Override this to perform additional validation beyond JSON deserialization.
200    fn validate(_params: &Self::Params) -> WsResult<()> {
201        Ok(())
202    }
203
204    /// Execute the method
205    ///
206    /// For non-streaming methods, this should send a single result via the sink.
207    /// For streaming methods, this may send progress/stream messages followed by a result.
208    async fn handle(
209        ctx: Arc<WsContext>,
210        req: WsRequest,
211        params: Self::Params,
212        sink: WsOpSink,
213    ) -> WsResult<()>;
214}
215
216// =============================================================================
217// Handler Registration
218// =============================================================================
219
220/// Type-erased handler function
221pub type DynHandler = Box<
222    dyn Fn(Arc<WsContext>, WsRequest, WsOpSink) -> futures::future::BoxFuture<'static, WsResult<()>>
223        + Send
224        + Sync,
225>;
226
227/// Create a type-erased handler from a WsMethod implementation
228pub fn make_handler<M: WsMethod>() -> DynHandler {
229    Box::new(move |ctx, req, sink| {
230        Box::pin(async move {
231            // Parse parameters
232            let params: M::Params = serde_json::from_value(req.params.clone())?;
233
234            // Validate parameters
235            M::validate(&params)?;
236
237            // Execute handler
238            M::handle(ctx, req, params, sink).await
239        })
240    })
241}
242
243// =============================================================================
244// Tests
245// =============================================================================
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_ws_context_default() {
253        let ctx = WsContext::default();
254        assert!(ctx.data_dir.contains(".monocle"));
255    }
256
257    #[test]
258    fn test_ws_context_new() {
259        let ctx = WsContext::new("/tmp/test".to_string());
260        assert_eq!(ctx.data_dir, "/tmp/test");
261    }
262
263    #[test]
264    fn test_ws_request_from_envelope() {
265        // With ID
266        let envelope = RequestEnvelope {
267            id: Some("test-id".to_string()),
268            method: "time.parse".to_string(),
269            params: serde_json::json!({}),
270        };
271        let req = WsRequest::from_envelope(envelope);
272        assert_eq!(req.id, "test-id");
273        assert_eq!(req.op_id, None);
274        assert_eq!(req.method, "time.parse");
275
276        // Without ID (should generate UUID)
277        let envelope = RequestEnvelope {
278            id: None,
279            method: "time.parse".to_string(),
280            params: serde_json::json!({}),
281        };
282        let req = WsRequest::from_envelope(envelope);
283        assert!(!req.id.is_empty());
284        assert_ne!(req.id, "test-id"); // Should be different
285    }
286
287    #[test]
288    fn test_ws_error_conversion() {
289        let err = WsError::invalid_params("missing field");
290        assert_eq!(err.code, ErrorCode::InvalidParams);
291        assert!(err.message.contains("missing field"));
292
293        let error_data = err.to_error_data();
294        assert_eq!(error_data.code, ErrorCode::InvalidParams);
295    }
296
297    #[test]
298    fn test_ws_error_from_anyhow() {
299        let anyhow_err = anyhow::anyhow!("something went wrong");
300        let ws_err: WsError = anyhow_err.into();
301        assert_eq!(ws_err.code, ErrorCode::OperationFailed);
302        assert!(ws_err.message.contains("something went wrong"));
303    }
304}