Skip to main content

pondsocket_axum/
lib.rs

1//! Axum adapter for PondSocket.
2
3use async_trait::async_trait;
4use axum::extract::ws::{Message, WebSocket};
5use futures_util::{SinkExt, StreamExt};
6use pondsocket::contexts::IncomingConnection;
7use pondsocket::errors::{Result, internal};
8use pondsocket::transport::Transport;
9use pondsocket::types::{Event, TransportType};
10use pondsocket::wire::{event_to_json, parse_inbound_text};
11use pondsocket::{Endpoint, PondSocket};
12use pondsocket_common::PondAssigns;
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{RwLock, mpsc};
17
18pub struct AxumWebSocketTransport {
19    id: String,
20    assigns: RwLock<PondAssigns>,
21    active: RwLock<bool>,
22    tx: mpsc::Sender<Message>,
23}
24
25impl AxumWebSocketTransport {
26    pub fn new(id: impl Into<String>, assigns: PondAssigns, tx: mpsc::Sender<Message>) -> Self {
27        Self {
28            id: id.into(),
29            assigns: RwLock::new(assigns),
30            active: RwLock::new(true),
31            tx,
32        }
33    }
34}
35
36#[async_trait]
37impl Transport for AxumWebSocketTransport {
38    fn id(&self) -> &str {
39        &self.id
40    }
41
42    async fn send_event(&self, event: Event) -> Result<()> {
43        let text = event_to_json(&event).map_err(|e| internal("", e.to_string()))?;
44        self.tx
45            .send(Message::Text(text.into()))
46            .await
47            .map_err(|_| internal("", "websocket writer closed"))
48    }
49
50    async fn close(&self) -> Result<()> {
51        *self.active.write().await = false;
52        let _ = self.tx.send(Message::Close(None)).await;
53        Ok(())
54    }
55
56    fn transport_type(&self) -> TransportType {
57        TransportType::WebSocket
58    }
59
60    async fn is_active(&self) -> bool {
61        *self.active.read().await
62    }
63
64    async fn get_assign(&self, key: &str) -> Option<Value> {
65        self.assigns.read().await.get(key).cloned()
66    }
67
68    async fn set_assign(&self, key: &str, value: Value) {
69        self.assigns.write().await.insert(key.to_owned(), value);
70    }
71
72    async fn clone_assigns(&self) -> PondAssigns {
73        self.assigns.read().await.clone()
74    }
75}
76
77#[derive(Debug, Clone, Default)]
78pub struct RequestParts {
79    pub path: String,
80    pub headers: HashMap<String, String>,
81    pub cookies: HashMap<String, String>,
82    pub query: HashMap<String, String>,
83    pub address: String,
84}
85
86pub async fn handle_socket(pond: Arc<PondSocket>, socket: WebSocket, request: RequestParts) {
87    let Some(matched) = pond.match_endpoint(&request.path).await else {
88        let mut socket = socket;
89        let _ = socket.send(Message::Close(None)).await;
90        return;
91    };
92    let endpoint = matched.endpoint;
93    let incoming = IncomingConnection {
94        id: String::new(),
95        headers: request.headers,
96        cookies: request.cookies,
97        query: request.query,
98        params: matched.route.params.clone(),
99        address: request.address,
100    };
101    let ctx = endpoint
102        .request_connection(incoming, matched.route, None)
103        .await;
104    if ctx.is_declined() {
105        let mut socket = socket;
106        let _ = socket.send(Message::Close(None)).await;
107        return;
108    }
109
110    let (mut sender, mut receiver) = socket.split();
111    let (tx, mut rx) = mpsc::channel::<Message>(1024);
112    let transport = Arc::new(AxumWebSocketTransport::new(
113        ctx.user_id.clone(),
114        ctx.assigns(),
115        tx,
116    ));
117    if endpoint
118        .register_transport(transport.clone())
119        .await
120        .is_err()
121    {
122        let _ = sender.send(Message::Close(None)).await;
123        return;
124    }
125    send_pending_reply(&ctx, transport.clone()).await;
126
127    let writer = tokio::spawn(async move {
128        while let Some(message) = rx.recv().await {
129            if sender.send(message).await.is_err() {
130                break;
131            }
132        }
133    });
134
135    read_loop(endpoint.clone(), transport.clone(), &mut receiver).await;
136    let _ = endpoint.unregister_transport(transport.id()).await;
137    let _ = transport.close().await;
138    writer.abort();
139}
140
141async fn read_loop(
142    endpoint: Arc<Endpoint>,
143    transport: Arc<AxumWebSocketTransport>,
144    receiver: &mut futures_util::stream::SplitStream<WebSocket>,
145) {
146    while let Some(next) = receiver.next().await {
147        let Ok(message) = next else {
148            break;
149        };
150        match message {
151            Message::Text(text) => {
152                if text.len() > endpoint.max_message_size() {
153                    break;
154                }
155                let Ok(event) = parse_inbound_text(&text) else {
156                    continue;
157                };
158                let _ = endpoint.handle_message(event, transport.clone()).await;
159            }
160            Message::Binary(bytes) => {
161                if bytes.len() > endpoint.max_message_size() {
162                    break;
163                }
164                let Ok(text) = String::from_utf8(bytes.to_vec()) else {
165                    continue;
166                };
167                let Ok(event) = parse_inbound_text(&text) else {
168                    continue;
169                };
170                let _ = endpoint.handle_message(event, transport.clone()).await;
171            }
172            Message::Close(_) => break,
173            _ => {}
174        }
175    }
176}
177
178async fn send_pending_reply(
179    ctx: &pondsocket::ConnectionContext,
180    transport: Arc<AxumWebSocketTransport>,
181) {
182    if let Some((event, payload)) = ctx.pending_reply() {
183        let ev = Event::new(
184            "SYSTEM",
185            "GATEWAY",
186            pondsocket_common::uuid(),
187            event,
188            payload,
189        );
190        let _ = transport.send_event(ev).await;
191    }
192}