Skip to main content

monocle/server/
handler.rs

1//! Handler trait and context module for WebSocket methods
2//!
3//! This module defines the `WsMethod` trait which all WebSocket method handlers
4//! must implement, along with the `WsContext` which provides access to shared
5//! resources like database handles and configuration.
6
7use crate::server::op_sink::WsOpSink;
8use crate::server::protocol::{ErrorCode, ErrorData, RequestEnvelope};
9use async_trait::async_trait;
10use serde::de::DeserializeOwned;
11use serde_json::Value;
12use std::sync::Arc;
13
14// =============================================================================
15// Context
16// =============================================================================
17
18use crate::config::MonocleConfig;
19
20/// WebSocket context providing access to shared resources
21///
22/// This context is passed to all handlers and provides access to:
23/// - Database handles (MonocleDatabase)
24/// - Configuration settings
25/// - Operation registry for cancellation
26/// - Rate limiting state
27#[derive(Clone)]
28pub struct WsContext {
29    /// Monocle configuration (includes data_dir and cache TTLs)
30    pub config: MonocleConfig,
31}
32
33impl WsContext {
34    /// Create a new WebSocket context from MonocleConfig
35    pub fn from_config(config: MonocleConfig) -> Self {
36        Self { config }
37    }
38
39    /// Get the data directory path
40    pub fn data_dir(&self) -> &str {
41        &self.config.data_dir
42    }
43}
44
45impl Default for WsContext {
46    fn default() -> Self {
47        Self::from_config(MonocleConfig::default())
48    }
49}
50
51// =============================================================================
52// Request
53// =============================================================================
54
55/// Processed WebSocket request with guaranteed ID
56#[derive(Debug, Clone)]
57pub struct WsRequest {
58    /// Request correlation ID (client-provided or server-generated)
59    pub id: String,
60
61    /// Server-generated operation identifier (present for streaming/long operations)
62    pub op_id: Option<String>,
63
64    /// Method name
65    pub method: String,
66
67    /// Raw parameters
68    pub params: Value,
69}
70
71impl WsRequest {
72    /// Create a new request from an envelope, generating an ID if not provided.
73    ///
74    /// Note: `op_id` is assigned by the dispatcher/router for streaming/long operations.
75    pub fn from_envelope(envelope: RequestEnvelope) -> Self {
76        let id = envelope
77            .id
78            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
79        Self {
80            id,
81            op_id: None,
82            method: envelope.method,
83            params: envelope.params,
84        }
85    }
86}
87
88// =============================================================================
89// Handler Trait
90// =============================================================================
91
92/// Result type for WebSocket handlers
93pub type WsResult<T> = Result<T, WsError>;
94
95/// Error type for WebSocket handlers
96#[derive(Debug, Clone)]
97pub struct WsError {
98    /// Error code
99    pub code: ErrorCode,
100    /// Error message
101    pub message: String,
102    /// Optional details
103    pub details: Option<Value>,
104}
105
106impl WsError {
107    /// Create a new error
108    pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
109        Self {
110            code,
111            message: message.into(),
112            details: None,
113        }
114    }
115
116    /// Create an error with details
117    pub fn with_details(code: ErrorCode, message: impl Into<String>, details: Value) -> Self {
118        Self {
119            code,
120            message: message.into(),
121            details: Some(details),
122        }
123    }
124
125    /// Create an invalid params error
126    pub fn invalid_params(message: impl Into<String>) -> Self {
127        Self::new(ErrorCode::InvalidParams, message)
128    }
129
130    /// Create an operation failed error
131    pub fn operation_failed(message: impl Into<String>) -> Self {
132        Self::new(ErrorCode::OperationFailed, message)
133    }
134
135    /// Create a not initialized error
136    pub fn not_initialized(resource: &str) -> Self {
137        Self::new(
138            ErrorCode::NotInitialized,
139            format!("{} data not initialized", resource),
140        )
141    }
142
143    /// Create an internal error
144    pub fn internal(message: impl Into<String>) -> Self {
145        Self::new(ErrorCode::InternalError, message)
146    }
147
148    /// Convert to ErrorData
149    pub fn to_error_data(&self) -> ErrorData {
150        match &self.details {
151            Some(details) => {
152                ErrorData::with_details(self.code, self.message.clone(), details.clone())
153            }
154            None => ErrorData::new(self.code, self.message.clone()),
155        }
156    }
157}
158
159impl std::fmt::Display for WsError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(f, "{:?}: {}", self.code, self.message)
162    }
163}
164
165impl std::error::Error for WsError {}
166
167impl From<anyhow::Error> for WsError {
168    fn from(err: anyhow::Error) -> Self {
169        Self::operation_failed(err.to_string())
170    }
171}
172
173impl From<serde_json::Error> for WsError {
174    fn from(err: serde_json::Error) -> Self {
175        Self::invalid_params(err.to_string())
176    }
177}
178
179/// Trait for WebSocket method handlers
180///
181/// Each method handler implements this trait to define:
182/// - The method name (e.g., "rpki.validate")
183/// - Whether it's a streaming method
184/// - How to parse and validate parameters
185/// - How to execute the method
186#[async_trait]
187pub trait WsMethod: Send + Sync + 'static {
188    /// Fully qualified method name, e.g., "rpki.validate"
189    const METHOD: &'static str;
190
191    /// Whether this method is streaming (returns progress/stream messages)
192    const IS_STREAMING: bool = false;
193
194    /// Parameter type for this method
195    type Params: DeserializeOwned + Send;
196
197    /// Validate parameters after parsing
198    ///
199    /// Override this to perform additional validation beyond JSON deserialization.
200    fn validate(_params: &Self::Params) -> WsResult<()> {
201        Ok(())
202    }
203
204    /// Execute the method
205    ///
206    /// For non-streaming methods, this should send a single result via the sink.
207    /// For streaming methods, this may send progress/stream messages followed by a result.
208    async fn handle(
209        ctx: Arc<WsContext>,
210        req: WsRequest,
211        params: Self::Params,
212        sink: WsOpSink,
213    ) -> WsResult<()>;
214}
215
216// =============================================================================
217// Handler Registration
218// =============================================================================
219
220/// Type-erased handler function
221pub type DynHandler = Box<
222    dyn Fn(Arc<WsContext>, WsRequest, WsOpSink) -> futures::future::BoxFuture<'static, WsResult<()>>
223        + Send
224        + Sync,
225>;
226
227/// Create a type-erased handler from a WsMethod implementation
228pub fn make_handler<M: WsMethod>() -> DynHandler {
229    Box::new(move |ctx, req, sink| {
230        Box::pin(async move {
231            // Parse parameters
232            let params: M::Params = serde_json::from_value(req.params.clone())?;
233
234            // Validate parameters
235            M::validate(&params)?;
236
237            // Execute handler
238            M::handle(ctx, req, params, sink).await
239        })
240    })
241}
242
243// =============================================================================
244// Tests
245// =============================================================================
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_ws_context_default() {
253        let ctx = WsContext::default();
254        assert!(ctx.data_dir().contains("monocle"));
255    }
256
257    #[test]
258    fn test_ws_context_from_config() {
259        let config = MonocleConfig::default();
260        let ctx = WsContext::from_config(config.clone());
261        assert_eq!(ctx.data_dir(), &config.data_dir);
262    }
263
264    #[test]
265    fn test_ws_request_from_envelope() {
266        // With ID
267        let envelope = RequestEnvelope {
268            id: Some("test-id".to_string()),
269            method: "time.parse".to_string(),
270            params: serde_json::json!({}),
271        };
272        let req = WsRequest::from_envelope(envelope);
273        assert_eq!(req.id, "test-id");
274        assert_eq!(req.op_id, None);
275        assert_eq!(req.method, "time.parse");
276
277        // Without ID (should generate UUID)
278        let envelope = RequestEnvelope {
279            id: None,
280            method: "time.parse".to_string(),
281            params: serde_json::json!({}),
282        };
283        let req = WsRequest::from_envelope(envelope);
284        assert!(!req.id.is_empty());
285        assert_ne!(req.id, "test-id"); // Should be different
286    }
287
288    #[test]
289    fn test_ws_error_conversion() {
290        let err = WsError::invalid_params("missing field");
291        assert_eq!(err.code, ErrorCode::InvalidParams);
292        assert!(err.message.contains("missing field"));
293
294        let error_data = err.to_error_data();
295        assert_eq!(error_data.code, ErrorCode::InvalidParams);
296    }
297
298    #[test]
299    fn test_ws_error_from_anyhow() {
300        let anyhow_err = anyhow::anyhow!("something went wrong");
301        let ws_err: WsError = anyhow_err.into();
302        assert_eq!(ws_err.code, ErrorCode::OperationFailed);
303        assert!(ws_err.message.contains("something went wrong"));
304    }
305}