Skip to main content

pushwire_server/
server.rs

1use std::collections::{HashMap, VecDeque};
2use std::convert::Infallible;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::time::Instant;
7
8use anyhow::{Context, Result, anyhow};
9use axum::{
10    Router,
11    extract::{
12        Json, Query, State,
13        ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade},
14    },
15    http::StatusCode,
16    response::{IntoResponse, sse::Event, sse::Sse},
17    routing::{get, post},
18};
19use base64::{Engine as _, engine::general_purpose};
20use dashmap::{DashMap, DashSet};
21use futures_util::{SinkExt, StreamExt};
22use pushwire_core::{BinaryEnvelope, ChannelKind, Frame, SystemOp};
23use serde::Deserialize;
24use sha2::{Digest, Sha256};
25use tokio::sync::mpsc;
26use tokio_stream::wrappers::ReceiverStream;
27use tracing::{debug, warn};
28use uuid::Uuid;
29
30const DEFAULT_RESUME_CURSOR: u64 = 0;
31const OUTBOUND_BUFFER: usize = 64;
32const REPLAY_BUFFER: usize = 256;
33const BINARY_INLINE_LIMIT: usize = 256 * 1024;
34const ALLOWED_BINARY_MIME: &[&str] = &["image/png", "image/jpeg", "image/webp", "image/gif"];
35const QUEUE_WARN_THRESHOLD: usize = OUTBOUND_BUFFER / 2;
36
37// ---------------------------------------------------------------------------
38// Priority constants
39// ---------------------------------------------------------------------------
40
41const PRIORITY_HIGH: u8 = 0;
42const PRIORITY_NORMAL: u8 = 1;
43const PRIORITY_LOW: u8 = 2;
44
45// ---------------------------------------------------------------------------
46// Handler type — closures registered per-channel by the consumer.
47// ---------------------------------------------------------------------------
48
49/// Callback invoked for each inbound frame on a non-system channel.
50///
51/// The handler receives the sender's client ID, the frame, and a reference to
52/// the server so it can call [`PushServer::send`] or inspect connection state.
53pub type ChannelHandler<C> = Arc<dyn Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync>;
54
55/// Auth validation callback invoked during the initial handshake.
56pub type AuthValidator<C> =
57    Arc<dyn Fn(Uuid, Option<&str>, &[C]) -> Result<(), AuthError> + Send + Sync>;
58
59// ---------------------------------------------------------------------------
60// Internal connection state
61// ---------------------------------------------------------------------------
62
63#[derive(Debug)]
64#[allow(dead_code)]
65struct ConnectionHandle<C: ChannelKind> {
66    sender: mpsc::Sender<Outbound<C>>,
67    queue_high: mpsc::Sender<Outbound<C>>,
68    queue_normal: mpsc::Sender<Outbound<C>>,
69    queue_low: mpsc::Sender<Outbound<C>>,
70    depth_high: Arc<AtomicUsize>,
71    depth_normal: Arc<AtomicUsize>,
72    depth_low: Arc<AtomicUsize>,
73    capabilities: Vec<C>,
74    token: Option<String>,
75    created_at: Instant,
76    replay: Arc<ClientReplay<C>>,
77    allowed_channels: DashSet<C>,
78}
79
80#[derive(Debug)]
81#[allow(dead_code)]
82struct SseHandle<C: ChannelKind> {
83    sender: mpsc::Sender<Frame<C>>,
84    allowed_channels: DashSet<C>,
85    replay: Arc<ClientReplay<C>>,
86}
87
88#[derive(Debug)]
89#[allow(dead_code)]
90enum Outbound<C: ChannelKind> {
91    Frame(Frame<C>),
92    System(SystemOp<C>),
93    Raw(Message),
94    Priority {
95        priority: u8,
96        inner: Box<Outbound<C>>,
97    },
98}
99
100impl<C: ChannelKind> Outbound<C> {
101    fn into_message(self) -> serde_json::Result<Message> {
102        match self {
103            Outbound::Frame(frame) => serde_json::to_string(&frame).map(Message::Text),
104            Outbound::System(op) => serde_json::to_string(&op).map(Message::Text),
105            Outbound::Raw(message) => Ok(message),
106            Outbound::Priority { inner, .. } => inner.into_message(),
107        }
108    }
109
110    fn priority(&self) -> u8 {
111        match self {
112            Outbound::Priority { priority, .. } => *priority,
113            Outbound::System(_) => PRIORITY_HIGH,
114            Outbound::Frame(frame) => frame.channel.priority(),
115            Outbound::Raw(_) => PRIORITY_NORMAL,
116        }
117    }
118}
119
120// ---------------------------------------------------------------------------
121// Cursor tracking & replay
122// ---------------------------------------------------------------------------
123
124#[derive(Debug, Default)]
125struct ChannelCursorState {
126    last_sent: AtomicU64,
127    last_acked: AtomicU64,
128    buffer_floor: AtomicU64,
129}
130
131impl ChannelCursorState {
132    fn mark_sent(&self, cursor: u64) {
133        let _ = self.last_sent.fetch_max(cursor, Ordering::SeqCst);
134        let _ = self
135            .buffer_floor
136            .fetch_max(self.last_acked(), Ordering::SeqCst);
137    }
138
139    fn mark_acked(&self, cursor: u64) {
140        let mut current = self.last_acked.load(Ordering::SeqCst);
141        while cursor > current {
142            match self.last_acked.compare_exchange(
143                current,
144                cursor,
145                Ordering::SeqCst,
146                Ordering::SeqCst,
147            ) {
148                Ok(_) => break,
149                Err(observed) => current = observed,
150            }
151        }
152        let _ = self.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
153    }
154
155    fn last_sent(&self) -> u64 {
156        self.last_sent.load(Ordering::SeqCst)
157    }
158
159    fn last_acked(&self) -> u64 {
160        self.last_acked.load(Ordering::SeqCst)
161    }
162
163    fn buffer_floor(&self) -> u64 {
164        self.buffer_floor.load(Ordering::SeqCst)
165    }
166}
167
168#[derive(Debug)]
169struct ChannelReplay<C: ChannelKind> {
170    state: Arc<ChannelCursorState>,
171    buffer: Mutex<VecDeque<Frame<C>>>,
172}
173
174impl<C: ChannelKind> ChannelReplay<C> {
175    fn new() -> Self {
176        Self {
177            state: Arc::new(ChannelCursorState::default()),
178            buffer: Mutex::new(VecDeque::new()),
179        }
180    }
181
182    fn state(&self) -> Arc<ChannelCursorState> {
183        self.state.clone()
184    }
185
186    fn push(&self, frame: &Frame<C>, limit: usize) {
187        let mut buffer = self.buffer.lock().unwrap();
188        buffer.push_back(frame.clone());
189        while buffer.len() > limit {
190            if let Some(dropped) = buffer.pop_front()
191                && let Some(cursor) = dropped.cursor
192            {
193                self.state.buffer_floor.store(cursor, Ordering::SeqCst);
194            }
195        }
196    }
197
198    fn ack(&self, cursor: u64) {
199        self.state.mark_acked(cursor);
200        let mut buffer = self.buffer.lock().unwrap();
201        while buffer
202            .front()
203            .and_then(|f| f.cursor)
204            .map(|c| c <= cursor)
205            .unwrap_or(false)
206        {
207            buffer.pop_front();
208        }
209        let _ = self.state.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
210    }
211
212    fn replay_from(&self, from: u64) -> ReplayOutcome<C> {
213        let floor = self.state.buffer_floor();
214        if from < floor {
215            return ReplayOutcome::Gap {
216                buffer_floor: floor,
217            };
218        }
219
220        let min_cursor = self.state.last_acked().max(from);
221        let buffer = self.buffer.lock().unwrap();
222        let frames: Vec<Frame<C>> = buffer
223            .iter()
224            .filter(|f| f.cursor.map(|c| c > min_cursor).unwrap_or(false))
225            .cloned()
226            .collect();
227        ReplayOutcome::Frames(frames)
228    }
229}
230
231#[derive(Debug)]
232enum ReplayOutcome<C: ChannelKind> {
233    Frames(Vec<Frame<C>>),
234    Gap { buffer_floor: u64 },
235}
236
237#[derive(Debug)]
238struct ClientReplay<C: ChannelKind> {
239    channels: DashMap<C, Arc<ChannelReplay<C>>>,
240}
241
242impl<C: ChannelKind> Default for ClientReplay<C> {
243    fn default() -> Self {
244        Self {
245            channels: DashMap::new(),
246        }
247    }
248}
249
250impl<C: ChannelKind> ClientReplay<C> {
251    fn channel(&self, channel: C) -> Arc<ChannelReplay<C>> {
252        self.channels
253            .entry(channel)
254            .or_insert_with(|| Arc::new(ChannelReplay::new()))
255            .clone()
256    }
257
258    fn resume_state(&self) -> HashMap<C, u64> {
259        self.channels
260            .iter()
261            .map(|entry| (*entry.key(), entry.value().state.last_acked()))
262            .collect()
263    }
264}
265
266// ---------------------------------------------------------------------------
267// Errors
268// ---------------------------------------------------------------------------
269
270/// Errors that can occur when pushing frames to a connected client.
271#[derive(Debug, thiserror::Error)]
272pub enum SendError {
273    #[error("client {0} not connected")]
274    NotConnected(Uuid),
275    #[error("send buffer full for client {0}")]
276    Backpressure(Uuid),
277    #[error("payload rejected: {0}")]
278    Rejected(String),
279    #[error("payload serialization error: {0}")]
280    Serialization(String),
281}
282
283#[derive(Debug, thiserror::Error)]
284pub enum AuthError {
285    #[error("invalid token")]
286    InvalidToken,
287    #[error("capabilities not permitted")]
288    Forbidden,
289    #[error("{0}")]
290    Other(String),
291}
292
293// ---------------------------------------------------------------------------
294// PushServer<C>
295// ---------------------------------------------------------------------------
296
297/// Generic multiplexed push server parameterized by channel type.
298///
299/// `C` implements [`ChannelKind`] and defines the channel taxonomy.
300/// Channel-specific handling is registered by the consumer via
301/// [`register_handler`](PushServer::register_handler).
302pub struct PushServer<C: ChannelKind> {
303    connections: DashMap<Uuid, ConnectionHandle<C>>,
304    sse_connections: DashMap<Uuid, SseHandle<C>>,
305    channel_cursors: DashMap<C, Arc<AtomicU64>>,
306    client_replay: DashMap<Uuid, Arc<ClientReplay<C>>>,
307    channel_handlers: DashMap<C, ChannelHandler<C>>,
308    auth_validator: AuthValidator<C>,
309}
310
311impl<C: ChannelKind> Default for PushServer<C> {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317impl<C: ChannelKind> PushServer<C> {
318    pub fn new() -> Self {
319        Self::with_auth_validator(Arc::new(|_, _, _| Ok(())))
320    }
321
322    pub fn with_auth_validator(auth_validator: AuthValidator<C>) -> Self {
323        let counters = DashMap::new();
324        for channel in C::all() {
325            counters.insert(*channel, Arc::new(AtomicU64::new(0)));
326        }
327
328        Self {
329            connections: DashMap::new(),
330            sse_connections: DashMap::new(),
331            channel_cursors: counters,
332            client_replay: DashMap::new(),
333            channel_handlers: DashMap::new(),
334            auth_validator,
335        }
336    }
337
338    fn sha256_hex(bytes: &[u8]) -> String {
339        let mut hasher = Sha256::new();
340        hasher.update(bytes);
341        let digest = hasher.finalize();
342        digest.iter().map(|b| format!("{:02x}", b)).collect()
343    }
344
345    /// Axum router exposing `/rps` (WebSocket), `/rps/sse`, and `/rps/ack`.
346    pub fn router(self: Arc<Self>) -> Router<Arc<Self>> {
347        Router::new()
348            .route("/rps", get(ws_upgrade::<C>))
349            .route("/rps/sse", get(sse_upgrade::<C>))
350            .route("/rps/ack", post(http_ack::<C>))
351            .with_state(self)
352    }
353
354    /// Number of currently connected WebSocket clients.
355    pub fn connected_clients(&self) -> usize {
356        self.connections.len()
357    }
358
359    /// All currently connected client IDs.
360    pub fn connected_client_ids(&self) -> Vec<Uuid> {
361        self.connections.iter().map(|entry| *entry.key()).collect()
362    }
363
364    /// Register a handler for a specific channel. When a frame arrives on this
365    /// channel, the handler is invoked with the client ID, the frame, and a
366    /// reference to the server.
367    pub fn register_handler<F>(&self, channel: C, handler: F)
368    where
369        F: Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync + 'static,
370    {
371        self.channel_handlers.insert(channel, Arc::new(handler));
372    }
373
374    fn stamp_frame(&self, replay: &ClientReplay<C>, frame: Frame<C>) -> Frame<C> {
375        let cursor = self.next_cursor(frame.channel, frame.cursor);
376        replay.channel(frame.channel).state().mark_sent(cursor);
377        frame.with_cursor(cursor)
378    }
379
380    fn next_cursor(&self, channel: C, existing: Option<u64>) -> u64 {
381        let counter = self
382            .channel_cursors
383            .get(&channel)
384            .map(|c| c.clone())
385            .unwrap_or_else(|| {
386                let fresh = Arc::new(AtomicU64::new(0));
387                self.channel_cursors.insert(channel, fresh.clone());
388                fresh
389            });
390
391        let cursor = existing.unwrap_or_else(|| counter.fetch_add(1, Ordering::SeqCst) + 1);
392        let _ = counter.fetch_max(cursor, Ordering::SeqCst);
393        cursor
394    }
395
396    /// Upgrade an HTTP request to an RPS WebSocket connection.
397    pub async fn upgrade(self: Arc<Self>, ws: WebSocketUpgrade) -> impl IntoResponse {
398        ws.on_upgrade(move |socket| async move {
399            if let Err(err) = self.handle_socket(socket).await {
400                warn!(?err, "RPS websocket closed with error");
401            }
402        })
403    }
404
405    /// Push a frame to a connected client.
406    pub fn send(&self, client_id: Uuid, frame: Frame<C>) -> Result<(), SendError> {
407        let replay = self
408            .client_replay
409            .entry(client_id)
410            .or_insert_with(|| Arc::new(ClientReplay::default()))
411            .clone();
412        match self.connections.get(&client_id) {
413            Some(conn) => {
414                let stamped = self.stamp_frame(&replay, frame).with_client(client_id);
415                replay
416                    .channel(stamped.channel)
417                    .push(&stamped, REPLAY_BUFFER);
418                self.enqueue_outbound(conn.value(), Outbound::Frame(stamped.clone()), client_id)?;
419
420                if let Some(sse) = self.sse_connections.get(&client_id)
421                    && sse.allowed_channels.contains(&stamped.channel)
422                    && let Err(err) = sse.sender.try_send(stamped)
423                {
424                    warn!(?client_id, ?err, "dropping SSE frame (buffer full?)");
425                    self.sse_connections.remove(&client_id);
426                }
427
428                Ok(())
429            }
430            None => Err(SendError::NotConnected(client_id)),
431        }
432    }
433
434    /// Send a binary asset using inline-or-pointer logic.
435    pub fn send_binary(
436        &self,
437        client_id: Uuid,
438        channel: C,
439        bytes: &[u8],
440        mime: &str,
441        name: Option<&str>,
442        pointer_url: Option<&str>,
443    ) -> Result<(), SendError> {
444        if !ALLOWED_BINARY_MIME
445            .iter()
446            .any(|m| m.eq_ignore_ascii_case(mime))
447        {
448            return Err(SendError::Rejected(format!(
449                "mime type {mime} not permitted"
450            )));
451        }
452
453        let sha256 = Self::sha256_hex(bytes);
454        let size = bytes.len() as u64;
455        let envelope = if bytes.len() <= BINARY_INLINE_LIMIT {
456            BinaryEnvelope::Inline {
457                mime: mime.to_string(),
458                sha256,
459                size,
460                data_base64: general_purpose::STANDARD.encode(bytes),
461                name: name.map(|s| s.to_string()),
462            }
463        } else if let Some(url) = pointer_url {
464            BinaryEnvelope::Pointer {
465                mime: mime.to_string(),
466                sha256,
467                size,
468                url: url.to_string(),
469                name: name.map(|s| s.to_string()),
470            }
471        } else {
472            return Err(SendError::Rejected(format!(
473                "payload size {} exceeds inline limit {} and no pointer_url provided",
474                bytes.len(),
475                BINARY_INLINE_LIMIT
476            )));
477        };
478
479        let payload =
480            serde_json::to_value(envelope).map_err(|e| SendError::Serialization(e.to_string()))?;
481        self.send(client_id, Frame::new(channel, payload))
482    }
483
484    /// Send a system-level message to a client.
485    pub fn send_system(&self, client_id: Uuid, op: SystemOp<C>) {
486        self.enqueue_system(client_id, op);
487    }
488
489    // -----------------------------------------------------------------------
490    // WebSocket connection lifecycle
491    // -----------------------------------------------------------------------
492
493    async fn handle_socket(self: Arc<Self>, socket: WebSocket) -> Result<()> {
494        let (mut ws_tx, mut ws_rx) = socket.split();
495
496        let first = futures_util::StreamExt::next(&mut ws_rx)
497            .await
498            .ok_or_else(|| anyhow!("connection closed before auth"))?;
499        let first = first.context("failed to read first RPS frame")?;
500        let auth: SystemOp<C> = match first {
501            Message::Text(text) => {
502                serde_json::from_str(&text).context("failed to parse auth frame")?
503            }
504            Message::Binary(bytes) => {
505                serde_json::from_slice(&bytes).context("failed to parse binary auth frame")?
506            }
507            other => anyhow::bail!("expected auth frame as text, got {other:?}"),
508        };
509
510        let (client_id, capabilities, token, resume_cursor, resume_cursors) = match auth {
511            SystemOp::Auth {
512                client_id,
513                capabilities,
514                token,
515                resume_cursor,
516                resume_cursors,
517                ..
518            } => (
519                client_id,
520                capabilities,
521                token,
522                resume_cursor,
523                resume_cursors,
524            ),
525            other => anyhow::bail!("first RPS frame must be auth, got {other:?}"),
526        };
527
528        if let Err(err) = (self.auth_validator)(client_id, token.as_deref(), &capabilities) {
529            let reason = match &err {
530                AuthError::InvalidToken => "invalid token",
531                AuthError::Forbidden => "capabilities not permitted",
532                AuthError::Other(msg) => msg.as_str(),
533            }
534            .to_string();
535            let _ = ws_tx
536                .send(Message::Close(Some(CloseFrame {
537                    code: 1008,
538                    reason: reason.into(),
539                })))
540                .await;
541            return Err(anyhow!(err));
542        }
543
544        let (tx, mut rx) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
545        let (q_high, mut rx_high) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
546        let (q_norm, mut rx_norm) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
547        let (q_low, mut rx_low) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
548        let depth_high = Arc::new(AtomicUsize::new(0));
549        let depth_normal = Arc::new(AtomicUsize::new(0));
550        let depth_low = Arc::new(AtomicUsize::new(0));
551
552        let replay = self
553            .client_replay
554            .entry(client_id)
555            .or_insert_with(|| Arc::new(ClientReplay::default()))
556            .clone();
557
558        let allowed_init = {
559            let set = DashSet::new();
560            for ch in &capabilities {
561                set.insert(*ch);
562            }
563            set
564        };
565
566        if let Some(_old) = self.connections.insert(
567            client_id,
568            ConnectionHandle {
569                sender: tx.clone(),
570                queue_high: q_high.clone(),
571                queue_normal: q_norm.clone(),
572                queue_low: q_low.clone(),
573                depth_high: depth_high.clone(),
574                depth_normal: depth_normal.clone(),
575                depth_low: depth_low.clone(),
576                capabilities,
577                token,
578                created_at: Instant::now(),
579                replay: replay.clone(),
580                allowed_channels: allowed_init,
581            },
582        ) {
583            warn!(?client_id, "replacing existing RPS connection for client");
584        }
585
586        let resume_snapshot = replay.resume_state();
587        let resume_cursor_reply = resume_snapshot
588            .values()
589            .copied()
590            .max()
591            .unwrap_or(DEFAULT_RESUME_CURSOR);
592        ws_tx
593            .send(Message::Text(
594                serde_json::to_string(&SystemOp::<C>::AuthOk {
595                    resume_cursor: resume_cursor_reply,
596                    resume_cursors: resume_snapshot.clone(),
597                })
598                .context("serialize auth_ok")?,
599            ))
600            .await
601            .map_err(anyhow::Error::new)?;
602
603        // Resume: use per-channel cursors if provided, otherwise fall back to global.
604        let mut requested = resume_cursors;
605        if let Some(global) = resume_cursor {
606            for channel in C::all() {
607                requested.entry(*channel).or_insert(global);
608            }
609        }
610
611        for entry in replay.channels.iter() {
612            let channel = *entry.key();
613            let channel_replay = entry.value();
614            let from = requested
615                .get(&channel)
616                .copied()
617                .unwrap_or(DEFAULT_RESUME_CURSOR);
618            match channel_replay.replay_from(from) {
619                ReplayOutcome::Frames(frames) => {
620                    for frame in frames {
621                        if let Err(err) = tx.try_send(Outbound::Frame(frame)) {
622                            warn!(?client_id, ?err, "failed to enqueue replay frame");
623                            break;
624                        }
625                    }
626                }
627                ReplayOutcome::Gap { buffer_floor } => {
628                    self.enqueue_system(
629                        client_id,
630                        SystemOp::ResumeRequired {
631                            channel,
632                            from_cursor: buffer_floor,
633                        },
634                    );
635                }
636            }
637        }
638
639        let writer = tokio::spawn(async move {
640            // Use biased select! to maintain priority ordering while waking
641            // on ANY channel. This fixes a wake-up issue where the writer
642            // would block on rx.recv() and never check priority queues.
643            loop {
644                tokio::select! {
645                    biased;
646                    Some(item) = rx_high.recv() => {
647                        depth_high.fetch_sub(1, Ordering::SeqCst);
648                        let message = item.into_message().context("serialize prio-high RPS")?;
649                        ws_tx.send(message).await.map_err(anyhow::Error::new)?;
650                    }
651                    Some(item) = rx_norm.recv() => {
652                        depth_normal.fetch_sub(1, Ordering::SeqCst);
653                        let message = item.into_message().context("serialize prio-norm RPS")?;
654                        ws_tx.send(message).await.map_err(anyhow::Error::new)?;
655                    }
656                    Some(item) = rx_low.recv() => {
657                        depth_low.fetch_sub(1, Ordering::SeqCst);
658                        let message = item.into_message().context("serialize prio-low RPS")?;
659                        ws_tx.send(message).await.map_err(anyhow::Error::new)?;
660                    }
661                    result = rx.recv() => {
662                        match result {
663                            Some(outbound) => {
664                                let message = outbound
665                                    .into_message()
666                                    .context("serialize outbound RPS message")?;
667                                ws_tx.send(message).await.map_err(anyhow::Error::new)?;
668                            }
669                            None => break,
670                        }
671                    }
672                }
673            }
674
675            Ok::<(), anyhow::Error>(())
676        });
677
678        let reader = {
679            let server = self.clone();
680            let tx = tx.clone();
681
682            tokio::spawn(async move {
683                while let Some(incoming) = futures_util::StreamExt::next(&mut ws_rx).await {
684                    match incoming {
685                        Ok(Message::Text(text)) => match serde_json::from_str::<Frame<C>>(&text) {
686                            Ok(frame) => server.handle_incoming(client_id, frame).await,
687                            Err(err) => {
688                                warn!(?err, "invalid RPS frame from client");
689                                server.enqueue_system(
690                                    client_id,
691                                    SystemOp::Error {
692                                        message: "invalid frame schema".into(),
693                                    },
694                                );
695                            }
696                        },
697                        Ok(Message::Binary(_)) => {
698                            warn!("ignoring binary RPS frame");
699                        }
700                        Ok(Message::Ping(payload)) => {
701                            let _ = tx.send(Outbound::Raw(Message::Pong(payload))).await;
702                        }
703                        Ok(Message::Pong(_)) => {}
704                        Ok(Message::Close(_)) => break,
705                        Err(err) => return Err(anyhow::Error::new(err)),
706                    }
707                }
708
709                Ok::<(), anyhow::Error>(())
710            })
711        };
712
713        let result = tokio::try_join!(writer, reader);
714
715        self.connections.remove(&client_id);
716
717        result.map(|_| ()).map_err(anyhow::Error::new)
718    }
719
720    // -----------------------------------------------------------------------
721    // Inbound frame handling
722    // -----------------------------------------------------------------------
723
724    async fn handle_incoming(&self, client_id: Uuid, frame: Frame<C>) {
725        if let Err(msg) = validate_frame(&frame) {
726            self.enqueue_system(
727                client_id,
728                SystemOp::Error {
729                    message: msg.to_string(),
730                },
731            );
732            return;
733        }
734
735        if frame.channel.is_system()
736            && let Some(conn) = self.connections.get(&client_id)
737        {
738            conn.replay
739                .channel(frame.channel)
740                .push(&frame, REPLAY_BUFFER);
741        }
742
743        let replay = self.client_replay.get(&client_id).map(|c| c.clone());
744
745        if frame.channel.is_system() {
746            match serde_json::from_value::<SystemOp<C>>(frame.payload.clone()) {
747                Ok(SystemOp::Ping) => self.enqueue_system(client_id, SystemOp::Pong),
748                Ok(SystemOp::Slow { window }) => {
749                    debug!(?client_id, ?window, "client reported backpressure window");
750                }
751                Ok(SystemOp::Ack { channel, cursor }) => {
752                    self.handle_ack(client_id, channel, cursor, replay.as_deref());
753                }
754                Ok(SystemOp::ResumeRequired { .. }) => {
755                    debug!(?client_id, "client reported resume_required; ignoring");
756                }
757                Ok(SystemOp::Subscribe { channels }) => {
758                    if let Some(conn) = self.connections.get(&client_id) {
759                        for ch in channels {
760                            conn.allowed_channels.insert(ch);
761                        }
762                    }
763                }
764                Ok(SystemOp::Unsubscribe { channels }) => {
765                    if let Some(conn) = self.connections.get(&client_id) {
766                        for ch in channels {
767                            conn.allowed_channels.remove(&ch);
768                        }
769                    }
770                }
771                Ok(SystemOp::Health { status, detail }) => {
772                    debug!(?client_id, ?status, ?detail, "client reported health");
773                }
774                Ok(SystemOp::Features {
775                    supported,
776                    requested,
777                }) => {
778                    debug!(?client_id, ?supported, ?requested, "client features");
779                }
780                Ok(SystemOp::Goodbye { reason }) => {
781                    debug!(?client_id, ?reason, "client goodbye");
782                    self.enqueue_system(client_id, SystemOp::Goodbye { reason });
783                }
784                Ok(other) => {
785                    debug!(?client_id, ?other, "received system message");
786                }
787                Err(err) => {
788                    warn!(?err, "invalid system payload");
789                    self.enqueue_system(
790                        client_id,
791                        SystemOp::Error {
792                            message: "invalid system payload".into(),
793                        },
794                    );
795                }
796            }
797        } else {
798            let channel = frame.channel;
799
800            // Capability/subscription enforcement
801            if let Some(conn) = self.connections.get(&client_id)
802                && !conn.allowed_channels.contains(&channel)
803            {
804                self.enqueue_system(
805                    client_id,
806                    SystemOp::Error {
807                        message: format!("channel {} not subscribed", channel.name()),
808                    },
809                );
810                return;
811            }
812
813            // Dispatch to registered handler
814            if let Some(handler) = self.channel_handlers.get(&channel) {
815                (handler.value())(client_id, frame.clone(), self);
816            } else {
817                self.enqueue_system(
818                    client_id,
819                    SystemOp::Error {
820                        message: format!("no handler for channel {}", channel.name()),
821                    },
822                );
823            }
824
825            debug!(
826                ?client_id,
827                channel = channel.name(),
828                cursor = ?frame.cursor,
829                "received RPS frame"
830            );
831        }
832    }
833
834    fn handle_ack(
835        &self,
836        client_id: Uuid,
837        channel: C,
838        cursor: u64,
839        replay: Option<&ClientReplay<C>>,
840    ) {
841        let Some(replay) = replay else {
842            warn!(?client_id, "ack from unknown client");
843            return;
844        };
845
846        let channel_replay = replay.channel(channel);
847        let state = channel_replay.state();
848        let last_sent = state.last_sent();
849        let buffer_floor = state.buffer_floor();
850
851        if cursor < buffer_floor {
852            self.enqueue_system(
853                client_id,
854                SystemOp::ResumeRequired {
855                    channel,
856                    from_cursor: buffer_floor,
857                },
858            );
859            return;
860        }
861
862        if cursor > last_sent {
863            self.enqueue_system(
864                client_id,
865                SystemOp::ResumeRequired {
866                    channel,
867                    from_cursor: last_sent,
868                },
869            );
870            return;
871        }
872
873        channel_replay.ack(cursor);
874    }
875
876    fn enqueue_system(&self, client_id: Uuid, op: SystemOp<C>) {
877        if let Some(conn) = self.connections.get(&client_id) {
878            let _ = self.enqueue_outbound(conn.value(), Outbound::System(op), client_id);
879        } else {
880            warn!(?client_id, "ignoring system send for unknown client");
881        }
882    }
883
884    fn enqueue_outbound(
885        &self,
886        conn: &ConnectionHandle<C>,
887        outbound: Outbound<C>,
888        client_id: Uuid,
889    ) -> Result<(), SendError> {
890        let prio = outbound.priority();
891        let (target, depth) = match prio {
892            PRIORITY_HIGH => (&conn.queue_high, &conn.depth_high),
893            PRIORITY_LOW => (&conn.queue_low, &conn.depth_low),
894            _ => (&conn.queue_normal, &conn.depth_normal),
895        };
896
897        let depth_now = depth.fetch_add(1, Ordering::SeqCst) + 1;
898        if depth_now > QUEUE_WARN_THRESHOLD {
899            debug!(
900                ?client_id,
901                ?prio,
902                depth = depth_now,
903                "send queue depth high"
904            );
905        }
906
907        if depth_now > OUTBOUND_BUFFER {
908            depth.fetch_sub(1, Ordering::SeqCst);
909            if prio == PRIORITY_LOW {
910                warn!(
911                    ?client_id,
912                    ?prio,
913                    "dropping low-priority frame (queue full)"
914                );
915                return Ok(());
916            } else {
917                warn!(
918                    ?client_id,
919                    ?prio,
920                    "send queue overflow; treating as backpressure"
921                );
922                return Err(SendError::Backpressure(client_id));
923            }
924        }
925
926        match target.try_send(outbound) {
927            Ok(_) => Ok(()),
928            Err(mpsc::error::TrySendError::Full(_)) => {
929                depth.fetch_sub(1, Ordering::SeqCst);
930                if prio == PRIORITY_LOW {
931                    warn!(
932                        ?client_id,
933                        ?prio,
934                        "dropping low-priority frame (queue full)"
935                    );
936                    Ok(())
937                } else {
938                    Err(SendError::Backpressure(client_id))
939                }
940            }
941            Err(mpsc::error::TrySendError::Closed(_)) => {
942                depth.fetch_sub(1, Ordering::SeqCst);
943                Err(SendError::NotConnected(client_id))
944            }
945        }
946    }
947}
948
949// ---------------------------------------------------------------------------
950// Axum route handlers
951// ---------------------------------------------------------------------------
952
953async fn ws_upgrade<C: ChannelKind>(
954    State(server): State<Arc<PushServer<C>>>,
955    ws: WebSocketUpgrade,
956) -> impl IntoResponse {
957    server.upgrade(ws).await
958}
959
960#[derive(Debug, Deserialize)]
961struct SseParams {
962    client_id: Uuid,
963    #[serde(default)]
964    token: Option<String>,
965    #[serde(default)]
966    capabilities: Option<String>,
967    #[serde(default)]
968    channels: Option<String>,
969    #[serde(default)]
970    resume_cursor: Option<u64>,
971}
972
973async fn sse_upgrade<C: ChannelKind>(
974    State(server): State<Arc<PushServer<C>>>,
975    Query(params): Query<SseParams>,
976) -> Result<impl IntoResponse, StatusCode> {
977    let client_id = params.client_id;
978    let capabilities = parse_channels::<C>(params.capabilities.as_deref());
979    let subscribe = parse_channels::<C>(params.channels.as_deref());
980
981    if let Err(_err) = (server.auth_validator)(client_id, params.token.as_deref(), &capabilities) {
982        return Err(StatusCode::UNAUTHORIZED);
983    }
984
985    let replay = server
986        .client_replay
987        .entry(client_id)
988        .or_insert_with(|| Arc::new(ClientReplay::default()))
989        .clone();
990
991    let allowed = {
992        let set = DashSet::new();
993        if !subscribe.is_empty() {
994            for ch in subscribe {
995                set.insert(ch);
996            }
997        } else if !capabilities.is_empty() {
998            for ch in capabilities.clone() {
999                set.insert(ch);
1000            }
1001        } else {
1002            for ch in C::all() {
1003                set.insert(*ch);
1004            }
1005        }
1006        set
1007    };
1008
1009    let (tx, rx) = mpsc::channel::<Frame<C>>(OUTBOUND_BUFFER);
1010
1011    server.sse_connections.insert(
1012        client_id,
1013        SseHandle {
1014            sender: tx.clone(),
1015            allowed_channels: allowed.clone(),
1016            replay: replay.clone(),
1017        },
1018    );
1019
1020    // Send auth_ok snapshot.
1021    let snapshot = replay.resume_state();
1022    let resume_cursor = snapshot
1023        .values()
1024        .copied()
1025        .max()
1026        .unwrap_or(DEFAULT_RESUME_CURSOR);
1027
1028    let system_channel = C::all()
1029        .iter()
1030        .find(|c| c.is_system())
1031        .copied()
1032        .expect("ChannelKind must have a system channel");
1033
1034    let auth_ok = Frame::new(
1035        system_channel,
1036        serde_json::to_value(SystemOp::<C>::AuthOk {
1037            resume_cursor,
1038            resume_cursors: snapshot.clone(),
1039        })
1040        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
1041    )
1042    .with_client(client_id);
1043    let _ = tx.try_send(auth_ok);
1044
1045    // Replay buffered frames if requested.
1046    if let Some(from) = params.resume_cursor {
1047        for entry in replay.channels.iter() {
1048            let channel = *entry.key();
1049            if !allowed.contains(&channel) {
1050                continue;
1051            }
1052            match entry.value().replay_from(from) {
1053                ReplayOutcome::Frames(frames) => {
1054                    for frame in frames {
1055                        let _ = tx.try_send(frame);
1056                    }
1057                }
1058                ReplayOutcome::Gap { buffer_floor } => {
1059                    let _ = tx.try_send(
1060                        Frame::new(
1061                            system_channel,
1062                            serde_json::to_value(SystemOp::<C>::ResumeRequired {
1063                                channel,
1064                                from_cursor: buffer_floor,
1065                            })
1066                            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
1067                        )
1068                        .with_client(client_id),
1069                    );
1070                }
1071            }
1072        }
1073    }
1074
1075    let stream = futures_util::StreamExt::map(
1076        ReceiverStream::new(rx),
1077        |frame| -> Result<Event, Infallible> {
1078            let id = frame
1079                .cursor
1080                .map(|c| c.to_string())
1081                .unwrap_or_else(|| "0".into());
1082            let event = match serde_json::to_string(&frame) {
1083                Ok(json) => Event::default().event("frame").id(id).data(json),
1084                Err(err) => {
1085                    warn!(?err, "failed to serialize SSE frame");
1086                    Event::default().event("error").data("serialize_failed")
1087                }
1088            };
1089            Ok(event)
1090        },
1091    );
1092
1093    Ok(Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
1094}
1095
1096#[derive(Debug, Deserialize)]
1097#[serde(bound(deserialize = "C: ChannelKind"))]
1098struct AckBody<C: ChannelKind> {
1099    client_id: Uuid,
1100    channel: C,
1101    cursor: u64,
1102}
1103
1104async fn http_ack<C: ChannelKind>(
1105    State(server): State<Arc<PushServer<C>>>,
1106    Json(body): Json<AckBody<C>>,
1107) -> impl IntoResponse {
1108    server.handle_ack(body.client_id, body.channel, body.cursor, None);
1109    axum::http::StatusCode::NO_CONTENT
1110}
1111
1112// ---------------------------------------------------------------------------
1113// Utilities
1114// ---------------------------------------------------------------------------
1115
1116fn parse_channels<C: ChannelKind>(raw: Option<&str>) -> Vec<C> {
1117    raw.map(|list| {
1118        list.split(',')
1119            .filter_map(|s| C::from_name(s.trim()))
1120            .collect()
1121    })
1122    .unwrap_or_default()
1123}
1124
1125fn validate_frame<C: ChannelKind>(frame: &Frame<C>) -> Result<(), &'static str> {
1126    if frame.payload.is_null() {
1127        return Err("payload required");
1128    }
1129    Ok(())
1130}