use std::net::{TcpListener, TcpStream};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use pylon_auth::SessionStore;
use pylon_realtime::{DynShardRegistry, ShardAuth, ShardError, SubscriberId};
use tungstenite::{accept_hdr, handshake::server::Request, Message};
use crate::ip_limit::IpConnCounter;
pub fn start_shard_ws_server(
registry: Arc<dyn DynShardRegistry>,
sessions: Arc<SessionStore>,
port: u16,
) {
let listener = match TcpListener::bind(format!("0.0.0.0:{port}")) {
Ok(l) => l,
Err(e) => {
tracing::warn!("[shard-ws] failed to bind port {port}: {e}");
return;
}
};
tracing::warn!("[shard-ws] listening on ws://0.0.0.0:{port}");
let ip_counter = Arc::new(IpConnCounter::default());
for stream in listener.incoming() {
let stream = match stream {
Ok(s) => s,
Err(_) => continue,
};
let ip = match stream.peer_addr() {
Ok(addr) => addr.ip(),
Err(_) => continue,
};
let guard = match ip_counter.acquire(ip) {
Some(g) => g,
None => continue,
};
let registry = Arc::clone(®istry);
let sessions = Arc::clone(&sessions);
thread::spawn(move || {
let _guard = guard;
if let Err(e) = handle_connection(stream, registry, sessions) {
tracing::warn!("[shard-ws] connection error: {e}");
}
});
}
}
fn handle_connection(
stream: TcpStream,
registry: Arc<dyn DynShardRegistry>,
sessions: Arc<SessionStore>,
) -> Result<(), String> {
let params = std::sync::Arc::new(Mutex::new(HandshakeParams::default()));
let params_clone = Arc::clone(¶ms);
use tungstenite::handshake::server::{ErrorResponse, Response};
let ws = accept_hdr(
stream,
|req: &Request, mut resp: Response| -> Result<Response, ErrorResponse> {
let uri = req.uri().to_string();
let mut p = params_clone.lock().unwrap();
p.uri = uri;
let mut selected_protocol: Option<String> = None;
for (name, value) in req.headers() {
let lower = name.as_str().to_ascii_lowercase();
if lower == "authorization" {
if let Ok(v) = value.to_str() {
p.auth_header = Some(v.to_string());
}
} else if lower == "sec-websocket-protocol" {
if let Ok(v) = value.to_str() {
for proto in v.split(',').map(str::trim) {
if let Some(encoded) = proto.strip_prefix("bearer.") {
if let Ok(decoded) = urldecode_strict(encoded) {
p.bearer_from_subprotocol = Some(decoded);
selected_protocol = Some(proto.to_string());
break;
}
}
}
}
}
}
if let Some(chosen) = selected_protocol {
if let Ok(hv) = tungstenite::http::HeaderValue::from_str(&chosen) {
resp.headers_mut().insert("Sec-WebSocket-Protocol", hv);
}
}
Ok(resp)
},
)
.map_err(|e| format!("handshake: {e}"))?;
let params = params.lock().unwrap().clone();
let query = params
.uri
.split_once('?')
.map(|(_, q)| q.to_string())
.unwrap_or_default();
let shard_id = query_param(&query, "shard").ok_or("missing ?shard= parameter")?;
let sid = query_param(&query, "sid").unwrap_or_else(|| "anon".to_string());
let token = params
.auth_header
.as_deref()
.and_then(|h| h.strip_prefix("Bearer "))
.map(|t| t.to_string())
.or_else(|| params.bearer_from_subprotocol.clone());
let auth_ctx = sessions.resolve(token.as_deref());
let shard_auth = ShardAuth {
user_id: auth_ctx.user_id.clone(),
is_admin: auth_ctx.is_admin,
};
let shard = registry
.get(&shard_id)
.ok_or_else(|| format!("shard \"{shard_id}\" not found"))?;
let ws = Arc::new(Mutex::new(ws));
let subscriber_id = SubscriberId::new(sid.clone());
let ws_for_sink = Arc::clone(&ws);
let sink: pylon_realtime::SnapshotSink = Box::new(move |tick, bytes| {
let mut payload = Vec::with_capacity(8 + bytes.len() + 2);
payload.extend_from_slice(&tick.to_be_bytes());
payload.extend_from_slice(bytes);
if let Ok(mut s) = ws_for_sink.lock() {
let _ = s.send(Message::Binary(payload.into()));
}
});
match shard.add_subscriber(subscriber_id.clone(), sink, &shard_auth) {
Ok(()) => {}
Err(ShardError::Unauthorized(reason)) => {
let _ = ws
.lock()
.unwrap()
.close(Some(tungstenite::protocol::CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::Policy,
reason: format!("unauthorized: {reason}").into(),
}));
return Ok(());
}
Err(e) => {
let _ = ws
.lock()
.unwrap()
.close(Some(tungstenite::protocol::CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::Again,
reason: e.to_string().into(),
}));
return Ok(());
}
}
let read_result = loop {
let msg = {
let mut s = match ws.lock() {
Ok(s) => s,
Err(_) => break Err("ws lock poisoned".to_string()),
};
match s.read() {
Ok(m) => m,
Err(tungstenite::Error::ConnectionClosed) => break Ok(()),
Err(tungstenite::Error::AlreadyClosed) => break Ok(()),
Err(e) => break Err(format!("ws read: {e}")),
}
};
match msg {
Message::Text(text) => {
process_input(&shard, &subscriber_id, &shard_auth, text.as_str());
}
Message::Binary(bytes) => {
let text = String::from_utf8_lossy(&bytes).to_string();
process_input(&shard, &subscriber_id, &shard_auth, &text);
}
Message::Ping(payload) => {
let _ = ws.lock().unwrap().send(Message::Pong(payload));
}
Message::Close(_) => break Ok(()),
_ => {}
}
};
shard.remove_subscriber(&subscriber_id);
if let Err(e) = read_result {
Err(e)
} else {
Ok(())
}
}
fn process_input(
shard: &Arc<dyn pylon_realtime::DynShard>,
subscriber_id: &SubscriberId,
shard_auth: &ShardAuth,
text: &str,
) {
let envelope: serde_json::Value = match serde_json::from_str(text) {
Ok(v) => v,
Err(_) => return,
};
let input = envelope
.get("input")
.cloned()
.unwrap_or(serde_json::Value::Null);
let client_seq = envelope.get("client_seq").and_then(|v| v.as_u64());
let input_str = serde_json::to_string(&input).unwrap_or_else(|_| "null".into());
let _ = shard.push_input_json(subscriber_id.clone(), &input_str, client_seq, shard_auth);
}
#[derive(Default, Clone)]
struct HandshakeParams {
uri: String,
auth_header: Option<String>,
bearer_from_subprotocol: Option<String>,
}
fn urldecode_strict(s: &str) -> Result<String, String> {
let mut out = Vec::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' {
if i + 2 >= bytes.len() {
return Err("truncated percent-encoding".into());
}
let hi = (bytes[i + 1] as char)
.to_digit(16)
.ok_or("bad hex in percent-encoding")?;
let lo = (bytes[i + 2] as char)
.to_digit(16)
.ok_or("bad hex in percent-encoding")?;
out.push(((hi << 4) | lo) as u8);
i += 3;
} else if bytes[i] == b'+' {
out.push(b' ');
i += 1;
} else {
out.push(bytes[i]);
i += 1;
}
}
String::from_utf8(out).map_err(|_| "percent-encoded token is not valid UTF-8".into())
}
fn query_param(query: &str, key: &str) -> Option<String> {
for pair in query.split('&') {
let mut it = pair.splitn(2, '=');
let k = it.next()?;
let v = it.next().unwrap_or("");
if k == key {
return Some(url_decode(v));
}
}
None
}
fn url_decode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'+' => {
out.push(' ');
i += 1;
}
b'%' if i + 2 < bytes.len() => {
if let Ok(h) =
u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
{
out.push(h as char);
i += 3;
} else {
out.push(bytes[i] as char);
i += 1;
}
}
b => {
out.push(b as char);
i += 1;
}
}
}
out
}
#[allow(dead_code)]
fn apply_read_timeout(stream: &TcpStream, dur: Duration) {
let _ = stream.set_read_timeout(Some(dur));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn query_param_parses_basic() {
assert_eq!(
query_param("shard=match1&sid=p1", "shard"),
Some("match1".to_string())
);
assert_eq!(
query_param("shard=match1&sid=p1", "sid"),
Some("p1".to_string())
);
assert_eq!(query_param("shard=match1", "missing"), None);
}
#[test]
fn query_param_url_decodes() {
assert_eq!(
query_param("name=hello%20world", "name"),
Some("hello world".to_string())
);
}
}