pylon_runtime/
shard_ws.rs1use std::net::{TcpListener, TcpStream};
17use std::sync::{Arc, Mutex};
18use std::thread;
19use std::time::Duration;
20
21use pylon_auth::SessionStore;
22use pylon_realtime::{DynShardRegistry, ShardAuth, ShardError, SubscriberId};
23use tungstenite::{accept_hdr, handshake::server::Request, Message};
24
25use crate::ip_limit::IpConnCounter;
26
27pub fn start_shard_ws_server(
35 registry: Arc<dyn DynShardRegistry>,
36 sessions: Arc<SessionStore>,
37 port: u16,
38) {
39 let listener = match TcpListener::bind(format!("0.0.0.0:{port}")) {
40 Ok(l) => l,
41 Err(e) => {
42 tracing::warn!("[shard-ws] failed to bind port {port}: {e}");
43 return;
44 }
45 };
46 tracing::warn!("[shard-ws] listening on ws://0.0.0.0:{port}");
47
48 let ip_counter = Arc::new(IpConnCounter::default());
52
53 for stream in listener.incoming() {
54 let stream = match stream {
55 Ok(s) => s,
56 Err(_) => continue,
57 };
58 let ip = match stream.peer_addr() {
59 Ok(addr) => addr.ip(),
60 Err(_) => continue,
61 };
62 let guard = match ip_counter.acquire(ip) {
63 Some(g) => g,
64 None => continue,
65 };
66 let registry = Arc::clone(®istry);
67 let sessions = Arc::clone(&sessions);
68 thread::spawn(move || {
69 let _guard = guard;
72 if let Err(e) = handle_connection(stream, registry, sessions) {
73 tracing::warn!("[shard-ws] connection error: {e}");
74 }
75 });
76 }
77}
78
79fn handle_connection(
84 stream: TcpStream,
85 registry: Arc<dyn DynShardRegistry>,
86 sessions: Arc<SessionStore>,
87) -> Result<(), String> {
88 let params = std::sync::Arc::new(Mutex::new(HandshakeParams::default()));
90 let params_clone = Arc::clone(¶ms);
91
92 use tungstenite::handshake::server::{ErrorResponse, Response};
93 let ws = accept_hdr(
94 stream,
95 |req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
96 let uri = req.uri().to_string();
97 let mut p = params_clone.lock().unwrap();
98 p.uri = uri;
99 let mut selected_protocol: Option<String> = None;
100 for (name, value) in req.headers() {
101 let lower = name.as_str().to_ascii_lowercase();
102 if lower == "authorization" {
103 if let Ok(v) = value.to_str() {
104 p.auth_header = Some(v.to_string());
105 }
106 } else if lower == "sec-websocket-protocol" {
107 if let Ok(v) = value.to_str() {
116 for proto in v.split(',').map(str::trim) {
117 if let Some(encoded) = proto.strip_prefix("bearer.") {
118 if let Ok(decoded) = urldecode_strict(encoded) {
119 p.bearer_from_subprotocol = Some(decoded);
120 selected_protocol = Some(proto.to_string());
121 break;
122 }
123 }
124 }
125 }
126 }
127 }
128 if let Some(chosen) = selected_protocol {
129 if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
130 resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
131 }
132 }
133 Ok(resp)
134 },
135 )
136 .map_err(|e| format!("handshake: {e}"))?;
137
138 let params = params.lock().unwrap().clone();
139 let query = params
140 .uri
141 .split_once('?')
142 .map(|(_, q)| q.to_string())
143 .unwrap_or_default();
144
145 let shard_id = query_param(&query, "shard").ok_or("missing ?shard= parameter")?;
146 let sid = query_param(&query, "sid").unwrap_or_else(|| "anon".to_string());
147
148 let token = params
156 .auth_header
157 .as_deref()
158 .and_then(|h| h.strip_prefix("Bearer "))
159 .map(|t| t.to_string())
160 .or_else(|| params.bearer_from_subprotocol.clone());
161 let auth_ctx = sessions.resolve(token.as_deref());
162 let shard_auth = ShardAuth {
163 user_id: auth_ctx.user_id.clone(),
164 is_admin: auth_ctx.is_admin,
165 };
166
167 let shard = registry
168 .get(&shard_id)
169 .ok_or_else(|| format!("shard \"{shard_id}\" not found"))?;
170
171 let ws = Arc::new(Mutex::new(ws));
172 let subscriber_id = SubscriberId::new(sid.clone());
173
174 let ws_for_sink = Arc::clone(&ws);
176 let sink: pylon_realtime::SnapshotSink = Box::new(move |tick, bytes| {
177 let mut payload = Vec::with_capacity(8 + bytes.len() + 2);
178 payload.extend_from_slice(&tick.to_be_bytes());
179 payload.extend_from_slice(bytes);
180 if let Ok(mut s) = ws_for_sink.lock() {
181 let _ = s.send(Message::Binary(payload.into()));
182 }
183 });
184
185 match shard.add_subscriber(subscriber_id.clone(), sink, &shard_auth) {
187 Ok(()) => {}
188 Err(ShardError::Unauthorized(reason)) => {
189 let _ = ws
190 .lock()
191 .unwrap()
192 .close(Some(tungstenite::protocol::CloseFrame {
193 code: tungstenite::protocol::frame::coding::CloseCode::Policy,
194 reason: format!("unauthorized: {reason}").into(),
195 }));
196 return Ok(());
197 }
198 Err(e) => {
199 let _ = ws
200 .lock()
201 .unwrap()
202 .close(Some(tungstenite::protocol::CloseFrame {
203 code: tungstenite::protocol::frame::coding::CloseCode::Again,
204 reason: e.to_string().into(),
205 }));
206 return Ok(());
207 }
208 }
209
210 let read_result = loop {
213 let msg = {
214 let mut s = match ws.lock() {
215 Ok(s) => s,
216 Err(_) => break Err("ws lock poisoned".to_string()),
217 };
218 match s.read() {
219 Ok(m) => m,
220 Err(tungstenite::Error::ConnectionClosed) => break Ok(()),
221 Err(tungstenite::Error::AlreadyClosed) => break Ok(()),
222 Err(e) => break Err(format!("ws read: {e}")),
223 }
224 };
225
226 match msg {
227 Message::Text(text) => {
228 process_input(&shard, &subscriber_id, &shard_auth, text.as_str());
229 }
230 Message::Binary(bytes) => {
231 let text = String::from_utf8_lossy(&bytes).to_string();
232 process_input(&shard, &subscriber_id, &shard_auth, &text);
233 }
234 Message::Ping(payload) => {
235 let _ = ws.lock().unwrap().send(Message::Pong(payload));
236 }
237 Message::Close(_) => break Ok(()),
238 _ => {}
239 }
240 };
241
242 shard.remove_subscriber(&subscriber_id);
244 if let Err(e) = read_result {
245 Err(e)
246 } else {
247 Ok(())
248 }
249}
250
251fn process_input(
252 shard: &Arc<dyn pylon_realtime::DynShard>,
253 subscriber_id: &SubscriberId,
254 shard_auth: &ShardAuth,
255 text: &str,
256) {
257 let envelope: serde_json::Value = match serde_json::from_str(text) {
259 Ok(v) => v,
260 Err(_) => return,
261 };
262 let input = envelope
263 .get("input")
264 .cloned()
265 .unwrap_or(serde_json::Value::Null);
266 let client_seq = envelope.get("client_seq").and_then(|v| v.as_u64());
267 let input_str = serde_json::to_string(&input).unwrap_or_else(|_| "null".into());
268
269 let _ = shard.push_input_json(subscriber_id.clone(), &input_str, client_seq, shard_auth);
270}
271
272#[derive(Default, Clone)]
277struct HandshakeParams {
278 uri: String,
279 auth_header: Option<String>,
280 bearer_from_subprotocol: Option<String>,
281}
282
283fn urldecode_strict(s: &str) -> Result<String, String> {
286 let mut out = Vec::with_capacity(s.len());
287 let bytes = s.as_bytes();
288 let mut i = 0;
289 while i < bytes.len() {
290 if bytes[i] == b'%' {
291 if i + 2 >= bytes.len() {
292 return Err("truncated percent-encoding".into());
293 }
294 let hi = (bytes[i + 1] as char)
295 .to_digit(16)
296 .ok_or("bad hex in percent-encoding")?;
297 let lo = (bytes[i + 2] as char)
298 .to_digit(16)
299 .ok_or("bad hex in percent-encoding")?;
300 out.push(((hi << 4) | lo) as u8);
301 i += 3;
302 } else if bytes[i] == b'+' {
303 out.push(b' ');
304 i += 1;
305 } else {
306 out.push(bytes[i]);
307 i += 1;
308 }
309 }
310 String::from_utf8(out).map_err(|_| "percent-encoded token is not valid UTF-8".into())
311}
312
313fn query_param(query: &str, key: &str) -> Option<String> {
314 for pair in query.split('&') {
315 let mut it = pair.splitn(2, '=');
316 let k = it.next()?;
317 let v = it.next().unwrap_or("");
318 if k == key {
319 return Some(url_decode(v));
320 }
321 }
322 None
323}
324
325fn url_decode(s: &str) -> String {
326 let mut out = String::with_capacity(s.len());
327 let bytes = s.as_bytes();
328 let mut i = 0;
329 while i < bytes.len() {
330 match bytes[i] {
331 b'+' => {
332 out.push(' ');
333 i += 1;
334 }
335 b'%' if i + 2 < bytes.len() => {
336 if let Ok(h) =
337 u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
338 {
339 out.push(h as char);
340 i += 3;
341 } else {
342 out.push(bytes[i] as char);
343 i += 1;
344 }
345 }
346 b => {
347 out.push(b as char);
348 i += 1;
349 }
350 }
351 }
352 out
353}
354
355#[allow(dead_code)]
357fn apply_read_timeout(stream: &TcpStream, dur: Duration) {
358 let _ = stream.set_read_timeout(Some(dur));
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn query_param_parses_basic() {
367 assert_eq!(
368 query_param("shard=match1&sid=p1", "shard"),
369 Some("match1".to_string())
370 );
371 assert_eq!(
372 query_param("shard=match1&sid=p1", "sid"),
373 Some("p1".to_string())
374 );
375 assert_eq!(query_param("shard=match1", "missing"), None);
376 }
377
378 #[test]
379 fn query_param_url_decodes() {
380 assert_eq!(
381 query_param("name=hello%20world", "name"),
382 Some("hello world".to_string())
383 );
384 }
385}