1use std::{
2 future::Future,
3 pin::Pin,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use axum::{
9 extract::Extension,
10 http::HeaderMap,
11 routing::get,
12 Router,
13};
14use nestforge_core::{AuthIdentity, Container, RequestId};
15use nestforge_microservices::{
16 EventEnvelope, MessageEnvelope, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
17};
18use serde_json::Value;
19
20pub use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket, WebSocketUpgrade};
21
22type WebSocketFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
23
24pub trait WebSocketGateway: Send + Sync + 'static {
25 fn on_connect(&self, ctx: WebSocketContext, socket: WebSocket) -> WebSocketFuture;
26}
27
28#[derive(Debug, Clone)]
29pub struct WebSocketConfig {
30 pub endpoint: String,
31}
32
33impl Default for WebSocketConfig {
34 fn default() -> Self {
35 Self {
36 endpoint: "/ws".to_string(),
37 }
38 }
39}
40
41impl WebSocketConfig {
42 pub fn new(endpoint: impl Into<String>) -> Self {
43 Self {
44 endpoint: normalize_path(endpoint.into(), "/ws"),
45 }
46 }
47}
48
49#[derive(Clone)]
50pub struct WebSocketContext {
51 container: Container,
52 request_id: Option<RequestId>,
53 auth_identity: Option<AuthIdentity>,
54 headers: HeaderMap,
55}
56
57impl WebSocketContext {
58 pub fn new(
59 container: Container,
60 request_id: Option<RequestId>,
61 auth_identity: Option<AuthIdentity>,
62 headers: HeaderMap,
63 ) -> Self {
64 Self {
65 container,
66 request_id,
67 auth_identity,
68 headers,
69 }
70 }
71
72 pub fn container(&self) -> &Container {
73 &self.container
74 }
75
76 pub fn request_id(&self) -> Option<&RequestId> {
77 self.request_id.as_ref()
78 }
79
80 pub fn auth_identity(&self) -> Option<&AuthIdentity> {
81 self.auth_identity.as_ref()
82 }
83
84 pub fn headers(&self) -> &HeaderMap {
85 &self.headers
86 }
87
88 pub fn resolve<T>(&self) -> Result<Arc<T>>
89 where
90 T: Send + Sync + 'static,
91 {
92 Ok(self.container.resolve::<T>()?)
93 }
94
95 pub fn is_authenticated(&self) -> bool {
96 self.auth_identity.is_some()
97 }
98
99 pub fn has_role(&self, role: &str) -> bool {
100 self.auth_identity
101 .as_ref()
102 .map(|identity| identity.roles.iter().any(|value| value == role))
103 .unwrap_or(false)
104 }
105
106 pub fn microservice_context(
107 &self,
108 pattern: impl Into<String>,
109 metadata: TransportMetadata,
110 ) -> MicroserviceContext {
111 MicroserviceContext::new(self.container.clone(), "websocket", pattern, metadata)
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
116#[serde(rename_all = "snake_case")]
117pub enum WebSocketMicroserviceKind {
118 Message,
119 Event,
120}
121
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct WebSocketMicroserviceFrame {
124 pub kind: WebSocketMicroserviceKind,
125 pub pattern: String,
126 pub payload: Value,
127 #[serde(default)]
128 pub metadata: TransportMetadata,
129}
130
131#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
132pub struct WebSocketMicroserviceResponse {
133 pub pattern: String,
134 pub payload: Value,
135 #[serde(default)]
136 pub metadata: TransportMetadata,
137}
138
139pub async fn handle_websocket_microservice_message(
140 ctx: &WebSocketContext,
141 registry: &MicroserviceRegistry,
142 message: Message,
143) -> Result<Option<Message>> {
144 let frame = parse_websocket_microservice_frame(message)?;
145 let microservice_ctx = ctx.microservice_context(frame.pattern.clone(), frame.metadata.clone());
146
147 match frame.kind {
148 WebSocketMicroserviceKind::Message => {
149 let payload = registry
150 .dispatch_message(
151 MessageEnvelope {
152 pattern: frame.pattern.clone(),
153 payload: frame.payload,
154 metadata: frame.metadata.clone(),
155 },
156 microservice_ctx,
157 )
158 .await?;
159 let response = WebSocketMicroserviceResponse {
160 pattern: frame.pattern,
161 payload,
162 metadata: frame.metadata,
163 };
164 Ok(Some(Message::Text(serde_json::to_string(&response)?.into())))
165 }
166 WebSocketMicroserviceKind::Event => {
167 registry
168 .dispatch_event(
169 EventEnvelope {
170 pattern: frame.pattern,
171 payload: frame.payload,
172 metadata: frame.metadata,
173 },
174 microservice_ctx,
175 )
176 .await?;
177 Ok(None)
178 }
179 }
180}
181
182fn parse_websocket_microservice_frame(message: Message) -> Result<WebSocketMicroserviceFrame> {
183 match message {
184 Message::Text(text) => Ok(serde_json::from_str(&text)?),
185 Message::Binary(bytes) => Ok(serde_json::from_slice(&bytes)?),
186 other => anyhow::bail!("Unsupported websocket microservice message: {other:?}"),
187 }
188}
189
190pub fn websocket_gateway_router<G>(gateway: G) -> Router<Container>
191where
192 G: WebSocketGateway,
193{
194 websocket_gateway_router_with_config(gateway, WebSocketConfig::default())
195}
196
197pub fn websocket_gateway_router_with_config<G>(
198 gateway: G,
199 config: WebSocketConfig,
200) -> Router<Container>
201where
202 G: WebSocketGateway,
203{
204 let gateway = Arc::new(gateway);
205 Router::new().route(
206 &config.endpoint,
207 get(move |ws: WebSocketUpgrade,
208 Extension(container): Extension<Container>,
209 headers: HeaderMap,
210 Extension(request_id): Extension<RequestId>| {
211 let gateway = Arc::clone(&gateway);
212 let auth_identity = container
213 .resolve::<AuthIdentity>()
214 .ok()
215 .map(|value| (*value).clone());
216 async move {
217 let context =
218 WebSocketContext::new(container, Some(request_id), auth_identity, headers);
219 ws.on_upgrade(move |socket| async move {
220 gateway.on_connect(context, socket).await;
221 })
222 }
223 },
224 ),
225 )
226}
227
228pub fn websocket_router<F, Fut>(handler: F) -> Router<Container>
229where
230 F: Fn(WebSocketContext, WebSocket) -> Fut + Clone + Send + Sync + 'static,
231 Fut: Future<Output = ()> + Send + 'static,
232{
233 websocket_router_with_config(handler, WebSocketConfig::default())
234}
235
236pub fn websocket_router_with_config<F, Fut>(
237 handler: F,
238 config: WebSocketConfig,
239) -> Router<Container>
240where
241 F: Fn(WebSocketContext, WebSocket) -> Fut + Clone + Send + Sync + 'static,
242 Fut: Future<Output = ()> + Send + 'static,
243{
244 Router::new().route(
245 &config.endpoint,
246 get(move |ws: WebSocketUpgrade,
247 Extension(container): Extension<Container>,
248 headers: HeaderMap,
249 Extension(request_id): Extension<RequestId>| {
250 let handler = handler.clone();
251 let auth_identity = container
252 .resolve::<AuthIdentity>()
253 .ok()
254 .map(|value| (*value).clone());
255 async move {
256 let context =
257 WebSocketContext::new(container, Some(request_id), auth_identity, headers);
258 ws.on_upgrade(move |socket| handler(context, socket))
259 }
260 }),
261 )
262}
263
264fn normalize_path(path: String, fallback: &str) -> String {
265 let trimmed = path.trim();
266 if trimmed.is_empty() || trimmed == "/" {
267 return fallback.to_string();
268 }
269
270 if trimmed.starts_with('/') {
271 trimmed.to_string()
272 } else {
273 format!("/{trimmed}")
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::sync::{
280 atomic::{AtomicUsize, Ordering},
281 Arc,
282 };
283
284 use nestforge_microservices::MicroserviceRegistry;
285
286 use super::{
287 handle_websocket_microservice_message, Message, TransportMetadata, WebSocketConfig,
288 WebSocketContext, WebSocketMicroserviceFrame, WebSocketMicroserviceKind,
289 };
290
291 #[test]
292 fn config_normalizes_relative_paths() {
293 assert_eq!(WebSocketConfig::new("socket").endpoint, "/socket");
294 assert_eq!(WebSocketConfig::new("/socket").endpoint, "/socket");
295 assert_eq!(WebSocketConfig::new("").endpoint, "/ws");
296 }
297
298 #[tokio::test]
299 async fn websocket_microservice_adapter_returns_message_responses() {
300 let container = nestforge_core::Container::new();
301 container
302 .register(Arc::new(AtomicUsize::new(7)))
303 .expect("counter should register");
304 let ctx = WebSocketContext::new(container, None, None, HeaderMap::new());
305 let registry = MicroserviceRegistry::builder()
306 .message("counter.read", |_payload: (), ctx| async move {
307 let counter = ctx.resolve::<Arc<AtomicUsize>>()?;
308 Ok(counter.load(Ordering::Relaxed))
309 })
310 .build();
311 let frame = WebSocketMicroserviceFrame {
312 kind: WebSocketMicroserviceKind::Message,
313 pattern: "counter.read".to_string(),
314 payload: serde_json::json!(null),
315 metadata: TransportMetadata::default(),
316 };
317
318 let response = handle_websocket_microservice_message(
319 &ctx,
320 ®istry,
321 Message::Text(serde_json::to_string(&frame).expect("frame").into()),
322 )
323 .await
324 .expect("message should dispatch");
325
326 match response {
327 Some(Message::Text(payload)) => {
328 let json: serde_json::Value =
329 serde_json::from_str(&payload).expect("response should be json");
330 assert_eq!(json["payload"], serde_json::json!(7));
331 }
332 other => panic!("unexpected websocket response: {other:?}"),
333 }
334 }
335
336 #[tokio::test]
337 async fn websocket_microservice_adapter_dispatches_events_without_response() {
338 let counter = Arc::new(AtomicUsize::new(0));
339 let ctx = WebSocketContext::new(nestforge_core::Container::new(), None, None, HeaderMap::new());
340 let registry = MicroserviceRegistry::builder()
341 .event("counter.bump", {
342 let counter = Arc::clone(&counter);
343 move |_payload: (), _ctx| {
344 let counter = Arc::clone(&counter);
345 async move {
346 counter.fetch_add(1, Ordering::Relaxed);
347 Ok(())
348 }
349 }
350 })
351 .build();
352 let frame = WebSocketMicroserviceFrame {
353 kind: WebSocketMicroserviceKind::Event,
354 pattern: "counter.bump".to_string(),
355 payload: serde_json::json!(null),
356 metadata: TransportMetadata::default(),
357 };
358
359 let response = handle_websocket_microservice_message(
360 &ctx,
361 ®istry,
362 Message::Text(serde_json::to_string(&frame).expect("frame").into()),
363 )
364 .await
365 .expect("event should dispatch");
366
367 assert!(response.is_none());
368 assert_eq!(counter.load(Ordering::Relaxed), 1);
369 }
370}