1use async_trait::async_trait;
4use axum::extract::ws::{Message, WebSocket};
5use futures_util::{SinkExt, StreamExt};
6use pondsocket::contexts::IncomingConnection;
7use pondsocket::errors::{Result, internal};
8use pondsocket::transport::Transport;
9use pondsocket::types::{Event, TransportType};
10use pondsocket::wire::{event_to_json, parse_inbound_text};
11use pondsocket::{Endpoint, PondSocket};
12use pondsocket_common::PondAssigns;
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{RwLock, mpsc};
17
18pub struct AxumWebSocketTransport {
19 id: String,
20 assigns: RwLock<PondAssigns>,
21 active: RwLock<bool>,
22 tx: mpsc::Sender<Message>,
23}
24
25impl AxumWebSocketTransport {
26 pub fn new(id: impl Into<String>, assigns: PondAssigns, tx: mpsc::Sender<Message>) -> Self {
27 Self {
28 id: id.into(),
29 assigns: RwLock::new(assigns),
30 active: RwLock::new(true),
31 tx,
32 }
33 }
34}
35
36#[async_trait]
37impl Transport for AxumWebSocketTransport {
38 fn id(&self) -> &str {
39 &self.id
40 }
41
42 async fn send_event(&self, event: Event) -> Result<()> {
43 let text = event_to_json(&event).map_err(|e| internal("", e.to_string()))?;
44 self.tx
45 .send(Message::Text(text.into()))
46 .await
47 .map_err(|_| internal("", "websocket writer closed"))
48 }
49
50 async fn close(&self) -> Result<()> {
51 *self.active.write().await = false;
52 let _ = self.tx.send(Message::Close(None)).await;
53 Ok(())
54 }
55
56 fn transport_type(&self) -> TransportType {
57 TransportType::WebSocket
58 }
59
60 async fn is_active(&self) -> bool {
61 *self.active.read().await
62 }
63
64 async fn get_assign(&self, key: &str) -> Option<Value> {
65 self.assigns.read().await.get(key).cloned()
66 }
67
68 async fn set_assign(&self, key: &str, value: Value) {
69 self.assigns.write().await.insert(key.to_owned(), value);
70 }
71
72 async fn clone_assigns(&self) -> PondAssigns {
73 self.assigns.read().await.clone()
74 }
75}
76
77#[derive(Debug, Clone, Default)]
78pub struct RequestParts {
79 pub path: String,
80 pub headers: HashMap<String, String>,
81 pub cookies: HashMap<String, String>,
82 pub query: HashMap<String, String>,
83 pub address: String,
84}
85
86pub async fn handle_socket(pond: Arc<PondSocket>, socket: WebSocket, request: RequestParts) {
87 let Some(matched) = pond.match_endpoint(&request.path).await else {
88 let mut socket = socket;
89 let _ = socket.send(Message::Close(None)).await;
90 return;
91 };
92 let endpoint = matched.endpoint;
93 let incoming = IncomingConnection {
94 id: String::new(),
95 headers: request.headers,
96 cookies: request.cookies,
97 query: request.query,
98 params: matched.route.params.clone(),
99 address: request.address,
100 };
101 let ctx = endpoint
102 .request_connection(incoming, matched.route, None)
103 .await;
104 if ctx.is_declined() {
105 let mut socket = socket;
106 let _ = socket.send(Message::Close(None)).await;
107 return;
108 }
109
110 let (mut sender, mut receiver) = socket.split();
111 let (tx, mut rx) = mpsc::channel::<Message>(1024);
112 let transport = Arc::new(AxumWebSocketTransport::new(
113 ctx.user_id.clone(),
114 ctx.assigns(),
115 tx,
116 ));
117 if endpoint
118 .register_transport(transport.clone())
119 .await
120 .is_err()
121 {
122 let _ = sender.send(Message::Close(None)).await;
123 return;
124 }
125 send_pending_reply(&ctx, transport.clone()).await;
126
127 let writer = tokio::spawn(async move {
128 while let Some(message) = rx.recv().await {
129 if sender.send(message).await.is_err() {
130 break;
131 }
132 }
133 });
134
135 read_loop(endpoint.clone(), transport.clone(), &mut receiver).await;
136 let _ = endpoint.unregister_transport(transport.id()).await;
137 let _ = transport.close().await;
138 writer.abort();
139}
140
141async fn read_loop(
142 endpoint: Arc<Endpoint>,
143 transport: Arc<AxumWebSocketTransport>,
144 receiver: &mut futures_util::stream::SplitStream<WebSocket>,
145) {
146 while let Some(next) = receiver.next().await {
147 let Ok(message) = next else {
148 break;
149 };
150 match message {
151 Message::Text(text) => {
152 if text.len() > endpoint.max_message_size() {
153 break;
154 }
155 let Ok(event) = parse_inbound_text(&text) else {
156 continue;
157 };
158 let _ = endpoint.handle_message(event, transport.clone()).await;
159 }
160 Message::Binary(bytes) => {
161 if bytes.len() > endpoint.max_message_size() {
162 break;
163 }
164 let Ok(text) = String::from_utf8(bytes.to_vec()) else {
165 continue;
166 };
167 let Ok(event) = parse_inbound_text(&text) else {
168 continue;
169 };
170 let _ = endpoint.handle_message(event, transport.clone()).await;
171 }
172 Message::Close(_) => break,
173 _ => {}
174 }
175 }
176}
177
178async fn send_pending_reply(
179 ctx: &pondsocket::ConnectionContext,
180 transport: Arc<AxumWebSocketTransport>,
181) {
182 if let Some((event, payload)) = ctx.pending_reply() {
183 let ev = Event::new(
184 "SYSTEM",
185 "GATEWAY",
186 pondsocket_common::uuid(),
187 event,
188 payload,
189 );
190 let _ = transport.send_event(ev).await;
191 }
192}