Skip to main content

nestforge_websockets/
lib.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::Arc,
5};
6
7use anyhow::Result;
8use axum::{
9    extract::Extension,
10    http::HeaderMap,
11    routing::get,
12    Router,
13};
14use nestforge_core::{AuthIdentity, Container, RequestId};
15use nestforge_microservices::{
16    EventEnvelope, MessageEnvelope, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
17};
18use serde_json::Value;
19
20pub use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket, WebSocketUpgrade};
21
22type WebSocketFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
23
24pub trait WebSocketGateway: Send + Sync + 'static {
25    fn on_connect(&self, ctx: WebSocketContext, socket: WebSocket) -> WebSocketFuture;
26}
27
28#[derive(Debug, Clone)]
29pub struct WebSocketConfig {
30    pub endpoint: String,
31}
32
33impl Default for WebSocketConfig {
34    fn default() -> Self {
35        Self {
36            endpoint: "/ws".to_string(),
37        }
38    }
39}
40
41impl WebSocketConfig {
42    pub fn new(endpoint: impl Into<String>) -> Self {
43        Self {
44            endpoint: normalize_path(endpoint.into(), "/ws"),
45        }
46    }
47}
48
49#[derive(Clone)]
50pub struct WebSocketContext {
51    container: Container,
52    request_id: Option<RequestId>,
53    auth_identity: Option<AuthIdentity>,
54    headers: HeaderMap,
55}
56
57impl WebSocketContext {
58    pub fn new(
59        container: Container,
60        request_id: Option<RequestId>,
61        auth_identity: Option<AuthIdentity>,
62        headers: HeaderMap,
63    ) -> Self {
64        Self {
65            container,
66            request_id,
67            auth_identity,
68            headers,
69        }
70    }
71
72    pub fn container(&self) -> &Container {
73        &self.container
74    }
75
76    pub fn request_id(&self) -> Option<&RequestId> {
77        self.request_id.as_ref()
78    }
79
80    pub fn auth_identity(&self) -> Option<&AuthIdentity> {
81        self.auth_identity.as_ref()
82    }
83
84    pub fn headers(&self) -> &HeaderMap {
85        &self.headers
86    }
87
88    pub fn resolve<T>(&self) -> Result<Arc<T>>
89    where
90        T: Send + Sync + 'static,
91    {
92        Ok(self.container.resolve::<T>()?)
93    }
94
95    pub fn is_authenticated(&self) -> bool {
96        self.auth_identity.is_some()
97    }
98
99    pub fn has_role(&self, role: &str) -> bool {
100        self.auth_identity
101            .as_ref()
102            .map(|identity| identity.roles.iter().any(|value| value == role))
103            .unwrap_or(false)
104    }
105
106    pub fn microservice_context(
107        &self,
108        pattern: impl Into<String>,
109        metadata: TransportMetadata,
110    ) -> MicroserviceContext {
111        MicroserviceContext::new(self.container.clone(), "websocket", pattern, metadata)
112    }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
116#[serde(rename_all = "snake_case")]
117pub enum WebSocketMicroserviceKind {
118    Message,
119    Event,
120}
121
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct WebSocketMicroserviceFrame {
124    pub kind: WebSocketMicroserviceKind,
125    pub pattern: String,
126    pub payload: Value,
127    #[serde(default)]
128    pub metadata: TransportMetadata,
129}
130
131#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
132pub struct WebSocketMicroserviceResponse {
133    pub pattern: String,
134    pub payload: Value,
135    #[serde(default)]
136    pub metadata: TransportMetadata,
137}
138
139pub async fn handle_websocket_microservice_message(
140    ctx: &WebSocketContext,
141    registry: &MicroserviceRegistry,
142    message: Message,
143) -> Result<Option<Message>> {
144    let frame = parse_websocket_microservice_frame(message)?;
145    let microservice_ctx = ctx.microservice_context(frame.pattern.clone(), frame.metadata.clone());
146
147    match frame.kind {
148        WebSocketMicroserviceKind::Message => {
149            let payload = registry
150                .dispatch_message(
151                    MessageEnvelope {
152                        pattern: frame.pattern.clone(),
153                        payload: frame.payload,
154                        metadata: frame.metadata.clone(),
155                    },
156                    microservice_ctx,
157                )
158                .await?;
159            let response = WebSocketMicroserviceResponse {
160                pattern: frame.pattern,
161                payload,
162                metadata: frame.metadata,
163            };
164            Ok(Some(Message::Text(serde_json::to_string(&response)?.into())))
165        }
166        WebSocketMicroserviceKind::Event => {
167            registry
168                .dispatch_event(
169                    EventEnvelope {
170                        pattern: frame.pattern,
171                        payload: frame.payload,
172                        metadata: frame.metadata,
173                    },
174                    microservice_ctx,
175                )
176                .await?;
177            Ok(None)
178        }
179    }
180}
181
182fn parse_websocket_microservice_frame(message: Message) -> Result<WebSocketMicroserviceFrame> {
183    match message {
184        Message::Text(text) => Ok(serde_json::from_str(&text)?),
185        Message::Binary(bytes) => Ok(serde_json::from_slice(&bytes)?),
186        other => anyhow::bail!("Unsupported websocket microservice message: {other:?}"),
187    }
188}
189
190pub fn websocket_gateway_router<G>(gateway: G) -> Router<Container>
191where
192    G: WebSocketGateway,
193{
194    websocket_gateway_router_with_config(gateway, WebSocketConfig::default())
195}
196
197pub fn websocket_gateway_router_with_config<G>(
198    gateway: G,
199    config: WebSocketConfig,
200) -> Router<Container>
201where
202    G: WebSocketGateway,
203{
204    let gateway = Arc::new(gateway);
205    Router::new().route(
206        &config.endpoint,
207        get(move |ws: WebSocketUpgrade,
208                  Extension(container): Extension<Container>,
209                  headers: HeaderMap,
210                  Extension(request_id): Extension<RequestId>| {
211                let gateway = Arc::clone(&gateway);
212                let auth_identity = container
213                    .resolve::<AuthIdentity>()
214                    .ok()
215                    .map(|value| (*value).clone());
216                async move {
217                    let context =
218                        WebSocketContext::new(container, Some(request_id), auth_identity, headers);
219                    ws.on_upgrade(move |socket| async move {
220                        gateway.on_connect(context, socket).await;
221                    })
222                }
223            },
224        ),
225    )
226}
227
228pub fn websocket_router<F, Fut>(handler: F) -> Router<Container>
229where
230    F: Fn(WebSocketContext, WebSocket) -> Fut + Clone + Send + Sync + 'static,
231    Fut: Future<Output = ()> + Send + 'static,
232{
233    websocket_router_with_config(handler, WebSocketConfig::default())
234}
235
236pub fn websocket_router_with_config<F, Fut>(
237    handler: F,
238    config: WebSocketConfig,
239) -> Router<Container>
240where
241    F: Fn(WebSocketContext, WebSocket) -> Fut + Clone + Send + Sync + 'static,
242    Fut: Future<Output = ()> + Send + 'static,
243{
244    Router::new().route(
245        &config.endpoint,
246        get(move |ws: WebSocketUpgrade,
247                  Extension(container): Extension<Container>,
248                  headers: HeaderMap,
249                  Extension(request_id): Extension<RequestId>| {
250                let handler = handler.clone();
251                let auth_identity = container
252                    .resolve::<AuthIdentity>()
253                    .ok()
254                    .map(|value| (*value).clone());
255                async move {
256                    let context =
257                        WebSocketContext::new(container, Some(request_id), auth_identity, headers);
258                    ws.on_upgrade(move |socket| handler(context, socket))
259                }
260            }),
261    )
262}
263
264fn normalize_path(path: String, fallback: &str) -> String {
265    let trimmed = path.trim();
266    if trimmed.is_empty() || trimmed == "/" {
267        return fallback.to_string();
268    }
269
270    if trimmed.starts_with('/') {
271        trimmed.to_string()
272    } else {
273        format!("/{trimmed}")
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::sync::{
280        atomic::{AtomicUsize, Ordering},
281        Arc,
282    };
283
284    use nestforge_microservices::MicroserviceRegistry;
285
286    use super::{
287        handle_websocket_microservice_message, Message, TransportMetadata, WebSocketConfig,
288        WebSocketContext, WebSocketMicroserviceFrame, WebSocketMicroserviceKind,
289    };
290
291    #[test]
292    fn config_normalizes_relative_paths() {
293        assert_eq!(WebSocketConfig::new("socket").endpoint, "/socket");
294        assert_eq!(WebSocketConfig::new("/socket").endpoint, "/socket");
295        assert_eq!(WebSocketConfig::new("").endpoint, "/ws");
296    }
297
298    #[tokio::test]
299    async fn websocket_microservice_adapter_returns_message_responses() {
300        let container = nestforge_core::Container::new();
301        container
302            .register(Arc::new(AtomicUsize::new(7)))
303            .expect("counter should register");
304        let ctx = WebSocketContext::new(container, None, None, HeaderMap::new());
305        let registry = MicroserviceRegistry::builder()
306            .message("counter.read", |_payload: (), ctx| async move {
307                let counter = ctx.resolve::<Arc<AtomicUsize>>()?;
308                Ok(counter.load(Ordering::Relaxed))
309            })
310            .build();
311        let frame = WebSocketMicroserviceFrame {
312            kind: WebSocketMicroserviceKind::Message,
313            pattern: "counter.read".to_string(),
314            payload: serde_json::json!(null),
315            metadata: TransportMetadata::default(),
316        };
317
318        let response = handle_websocket_microservice_message(
319            &ctx,
320            &registry,
321            Message::Text(serde_json::to_string(&frame).expect("frame").into()),
322        )
323        .await
324        .expect("message should dispatch");
325
326        match response {
327            Some(Message::Text(payload)) => {
328                let json: serde_json::Value =
329                    serde_json::from_str(&payload).expect("response should be json");
330                assert_eq!(json["payload"], serde_json::json!(7));
331            }
332            other => panic!("unexpected websocket response: {other:?}"),
333        }
334    }
335
336    #[tokio::test]
337    async fn websocket_microservice_adapter_dispatches_events_without_response() {
338        let counter = Arc::new(AtomicUsize::new(0));
339        let ctx = WebSocketContext::new(nestforge_core::Container::new(), None, None, HeaderMap::new());
340        let registry = MicroserviceRegistry::builder()
341            .event("counter.bump", {
342                let counter = Arc::clone(&counter);
343                move |_payload: (), _ctx| {
344                    let counter = Arc::clone(&counter);
345                    async move {
346                        counter.fetch_add(1, Ordering::Relaxed);
347                        Ok(())
348                    }
349                }
350            })
351            .build();
352        let frame = WebSocketMicroserviceFrame {
353            kind: WebSocketMicroserviceKind::Event,
354            pattern: "counter.bump".to_string(),
355            payload: serde_json::json!(null),
356            metadata: TransportMetadata::default(),
357        };
358
359        let response = handle_websocket_microservice_message(
360            &ctx,
361            &registry,
362            Message::Text(serde_json::to_string(&frame).expect("frame").into()),
363        )
364        .await
365        .expect("event should dispatch");
366
367        assert!(response.is_none());
368        assert_eq!(counter.load(Ordering::Relaxed), 1);
369    }
370}