1use std::collections::{HashMap, VecDeque};
2use std::convert::Infallible;
3use std::sync::Arc;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::time::Instant;
7
8use anyhow::{Context, Result, anyhow};
9use axum::{
10 Router,
11 extract::{
12 Json, Query, State,
13 ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade},
14 },
15 http::StatusCode,
16 response::{IntoResponse, sse::Event, sse::Sse},
17 routing::{get, post},
18};
19use base64::{Engine as _, engine::general_purpose};
20use dashmap::{DashMap, DashSet};
21use futures_util::{SinkExt, StreamExt};
22use pushwire_core::{BinaryEnvelope, ChannelKind, Frame, SystemOp};
23use serde::Deserialize;
24use sha2::{Digest, Sha256};
25use tokio::sync::mpsc;
26use tokio_stream::wrappers::ReceiverStream;
27use tracing::{debug, warn};
28use uuid::Uuid;
29
30const DEFAULT_RESUME_CURSOR: u64 = 0;
31const OUTBOUND_BUFFER: usize = 64;
32const REPLAY_BUFFER: usize = 256;
33const BINARY_INLINE_LIMIT: usize = 256 * 1024;
34const ALLOWED_BINARY_MIME: &[&str] = &["image/png", "image/jpeg", "image/webp", "image/gif"];
35const QUEUE_WARN_THRESHOLD: usize = OUTBOUND_BUFFER / 2;
36
37const PRIORITY_HIGH: u8 = 0;
42const PRIORITY_NORMAL: u8 = 1;
43const PRIORITY_LOW: u8 = 2;
44
45pub type ChannelHandler<C> = Arc<dyn Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync>;
54
55pub type AuthValidator<C> =
57 Arc<dyn Fn(Uuid, Option<&str>, &[C]) -> Result<(), AuthError> + Send + Sync>;
58
59#[derive(Debug)]
64#[allow(dead_code)]
65struct ConnectionHandle<C: ChannelKind> {
66 sender: mpsc::Sender<Outbound<C>>,
67 queue_high: mpsc::Sender<Outbound<C>>,
68 queue_normal: mpsc::Sender<Outbound<C>>,
69 queue_low: mpsc::Sender<Outbound<C>>,
70 depth_high: Arc<AtomicUsize>,
71 depth_normal: Arc<AtomicUsize>,
72 depth_low: Arc<AtomicUsize>,
73 capabilities: Vec<C>,
74 token: Option<String>,
75 created_at: Instant,
76 replay: Arc<ClientReplay<C>>,
77 allowed_channels: DashSet<C>,
78}
79
80#[derive(Debug)]
81#[allow(dead_code)]
82struct SseHandle<C: ChannelKind> {
83 sender: mpsc::Sender<Frame<C>>,
84 allowed_channels: DashSet<C>,
85 replay: Arc<ClientReplay<C>>,
86}
87
88#[derive(Debug)]
89#[allow(dead_code)]
90enum Outbound<C: ChannelKind> {
91 Frame(Frame<C>),
92 System(SystemOp<C>),
93 Raw(Message),
94 Priority {
95 priority: u8,
96 inner: Box<Outbound<C>>,
97 },
98}
99
100impl<C: ChannelKind> Outbound<C> {
101 fn into_message(self) -> serde_json::Result<Message> {
102 match self {
103 Outbound::Frame(frame) => serde_json::to_string(&frame).map(Message::Text),
104 Outbound::System(op) => serde_json::to_string(&op).map(Message::Text),
105 Outbound::Raw(message) => Ok(message),
106 Outbound::Priority { inner, .. } => inner.into_message(),
107 }
108 }
109
110 fn priority(&self) -> u8 {
111 match self {
112 Outbound::Priority { priority, .. } => *priority,
113 Outbound::System(_) => PRIORITY_HIGH,
114 Outbound::Frame(frame) => frame.channel.priority(),
115 Outbound::Raw(_) => PRIORITY_NORMAL,
116 }
117 }
118}
119
120#[derive(Debug, Default)]
125struct ChannelCursorState {
126 last_sent: AtomicU64,
127 last_acked: AtomicU64,
128 buffer_floor: AtomicU64,
129}
130
131impl ChannelCursorState {
132 fn mark_sent(&self, cursor: u64) {
133 let _ = self.last_sent.fetch_max(cursor, Ordering::SeqCst);
134 let _ = self
135 .buffer_floor
136 .fetch_max(self.last_acked(), Ordering::SeqCst);
137 }
138
139 fn mark_acked(&self, cursor: u64) {
140 let mut current = self.last_acked.load(Ordering::SeqCst);
141 while cursor > current {
142 match self.last_acked.compare_exchange(
143 current,
144 cursor,
145 Ordering::SeqCst,
146 Ordering::SeqCst,
147 ) {
148 Ok(_) => break,
149 Err(observed) => current = observed,
150 }
151 }
152 let _ = self.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
153 }
154
155 fn last_sent(&self) -> u64 {
156 self.last_sent.load(Ordering::SeqCst)
157 }
158
159 fn last_acked(&self) -> u64 {
160 self.last_acked.load(Ordering::SeqCst)
161 }
162
163 fn buffer_floor(&self) -> u64 {
164 self.buffer_floor.load(Ordering::SeqCst)
165 }
166}
167
168#[derive(Debug)]
169struct ChannelReplay<C: ChannelKind> {
170 state: Arc<ChannelCursorState>,
171 buffer: Mutex<VecDeque<Frame<C>>>,
172}
173
174impl<C: ChannelKind> ChannelReplay<C> {
175 fn new() -> Self {
176 Self {
177 state: Arc::new(ChannelCursorState::default()),
178 buffer: Mutex::new(VecDeque::new()),
179 }
180 }
181
182 fn state(&self) -> Arc<ChannelCursorState> {
183 self.state.clone()
184 }
185
186 fn push(&self, frame: &Frame<C>, limit: usize) {
187 let mut buffer = self.buffer.lock().unwrap();
188 buffer.push_back(frame.clone());
189 while buffer.len() > limit {
190 if let Some(dropped) = buffer.pop_front()
191 && let Some(cursor) = dropped.cursor
192 {
193 self.state.buffer_floor.store(cursor, Ordering::SeqCst);
194 }
195 }
196 }
197
198 fn ack(&self, cursor: u64) {
199 self.state.mark_acked(cursor);
200 let mut buffer = self.buffer.lock().unwrap();
201 while buffer
202 .front()
203 .and_then(|f| f.cursor)
204 .map(|c| c <= cursor)
205 .unwrap_or(false)
206 {
207 buffer.pop_front();
208 }
209 let _ = self.state.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
210 }
211
212 fn replay_from(&self, from: u64) -> ReplayOutcome<C> {
213 let floor = self.state.buffer_floor();
214 if from < floor {
215 return ReplayOutcome::Gap {
216 buffer_floor: floor,
217 };
218 }
219
220 let min_cursor = self.state.last_acked().max(from);
221 let buffer = self.buffer.lock().unwrap();
222 let frames: Vec<Frame<C>> = buffer
223 .iter()
224 .filter(|f| f.cursor.map(|c| c > min_cursor).unwrap_or(false))
225 .cloned()
226 .collect();
227 ReplayOutcome::Frames(frames)
228 }
229}
230
231#[derive(Debug)]
232enum ReplayOutcome<C: ChannelKind> {
233 Frames(Vec<Frame<C>>),
234 Gap { buffer_floor: u64 },
235}
236
237#[derive(Debug)]
238struct ClientReplay<C: ChannelKind> {
239 channels: DashMap<C, Arc<ChannelReplay<C>>>,
240}
241
242impl<C: ChannelKind> Default for ClientReplay<C> {
243 fn default() -> Self {
244 Self {
245 channels: DashMap::new(),
246 }
247 }
248}
249
250impl<C: ChannelKind> ClientReplay<C> {
251 fn channel(&self, channel: C) -> Arc<ChannelReplay<C>> {
252 self.channels
253 .entry(channel)
254 .or_insert_with(|| Arc::new(ChannelReplay::new()))
255 .clone()
256 }
257
258 fn resume_state(&self) -> HashMap<C, u64> {
259 self.channels
260 .iter()
261 .map(|entry| (*entry.key(), entry.value().state.last_acked()))
262 .collect()
263 }
264}
265
266#[derive(Debug, thiserror::Error)]
272pub enum SendError {
273 #[error("client {0} not connected")]
274 NotConnected(Uuid),
275 #[error("send buffer full for client {0}")]
276 Backpressure(Uuid),
277 #[error("payload rejected: {0}")]
278 Rejected(String),
279 #[error("payload serialization error: {0}")]
280 Serialization(String),
281}
282
283#[derive(Debug, thiserror::Error)]
284pub enum AuthError {
285 #[error("invalid token")]
286 InvalidToken,
287 #[error("capabilities not permitted")]
288 Forbidden,
289 #[error("{0}")]
290 Other(String),
291}
292
293pub struct PushServer<C: ChannelKind> {
303 connections: DashMap<Uuid, ConnectionHandle<C>>,
304 sse_connections: DashMap<Uuid, SseHandle<C>>,
305 channel_cursors: DashMap<C, Arc<AtomicU64>>,
306 client_replay: DashMap<Uuid, Arc<ClientReplay<C>>>,
307 channel_handlers: DashMap<C, ChannelHandler<C>>,
308 auth_validator: AuthValidator<C>,
309}
310
311impl<C: ChannelKind> Default for PushServer<C> {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317impl<C: ChannelKind> PushServer<C> {
318 pub fn new() -> Self {
319 Self::with_auth_validator(Arc::new(|_, _, _| Ok(())))
320 }
321
322 pub fn with_auth_validator(auth_validator: AuthValidator<C>) -> Self {
323 let counters = DashMap::new();
324 for channel in C::all() {
325 counters.insert(*channel, Arc::new(AtomicU64::new(0)));
326 }
327
328 Self {
329 connections: DashMap::new(),
330 sse_connections: DashMap::new(),
331 channel_cursors: counters,
332 client_replay: DashMap::new(),
333 channel_handlers: DashMap::new(),
334 auth_validator,
335 }
336 }
337
338 fn sha256_hex(bytes: &[u8]) -> String {
339 let mut hasher = Sha256::new();
340 hasher.update(bytes);
341 let digest = hasher.finalize();
342 digest.iter().map(|b| format!("{:02x}", b)).collect()
343 }
344
345 pub fn router(self: Arc<Self>) -> Router<Arc<Self>> {
347 Router::new()
348 .route("/rps", get(ws_upgrade::<C>))
349 .route("/rps/sse", get(sse_upgrade::<C>))
350 .route("/rps/ack", post(http_ack::<C>))
351 .with_state(self)
352 }
353
354 pub fn connected_clients(&self) -> usize {
356 self.connections.len()
357 }
358
359 pub fn connected_client_ids(&self) -> Vec<Uuid> {
361 self.connections.iter().map(|entry| *entry.key()).collect()
362 }
363
364 pub fn register_handler<F>(&self, channel: C, handler: F)
368 where
369 F: Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync + 'static,
370 {
371 self.channel_handlers.insert(channel, Arc::new(handler));
372 }
373
374 fn stamp_frame(&self, replay: &ClientReplay<C>, frame: Frame<C>) -> Frame<C> {
375 let cursor = self.next_cursor(frame.channel, frame.cursor);
376 replay.channel(frame.channel).state().mark_sent(cursor);
377 frame.with_cursor(cursor)
378 }
379
380 fn next_cursor(&self, channel: C, existing: Option<u64>) -> u64 {
381 let counter = self
382 .channel_cursors
383 .get(&channel)
384 .map(|c| c.clone())
385 .unwrap_or_else(|| {
386 let fresh = Arc::new(AtomicU64::new(0));
387 self.channel_cursors.insert(channel, fresh.clone());
388 fresh
389 });
390
391 let cursor = existing.unwrap_or_else(|| counter.fetch_add(1, Ordering::SeqCst) + 1);
392 let _ = counter.fetch_max(cursor, Ordering::SeqCst);
393 cursor
394 }
395
396 pub async fn upgrade(self: Arc<Self>, ws: WebSocketUpgrade) -> impl IntoResponse {
398 ws.on_upgrade(move |socket| async move {
399 if let Err(err) = self.handle_socket(socket).await {
400 warn!(?err, "RPS websocket closed with error");
401 }
402 })
403 }
404
405 pub fn send(&self, client_id: Uuid, frame: Frame<C>) -> Result<(), SendError> {
407 let replay = self
408 .client_replay
409 .entry(client_id)
410 .or_insert_with(|| Arc::new(ClientReplay::default()))
411 .clone();
412 match self.connections.get(&client_id) {
413 Some(conn) => {
414 let stamped = self.stamp_frame(&replay, frame).with_client(client_id);
415 replay
416 .channel(stamped.channel)
417 .push(&stamped, REPLAY_BUFFER);
418 self.enqueue_outbound(conn.value(), Outbound::Frame(stamped.clone()), client_id)?;
419
420 if let Some(sse) = self.sse_connections.get(&client_id)
421 && sse.allowed_channels.contains(&stamped.channel)
422 && let Err(err) = sse.sender.try_send(stamped)
423 {
424 warn!(?client_id, ?err, "dropping SSE frame (buffer full?)");
425 self.sse_connections.remove(&client_id);
426 }
427
428 Ok(())
429 }
430 None => Err(SendError::NotConnected(client_id)),
431 }
432 }
433
434 pub fn send_binary(
436 &self,
437 client_id: Uuid,
438 channel: C,
439 bytes: &[u8],
440 mime: &str,
441 name: Option<&str>,
442 pointer_url: Option<&str>,
443 ) -> Result<(), SendError> {
444 if !ALLOWED_BINARY_MIME
445 .iter()
446 .any(|m| m.eq_ignore_ascii_case(mime))
447 {
448 return Err(SendError::Rejected(format!(
449 "mime type {mime} not permitted"
450 )));
451 }
452
453 let sha256 = Self::sha256_hex(bytes);
454 let size = bytes.len() as u64;
455 let envelope = if bytes.len() <= BINARY_INLINE_LIMIT {
456 BinaryEnvelope::Inline {
457 mime: mime.to_string(),
458 sha256,
459 size,
460 data_base64: general_purpose::STANDARD.encode(bytes),
461 name: name.map(|s| s.to_string()),
462 }
463 } else if let Some(url) = pointer_url {
464 BinaryEnvelope::Pointer {
465 mime: mime.to_string(),
466 sha256,
467 size,
468 url: url.to_string(),
469 name: name.map(|s| s.to_string()),
470 }
471 } else {
472 return Err(SendError::Rejected(format!(
473 "payload size {} exceeds inline limit {} and no pointer_url provided",
474 bytes.len(),
475 BINARY_INLINE_LIMIT
476 )));
477 };
478
479 let payload =
480 serde_json::to_value(envelope).map_err(|e| SendError::Serialization(e.to_string()))?;
481 self.send(client_id, Frame::new(channel, payload))
482 }
483
484 pub fn send_system(&self, client_id: Uuid, op: SystemOp<C>) {
486 self.enqueue_system(client_id, op);
487 }
488
489 async fn handle_socket(self: Arc<Self>, socket: WebSocket) -> Result<()> {
494 let (mut ws_tx, mut ws_rx) = socket.split();
495
496 let first = futures_util::StreamExt::next(&mut ws_rx)
497 .await
498 .ok_or_else(|| anyhow!("connection closed before auth"))?;
499 let first = first.context("failed to read first RPS frame")?;
500 let auth: SystemOp<C> = match first {
501 Message::Text(text) => {
502 serde_json::from_str(&text).context("failed to parse auth frame")?
503 }
504 Message::Binary(bytes) => {
505 serde_json::from_slice(&bytes).context("failed to parse binary auth frame")?
506 }
507 other => anyhow::bail!("expected auth frame as text, got {other:?}"),
508 };
509
510 let (client_id, capabilities, token, resume_cursor, resume_cursors) = match auth {
511 SystemOp::Auth {
512 client_id,
513 capabilities,
514 token,
515 resume_cursor,
516 resume_cursors,
517 ..
518 } => (
519 client_id,
520 capabilities,
521 token,
522 resume_cursor,
523 resume_cursors,
524 ),
525 other => anyhow::bail!("first RPS frame must be auth, got {other:?}"),
526 };
527
528 if let Err(err) = (self.auth_validator)(client_id, token.as_deref(), &capabilities) {
529 let reason = match &err {
530 AuthError::InvalidToken => "invalid token",
531 AuthError::Forbidden => "capabilities not permitted",
532 AuthError::Other(msg) => msg.as_str(),
533 }
534 .to_string();
535 let _ = ws_tx
536 .send(Message::Close(Some(CloseFrame {
537 code: 1008,
538 reason: reason.into(),
539 })))
540 .await;
541 return Err(anyhow!(err));
542 }
543
544 let (tx, mut rx) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
545 let (q_high, mut rx_high) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
546 let (q_norm, mut rx_norm) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
547 let (q_low, mut rx_low) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
548 let depth_high = Arc::new(AtomicUsize::new(0));
549 let depth_normal = Arc::new(AtomicUsize::new(0));
550 let depth_low = Arc::new(AtomicUsize::new(0));
551
552 let replay = self
553 .client_replay
554 .entry(client_id)
555 .or_insert_with(|| Arc::new(ClientReplay::default()))
556 .clone();
557
558 let allowed_init = {
559 let set = DashSet::new();
560 for ch in &capabilities {
561 set.insert(*ch);
562 }
563 set
564 };
565
566 if let Some(_old) = self.connections.insert(
567 client_id,
568 ConnectionHandle {
569 sender: tx.clone(),
570 queue_high: q_high.clone(),
571 queue_normal: q_norm.clone(),
572 queue_low: q_low.clone(),
573 depth_high: depth_high.clone(),
574 depth_normal: depth_normal.clone(),
575 depth_low: depth_low.clone(),
576 capabilities,
577 token,
578 created_at: Instant::now(),
579 replay: replay.clone(),
580 allowed_channels: allowed_init,
581 },
582 ) {
583 warn!(?client_id, "replacing existing RPS connection for client");
584 }
585
586 let resume_snapshot = replay.resume_state();
587 let resume_cursor_reply = resume_snapshot
588 .values()
589 .copied()
590 .max()
591 .unwrap_or(DEFAULT_RESUME_CURSOR);
592 ws_tx
593 .send(Message::Text(
594 serde_json::to_string(&SystemOp::<C>::AuthOk {
595 resume_cursor: resume_cursor_reply,
596 resume_cursors: resume_snapshot.clone(),
597 })
598 .context("serialize auth_ok")?,
599 ))
600 .await
601 .map_err(anyhow::Error::new)?;
602
603 let mut requested = resume_cursors;
605 if let Some(global) = resume_cursor {
606 for channel in C::all() {
607 requested.entry(*channel).or_insert(global);
608 }
609 }
610
611 for entry in replay.channels.iter() {
612 let channel = *entry.key();
613 let channel_replay = entry.value();
614 let from = requested
615 .get(&channel)
616 .copied()
617 .unwrap_or(DEFAULT_RESUME_CURSOR);
618 match channel_replay.replay_from(from) {
619 ReplayOutcome::Frames(frames) => {
620 for frame in frames {
621 if let Err(err) = tx.try_send(Outbound::Frame(frame)) {
622 warn!(?client_id, ?err, "failed to enqueue replay frame");
623 break;
624 }
625 }
626 }
627 ReplayOutcome::Gap { buffer_floor } => {
628 self.enqueue_system(
629 client_id,
630 SystemOp::ResumeRequired {
631 channel,
632 from_cursor: buffer_floor,
633 },
634 );
635 }
636 }
637 }
638
639 let writer = tokio::spawn(async move {
640 loop {
644 tokio::select! {
645 biased;
646 Some(item) = rx_high.recv() => {
647 depth_high.fetch_sub(1, Ordering::SeqCst);
648 let message = item.into_message().context("serialize prio-high RPS")?;
649 ws_tx.send(message).await.map_err(anyhow::Error::new)?;
650 }
651 Some(item) = rx_norm.recv() => {
652 depth_normal.fetch_sub(1, Ordering::SeqCst);
653 let message = item.into_message().context("serialize prio-norm RPS")?;
654 ws_tx.send(message).await.map_err(anyhow::Error::new)?;
655 }
656 Some(item) = rx_low.recv() => {
657 depth_low.fetch_sub(1, Ordering::SeqCst);
658 let message = item.into_message().context("serialize prio-low RPS")?;
659 ws_tx.send(message).await.map_err(anyhow::Error::new)?;
660 }
661 result = rx.recv() => {
662 match result {
663 Some(outbound) => {
664 let message = outbound
665 .into_message()
666 .context("serialize outbound RPS message")?;
667 ws_tx.send(message).await.map_err(anyhow::Error::new)?;
668 }
669 None => break,
670 }
671 }
672 }
673 }
674
675 Ok::<(), anyhow::Error>(())
676 });
677
678 let reader = {
679 let server = self.clone();
680 let tx = tx.clone();
681
682 tokio::spawn(async move {
683 while let Some(incoming) = futures_util::StreamExt::next(&mut ws_rx).await {
684 match incoming {
685 Ok(Message::Text(text)) => match serde_json::from_str::<Frame<C>>(&text) {
686 Ok(frame) => server.handle_incoming(client_id, frame).await,
687 Err(err) => {
688 warn!(?err, "invalid RPS frame from client");
689 server.enqueue_system(
690 client_id,
691 SystemOp::Error {
692 message: "invalid frame schema".into(),
693 },
694 );
695 }
696 },
697 Ok(Message::Binary(_)) => {
698 warn!("ignoring binary RPS frame");
699 }
700 Ok(Message::Ping(payload)) => {
701 let _ = tx.send(Outbound::Raw(Message::Pong(payload))).await;
702 }
703 Ok(Message::Pong(_)) => {}
704 Ok(Message::Close(_)) => break,
705 Err(err) => return Err(anyhow::Error::new(err)),
706 }
707 }
708
709 Ok::<(), anyhow::Error>(())
710 })
711 };
712
713 let result = tokio::try_join!(writer, reader);
714
715 self.connections.remove(&client_id);
716
717 result.map(|_| ()).map_err(anyhow::Error::new)
718 }
719
720 async fn handle_incoming(&self, client_id: Uuid, frame: Frame<C>) {
725 if let Err(msg) = validate_frame(&frame) {
726 self.enqueue_system(
727 client_id,
728 SystemOp::Error {
729 message: msg.to_string(),
730 },
731 );
732 return;
733 }
734
735 if frame.channel.is_system()
736 && let Some(conn) = self.connections.get(&client_id)
737 {
738 conn.replay
739 .channel(frame.channel)
740 .push(&frame, REPLAY_BUFFER);
741 }
742
743 let replay = self.client_replay.get(&client_id).map(|c| c.clone());
744
745 if frame.channel.is_system() {
746 match serde_json::from_value::<SystemOp<C>>(frame.payload.clone()) {
747 Ok(SystemOp::Ping) => self.enqueue_system(client_id, SystemOp::Pong),
748 Ok(SystemOp::Slow { window }) => {
749 debug!(?client_id, ?window, "client reported backpressure window");
750 }
751 Ok(SystemOp::Ack { channel, cursor }) => {
752 self.handle_ack(client_id, channel, cursor, replay.as_deref());
753 }
754 Ok(SystemOp::ResumeRequired { .. }) => {
755 debug!(?client_id, "client reported resume_required; ignoring");
756 }
757 Ok(SystemOp::Subscribe { channels }) => {
758 if let Some(conn) = self.connections.get(&client_id) {
759 for ch in channels {
760 conn.allowed_channels.insert(ch);
761 }
762 }
763 }
764 Ok(SystemOp::Unsubscribe { channels }) => {
765 if let Some(conn) = self.connections.get(&client_id) {
766 for ch in channels {
767 conn.allowed_channels.remove(&ch);
768 }
769 }
770 }
771 Ok(SystemOp::Health { status, detail }) => {
772 debug!(?client_id, ?status, ?detail, "client reported health");
773 }
774 Ok(SystemOp::Features {
775 supported,
776 requested,
777 }) => {
778 debug!(?client_id, ?supported, ?requested, "client features");
779 }
780 Ok(SystemOp::Goodbye { reason }) => {
781 debug!(?client_id, ?reason, "client goodbye");
782 self.enqueue_system(client_id, SystemOp::Goodbye { reason });
783 }
784 Ok(other) => {
785 debug!(?client_id, ?other, "received system message");
786 }
787 Err(err) => {
788 warn!(?err, "invalid system payload");
789 self.enqueue_system(
790 client_id,
791 SystemOp::Error {
792 message: "invalid system payload".into(),
793 },
794 );
795 }
796 }
797 } else {
798 let channel = frame.channel;
799
800 if let Some(conn) = self.connections.get(&client_id)
802 && !conn.allowed_channels.contains(&channel)
803 {
804 self.enqueue_system(
805 client_id,
806 SystemOp::Error {
807 message: format!("channel {} not subscribed", channel.name()),
808 },
809 );
810 return;
811 }
812
813 if let Some(handler) = self.channel_handlers.get(&channel) {
815 (handler.value())(client_id, frame.clone(), self);
816 } else {
817 self.enqueue_system(
818 client_id,
819 SystemOp::Error {
820 message: format!("no handler for channel {}", channel.name()),
821 },
822 );
823 }
824
825 debug!(
826 ?client_id,
827 channel = channel.name(),
828 cursor = ?frame.cursor,
829 "received RPS frame"
830 );
831 }
832 }
833
834 fn handle_ack(
835 &self,
836 client_id: Uuid,
837 channel: C,
838 cursor: u64,
839 replay: Option<&ClientReplay<C>>,
840 ) {
841 let Some(replay) = replay else {
842 warn!(?client_id, "ack from unknown client");
843 return;
844 };
845
846 let channel_replay = replay.channel(channel);
847 let state = channel_replay.state();
848 let last_sent = state.last_sent();
849 let buffer_floor = state.buffer_floor();
850
851 if cursor < buffer_floor {
852 self.enqueue_system(
853 client_id,
854 SystemOp::ResumeRequired {
855 channel,
856 from_cursor: buffer_floor,
857 },
858 );
859 return;
860 }
861
862 if cursor > last_sent {
863 self.enqueue_system(
864 client_id,
865 SystemOp::ResumeRequired {
866 channel,
867 from_cursor: last_sent,
868 },
869 );
870 return;
871 }
872
873 channel_replay.ack(cursor);
874 }
875
876 fn enqueue_system(&self, client_id: Uuid, op: SystemOp<C>) {
877 if let Some(conn) = self.connections.get(&client_id) {
878 let _ = self.enqueue_outbound(conn.value(), Outbound::System(op), client_id);
879 } else {
880 warn!(?client_id, "ignoring system send for unknown client");
881 }
882 }
883
884 fn enqueue_outbound(
885 &self,
886 conn: &ConnectionHandle<C>,
887 outbound: Outbound<C>,
888 client_id: Uuid,
889 ) -> Result<(), SendError> {
890 let prio = outbound.priority();
891 let (target, depth) = match prio {
892 PRIORITY_HIGH => (&conn.queue_high, &conn.depth_high),
893 PRIORITY_LOW => (&conn.queue_low, &conn.depth_low),
894 _ => (&conn.queue_normal, &conn.depth_normal),
895 };
896
897 let depth_now = depth.fetch_add(1, Ordering::SeqCst) + 1;
898 if depth_now > QUEUE_WARN_THRESHOLD {
899 debug!(
900 ?client_id,
901 ?prio,
902 depth = depth_now,
903 "send queue depth high"
904 );
905 }
906
907 if depth_now > OUTBOUND_BUFFER {
908 depth.fetch_sub(1, Ordering::SeqCst);
909 if prio == PRIORITY_LOW {
910 warn!(
911 ?client_id,
912 ?prio,
913 "dropping low-priority frame (queue full)"
914 );
915 return Ok(());
916 } else {
917 warn!(
918 ?client_id,
919 ?prio,
920 "send queue overflow; treating as backpressure"
921 );
922 return Err(SendError::Backpressure(client_id));
923 }
924 }
925
926 match target.try_send(outbound) {
927 Ok(_) => Ok(()),
928 Err(mpsc::error::TrySendError::Full(_)) => {
929 depth.fetch_sub(1, Ordering::SeqCst);
930 if prio == PRIORITY_LOW {
931 warn!(
932 ?client_id,
933 ?prio,
934 "dropping low-priority frame (queue full)"
935 );
936 Ok(())
937 } else {
938 Err(SendError::Backpressure(client_id))
939 }
940 }
941 Err(mpsc::error::TrySendError::Closed(_)) => {
942 depth.fetch_sub(1, Ordering::SeqCst);
943 Err(SendError::NotConnected(client_id))
944 }
945 }
946 }
947}
948
949async fn ws_upgrade<C: ChannelKind>(
954 State(server): State<Arc<PushServer<C>>>,
955 ws: WebSocketUpgrade,
956) -> impl IntoResponse {
957 server.upgrade(ws).await
958}
959
960#[derive(Debug, Deserialize)]
961struct SseParams {
962 client_id: Uuid,
963 #[serde(default)]
964 token: Option<String>,
965 #[serde(default)]
966 capabilities: Option<String>,
967 #[serde(default)]
968 channels: Option<String>,
969 #[serde(default)]
970 resume_cursor: Option<u64>,
971}
972
973async fn sse_upgrade<C: ChannelKind>(
974 State(server): State<Arc<PushServer<C>>>,
975 Query(params): Query<SseParams>,
976) -> Result<impl IntoResponse, StatusCode> {
977 let client_id = params.client_id;
978 let capabilities = parse_channels::<C>(params.capabilities.as_deref());
979 let subscribe = parse_channels::<C>(params.channels.as_deref());
980
981 if let Err(_err) = (server.auth_validator)(client_id, params.token.as_deref(), &capabilities) {
982 return Err(StatusCode::UNAUTHORIZED);
983 }
984
985 let replay = server
986 .client_replay
987 .entry(client_id)
988 .or_insert_with(|| Arc::new(ClientReplay::default()))
989 .clone();
990
991 let allowed = {
992 let set = DashSet::new();
993 if !subscribe.is_empty() {
994 for ch in subscribe {
995 set.insert(ch);
996 }
997 } else if !capabilities.is_empty() {
998 for ch in capabilities.clone() {
999 set.insert(ch);
1000 }
1001 } else {
1002 for ch in C::all() {
1003 set.insert(*ch);
1004 }
1005 }
1006 set
1007 };
1008
1009 let (tx, rx) = mpsc::channel::<Frame<C>>(OUTBOUND_BUFFER);
1010
1011 server.sse_connections.insert(
1012 client_id,
1013 SseHandle {
1014 sender: tx.clone(),
1015 allowed_channels: allowed.clone(),
1016 replay: replay.clone(),
1017 },
1018 );
1019
1020 let snapshot = replay.resume_state();
1022 let resume_cursor = snapshot
1023 .values()
1024 .copied()
1025 .max()
1026 .unwrap_or(DEFAULT_RESUME_CURSOR);
1027
1028 let system_channel = C::all()
1029 .iter()
1030 .find(|c| c.is_system())
1031 .copied()
1032 .expect("ChannelKind must have a system channel");
1033
1034 let auth_ok = Frame::new(
1035 system_channel,
1036 serde_json::to_value(SystemOp::<C>::AuthOk {
1037 resume_cursor,
1038 resume_cursors: snapshot.clone(),
1039 })
1040 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
1041 )
1042 .with_client(client_id);
1043 let _ = tx.try_send(auth_ok);
1044
1045 if let Some(from) = params.resume_cursor {
1047 for entry in replay.channels.iter() {
1048 let channel = *entry.key();
1049 if !allowed.contains(&channel) {
1050 continue;
1051 }
1052 match entry.value().replay_from(from) {
1053 ReplayOutcome::Frames(frames) => {
1054 for frame in frames {
1055 let _ = tx.try_send(frame);
1056 }
1057 }
1058 ReplayOutcome::Gap { buffer_floor } => {
1059 let _ = tx.try_send(
1060 Frame::new(
1061 system_channel,
1062 serde_json::to_value(SystemOp::<C>::ResumeRequired {
1063 channel,
1064 from_cursor: buffer_floor,
1065 })
1066 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
1067 )
1068 .with_client(client_id),
1069 );
1070 }
1071 }
1072 }
1073 }
1074
1075 let stream = futures_util::StreamExt::map(
1076 ReceiverStream::new(rx),
1077 |frame| -> Result<Event, Infallible> {
1078 let id = frame
1079 .cursor
1080 .map(|c| c.to_string())
1081 .unwrap_or_else(|| "0".into());
1082 let event = match serde_json::to_string(&frame) {
1083 Ok(json) => Event::default().event("frame").id(id).data(json),
1084 Err(err) => {
1085 warn!(?err, "failed to serialize SSE frame");
1086 Event::default().event("error").data("serialize_failed")
1087 }
1088 };
1089 Ok(event)
1090 },
1091 );
1092
1093 Ok(Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
1094}
1095
1096#[derive(Debug, Deserialize)]
1097#[serde(bound(deserialize = "C: ChannelKind"))]
1098struct AckBody<C: ChannelKind> {
1099 client_id: Uuid,
1100 channel: C,
1101 cursor: u64,
1102}
1103
1104async fn http_ack<C: ChannelKind>(
1105 State(server): State<Arc<PushServer<C>>>,
1106 Json(body): Json<AckBody<C>>,
1107) -> impl IntoResponse {
1108 server.handle_ack(body.client_id, body.channel, body.cursor, None);
1109 axum::http::StatusCode::NO_CONTENT
1110}
1111
1112fn parse_channels<C: ChannelKind>(raw: Option<&str>) -> Vec<C> {
1117 raw.map(|list| {
1118 list.split(',')
1119 .filter_map(|s| C::from_name(s.trim()))
1120 .collect()
1121 })
1122 .unwrap_or_default()
1123}
1124
1125fn validate_frame<C: ChannelKind>(frame: &Frame<C>) -> Result<(), &'static str> {
1126 if frame.payload.is_null() {
1127 return Err("payload required");
1128 }
1129 Ok(())
1130}