1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
8use serde::de::DeserializeOwned;
9use serde_json::{Value, json};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, mpsc};
12use tokio::time::{MissedTickBehavior, interval, timeout};
13use tokio_tungstenite::{accept_async, tungstenite::Message};
14use uuid::Uuid;
15use validator::Validate;
16use wscall_protocol::{
17 EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
18};
19
20use crate::server_types::{
21 ApiContext, ApiError, EventContext, ExceptionContext, ServerError, ServerHandle,
22 ServerOutbound, ServerState,
23};
24
25const SERVER_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
26const SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
27const SERVER_OUTBOUND_QUEUE_CAPACITY: usize = 256;
28
29type ApiHandler =
30 Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
31type Filter =
32 Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<ApiContext, ApiError>> + Send + Sync>;
33type EventHandler =
34 Arc<dyn Fn(EventContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
35type ExceptionHandler =
36 Arc<dyn Fn(ExceptionContext) -> BoxFuture<'static, ErrorPayload> + Send + Sync>;
37
38struct ApiRequestInput {
39 request_id: String,
40 route: String,
41 params: Value,
42 attachments: Vec<FileAttachment>,
43 metadata: Value,
44}
45
46struct EventEmitInput {
47 event_id: String,
48 name: String,
49 data: Value,
50 attachments: Vec<FileAttachment>,
51 metadata: Value,
52}
53
54impl ServerHandle {
55 pub async fn broadcast_event(
56 &self,
57 name: impl Into<String>,
58 data: Value,
59 attachments: Vec<FileAttachment>,
60 ) -> Result<(), ApiError> {
61 let packet = PacketEnvelope::with_encryption(
62 PacketBody::EventEmit {
63 event_id: Uuid::new_v4().to_string(),
64 name: name.into(),
65 data,
66 attachments,
67 metadata: json!({ "source": "server" }),
68 expect_ack: true,
69 },
70 self.default_encryption,
71 );
72
73 let clients = self.state.clients.read().await;
74 let senders = clients.values().cloned().collect::<Vec<_>>();
75 drop(clients);
76
77 for sender in senders {
78 sender
79 .try_send(ServerOutbound::Packet(packet.clone()))
80 .map_err(|_| ApiError::internal("failed to queue broadcast event"))?;
81 }
82 Ok(())
83 }
84
85 pub async fn send_event_to(
86 &self,
87 connection_id: &str,
88 name: impl Into<String>,
89 data: Value,
90 attachments: Vec<FileAttachment>,
91 ) -> Result<(), ApiError> {
92 let packet = PacketEnvelope::with_encryption(
93 PacketBody::EventEmit {
94 event_id: Uuid::new_v4().to_string(),
95 name: name.into(),
96 data,
97 attachments,
98 metadata: json!({ "source": "server" }),
99 expect_ack: true,
100 },
101 self.default_encryption,
102 );
103
104 let clients = self.state.clients.read().await;
105 let sender = clients
106 .get(connection_id)
107 .cloned()
108 .ok_or_else(|| ApiError::not_found("target connection not found"))?;
109 drop(clients);
110 sender
111 .try_send(ServerOutbound::Packet(packet))
112 .map_err(|_| ApiError::internal("failed to queue direct event"))
113 }
114
115 pub async fn connection_count(&self) -> usize {
116 self.state.clients.read().await.len()
117 }
118}
119
120pub struct WscallServer {
121 state: Arc<ServerState>,
122 routes: HashMap<String, ApiHandler>,
123 filters: Vec<Filter>,
124 event_handlers: HashMap<String, EventHandler>,
125 exception_handler: Option<ExceptionHandler>,
126 codec: FrameCodec,
127 default_encryption: EncryptionKind,
128}
129
130impl Default for WscallServer {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl WscallServer {
137 pub fn new() -> Self {
138 Self {
139 state: Arc::new(ServerState {
140 clients: RwLock::new(HashMap::new()),
141 }),
142 routes: HashMap::new(),
143 filters: Vec::new(),
144 event_handlers: HashMap::new(),
145 exception_handler: None,
146 codec: FrameCodec::plaintext(),
147 default_encryption: EncryptionKind::None,
148 }
149 }
150
151 pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
152 self.codec = self.codec.clone().with_chacha20_key(key);
153 self.default_encryption = EncryptionKind::ChaCha20;
154 self
155 }
156
157 pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
158 self.codec = self.codec.clone().with_aes256_key(key);
159 self.default_encryption = EncryptionKind::Aes256;
160 self
161 }
162
163 pub fn handle(&self) -> ServerHandle {
164 ServerHandle {
165 state: Arc::clone(&self.state),
166 default_encryption: self.default_encryption,
167 }
168 }
169
170 pub fn route<F, Fut>(&mut self, route: impl Into<String>, handler: F)
171 where
172 F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
173 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
174 {
175 let handler = Arc::new(move |ctx: ApiContext| {
176 Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
177 });
178 self.routes.insert(route.into(), handler);
179 }
180
181 pub fn typed_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
182 where
183 T: DeserializeOwned + Send + 'static,
184 F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
185 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
186 {
187 let handler = Arc::new(handler);
188 self.route(route, move |ctx| {
189 let handler = Arc::clone(&handler);
190 let params = ctx.bind::<T>();
191 async move {
192 let params = params?;
193 handler(ctx, params).await
194 }
195 });
196 }
197
198 pub fn validated_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
199 where
200 T: DeserializeOwned + Validate + Send + 'static,
201 F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
202 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
203 {
204 let handler = Arc::new(handler);
205 self.route(route, move |ctx| {
206 let handler = Arc::clone(&handler);
207 let params = ctx.bind_validated::<T>();
208 async move {
209 let params = params?;
210 handler(ctx, params).await
211 }
212 });
213 }
214
215 pub fn filter<F, Fut>(&mut self, filter: F)
216 where
217 F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
218 Fut: Future<Output = Result<ApiContext, ApiError>> + Send + 'static,
219 {
220 let filter = Arc::new(move |ctx: ApiContext| {
221 Box::pin(filter(ctx)) as BoxFuture<'static, Result<ApiContext, ApiError>>
222 });
223 self.filters.push(filter);
224 }
225
226 pub fn event_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
227 where
228 F: Fn(EventContext) -> Fut + Send + Sync + 'static,
229 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
230 {
231 let handler = Arc::new(move |ctx: EventContext| {
232 Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
233 });
234 self.event_handlers.insert(name.into(), handler);
235 }
236
237 pub fn exception_handler<F, Fut>(&mut self, handler: F)
238 where
239 F: Fn(ExceptionContext) -> Fut + Send + Sync + 'static,
240 Fut: Future<Output = ErrorPayload> + Send + 'static,
241 {
242 self.exception_handler = Some(Arc::new(move |ctx: ExceptionContext| {
243 Box::pin(handler(ctx)) as BoxFuture<'static, ErrorPayload>
244 }));
245 }
246
247 pub async fn listen(self, address: &str) -> Result<(), ServerError> {
248 let listener = TcpListener::bind(address).await?;
249 println!("WSCALL server listening on ws://{address}/socket");
250
251 let shared = Arc::new(self);
252 loop {
253 let (stream, peer) = listener.accept().await?;
254 let server = Arc::clone(&shared);
255 tokio::spawn(async move {
256 if let Err(error) = server.serve_connection(stream, peer).await {
257 eprintln!("connection {peer:?} failed: {error}");
258 }
259 });
260 }
261 }
262
263 async fn serve_connection(
264 self: Arc<Self>,
265 stream: TcpStream,
266 peer: std::net::SocketAddr,
267 ) -> Result<(), ServerError> {
268 let websocket = accept_async(stream).await?;
269 let connection_id = Uuid::new_v4().to_string();
270 let (mut sink, mut stream) = websocket.split();
271 let (tx, mut rx) = mpsc::channel::<ServerOutbound>(SERVER_OUTBOUND_QUEUE_CAPACITY);
272
273 self.state
274 .clients
275 .write()
276 .await
277 .insert(connection_id.clone(), tx.clone());
278
279 let codec = self.codec.clone();
280 let writer = tokio::spawn(async move {
281 while let Some(outbound) = rx.recv().await {
282 match outbound {
283 ServerOutbound::Packet(packet) => {
284 let bytes = codec.encode(&packet)?;
285 sink.send(Message::Binary(bytes)).await?;
286 }
287 ServerOutbound::Ping(payload) => {
288 sink.send(Message::Ping(payload)).await?;
289 }
290 ServerOutbound::Pong(payload) => {
291 sink.send(Message::Pong(payload)).await?;
292 }
293 ServerOutbound::Close => {
294 let _ = sink.send(Message::Close(None)).await;
295 break;
296 }
297 }
298 }
299 Ok::<(), ServerError>(())
300 });
301
302 let heartbeat_tx = tx.clone();
303 let heartbeat = tokio::spawn(async move {
304 let mut ticker = interval(SERVER_HEARTBEAT_INTERVAL);
305 ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
306 loop {
307 ticker.tick().await;
308 if heartbeat_tx
309 .send(ServerOutbound::Ping(Vec::new()))
310 .await
311 .is_err()
312 {
313 break;
314 }
315 }
316 });
317
318 self.handle()
319 .send_event_to(
320 &connection_id,
321 "system.notice",
322 json!({ "message": "connected", "connection_id": connection_id }),
323 Vec::new(),
324 )
325 .await
326 .map_err(ServerError::Api)?;
327
328 let result = loop {
329 let next_message = timeout(SERVER_IDLE_TIMEOUT, stream.next()).await;
330 let Some(message) =
331 next_message.map_err(|_| ServerError::IdleTimeout(connection_id.clone()))?
332 else {
333 break Ok(());
334 };
335
336 match message? {
337 Message::Binary(bytes) => {
338 let packet = self.codec.decode(&bytes)?;
339 self.process_packet(&connection_id, Some(peer), packet)
340 .await?;
341 }
342 Message::Close(_) => break Ok(()),
343 Message::Ping(payload) => {
344 if tx
345 .send(ServerOutbound::Pong(payload.to_vec()))
346 .await
347 .is_err()
348 {
349 break Ok(());
350 }
351 }
352 Message::Pong(_) => {}
353 Message::Text(_) => {}
354 Message::Frame(_) => {}
355 }
356 };
357
358 self.state.clients.write().await.remove(&connection_id);
359 let _ = tx.send(ServerOutbound::Close).await;
360 heartbeat.abort();
361 writer.abort();
362 result
363 }
364
365 async fn process_packet(
366 &self,
367 connection_id: &str,
368 peer_addr: Option<std::net::SocketAddr>,
369 packet: PacketEnvelope,
370 ) -> Result<(), ServerError> {
371 match packet.body {
372 PacketBody::ApiRequest {
373 request_id,
374 route,
375 params,
376 attachments,
377 metadata,
378 } => {
379 let response = self
380 .run_api_request(
381 connection_id,
382 peer_addr,
383 ApiRequestInput {
384 request_id: request_id.clone(),
385 route,
386 params,
387 attachments,
388 metadata,
389 },
390 )
391 .await;
392 self.queue_for(connection_id, response).await?;
393 }
394 PacketBody::EventEmit {
395 event_id,
396 name,
397 data,
398 attachments,
399 metadata,
400 ..
401 } => {
402 let ack = self
403 .run_event(
404 connection_id,
405 peer_addr,
406 EventEmitInput {
407 event_id: event_id.clone(),
408 name,
409 data,
410 attachments,
411 metadata,
412 },
413 )
414 .await;
415 self.queue_for(connection_id, ack).await?;
416 }
417 PacketBody::EventAck {
418 event_id,
419 ok,
420 receipt,
421 error,
422 } => {
423 println!(
424 "received event ack from {} for {}: ok={}, receipt={}, error={:?}",
425 connection_id, event_id, ok, receipt, error
426 );
427 }
428 PacketBody::ApiResponse { .. } => {}
429 }
430 Ok(())
431 }
432
433 async fn queue_for(
434 &self,
435 connection_id: &str,
436 packet: PacketEnvelope,
437 ) -> Result<(), ServerError> {
438 let clients = self.state.clients.read().await;
439 let sender = clients
440 .get(connection_id)
441 .cloned()
442 .ok_or_else(|| ServerError::Api(ApiError::not_found("connection is closed")))?;
443 drop(clients);
444 sender
445 .try_send(ServerOutbound::Packet(packet))
446 .map_err(|error| match error {
447 tokio::sync::mpsc::error::TrySendError::Full(_) => {
448 ServerError::OutboundQueueFull(connection_id.to_string())
449 }
450 tokio::sync::mpsc::error::TrySendError::Closed(_) => {
451 ServerError::Api(ApiError::internal("failed to queue outbound packet"))
452 }
453 })
454 }
455
456 async fn run_api_request(
457 &self,
458 connection_id: &str,
459 peer_addr: Option<std::net::SocketAddr>,
460 request: ApiRequestInput,
461 ) -> PacketEnvelope {
462 let ApiRequestInput {
463 request_id,
464 route,
465 params,
466 attachments,
467 metadata,
468 } = request;
469
470 let mut ctx = ApiContext {
471 connection_id: connection_id.to_string(),
472 peer_addr,
473 request_id: request_id.clone(),
474 route: route.clone(),
475 params,
476 attachments,
477 metadata,
478 server: self.handle(),
479 };
480
481 for filter in &self.filters {
482 match filter(ctx).await {
483 Ok(next_ctx) => ctx = next_ctx,
484 Err(error) => {
485 return self
486 .api_error_packet(connection_id, Some(request_id), route, error)
487 .await;
488 }
489 }
490 }
491
492 let Some(handler) = self.routes.get(&ctx.route) else {
493 return self
494 .api_error_packet(
495 connection_id,
496 Some(request_id),
497 route,
498 ApiError::not_found("route not found"),
499 )
500 .await;
501 };
502
503 match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
504 Ok(Ok(data)) => PacketEnvelope::with_encryption(
505 PacketBody::ApiResponse {
506 request_id,
507 ok: true,
508 status: 200,
509 data,
510 error: None,
511 metadata: json!({}),
512 },
513 self.default_encryption,
514 ),
515 Ok(Err(error)) => {
516 self.api_error_packet(connection_id, Some(request_id), route, error)
517 .await
518 }
519 Err(_) => {
520 self.api_error_packet(
521 connection_id,
522 Some(request_id),
523 route,
524 ApiError::internal("handler panicked"),
525 )
526 .await
527 }
528 }
529 }
530
531 async fn run_event(
532 &self,
533 connection_id: &str,
534 peer_addr: Option<std::net::SocketAddr>,
535 event: EventEmitInput,
536 ) -> PacketEnvelope {
537 let EventEmitInput {
538 event_id,
539 name,
540 data,
541 attachments,
542 metadata,
543 } = event;
544
545 let ctx = EventContext {
546 connection_id: connection_id.to_string(),
547 peer_addr,
548 event_id: event_id.clone(),
549 name: name.clone(),
550 data,
551 attachments,
552 metadata,
553 server: self.handle(),
554 };
555
556 let Some(handler) = self.event_handlers.get(&name) else {
557 return PacketEnvelope::with_encryption(
558 PacketBody::EventAck {
559 event_id,
560 ok: false,
561 receipt: json!({}),
562 error: Some(ApiError::not_found("event handler not found").into_payload()),
563 },
564 self.default_encryption,
565 );
566 };
567
568 match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
569 Ok(Ok(receipt)) => PacketEnvelope::with_encryption(
570 PacketBody::EventAck {
571 event_id,
572 ok: true,
573 receipt,
574 error: None,
575 },
576 self.default_encryption,
577 ),
578 Ok(Err(error)) => PacketEnvelope::with_encryption(
579 PacketBody::EventAck {
580 event_id: event_id.clone(),
581 ok: false,
582 receipt: json!({}),
583 error: Some(
584 self.map_exception(ExceptionContext {
585 connection_id: connection_id.to_string(),
586 request_id: Some(event_id.clone()),
587 target: name,
588 message_kind: "event",
589 error,
590 })
591 .await,
592 ),
593 },
594 self.default_encryption,
595 ),
596 Err(_) => PacketEnvelope::with_encryption(
597 PacketBody::EventAck {
598 event_id: event_id.clone(),
599 ok: false,
600 receipt: json!({}),
601 error: Some(
602 self.map_exception(ExceptionContext {
603 connection_id: connection_id.to_string(),
604 request_id: Some(event_id.clone()),
605 target: name,
606 message_kind: "event",
607 error: ApiError::internal("event handler panicked"),
608 })
609 .await,
610 ),
611 },
612 self.default_encryption,
613 ),
614 }
615 }
616
617 async fn api_error_packet(
618 &self,
619 connection_id: &str,
620 request_id: Option<String>,
621 route: String,
622 error: ApiError,
623 ) -> PacketEnvelope {
624 let request_id = request_id.unwrap_or_else(|| Uuid::new_v4().to_string());
625 let status = error.status;
626 let payload = self
627 .map_exception(ExceptionContext {
628 connection_id: connection_id.to_string(),
629 request_id: Some(request_id.clone()),
630 target: route,
631 message_kind: "api",
632 error,
633 })
634 .await;
635
636 PacketEnvelope::with_encryption(
637 PacketBody::ApiResponse {
638 request_id,
639 ok: false,
640 status,
641 data: json!({}),
642 error: Some(payload),
643 metadata: json!({}),
644 },
645 self.default_encryption,
646 )
647 }
648
649 async fn map_exception(&self, context: ExceptionContext) -> ErrorPayload {
650 match &self.exception_handler {
651 Some(handler) => handler(context).await,
652 None => context.error.into_payload(),
653 }
654 }
655}