use std::collections::{HashMap, VecDeque};
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Instant;
use anyhow::{Context, Result, anyhow};
use axum::{
Router,
extract::{
Json, Query, State,
ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::{IntoResponse, sse::Event, sse::Sse},
routing::{get, post},
};
use base64::{Engine as _, engine::general_purpose};
use dashmap::{DashMap, DashSet};
use futures_util::{SinkExt, StreamExt};
use pushwire_core::{BinaryEnvelope, ChannelKind, Frame, SystemOp};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, warn};
use uuid::Uuid;
const DEFAULT_RESUME_CURSOR: u64 = 0;
const OUTBOUND_BUFFER: usize = 64;
const REPLAY_BUFFER: usize = 256;
const BINARY_INLINE_LIMIT: usize = 256 * 1024;
const ALLOWED_BINARY_MIME: &[&str] = &["image/png", "image/jpeg", "image/webp", "image/gif"];
const QUEUE_WARN_THRESHOLD: usize = OUTBOUND_BUFFER / 2;
const PRIORITY_HIGH: u8 = 0;
const PRIORITY_NORMAL: u8 = 1;
const PRIORITY_LOW: u8 = 2;
pub type ChannelHandler<C> = Arc<dyn Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync>;
pub type AuthValidator<C> =
Arc<dyn Fn(Uuid, Option<&str>, &[C]) -> Result<(), AuthError> + Send + Sync>;
#[derive(Debug)]
#[allow(dead_code)]
struct ConnectionHandle<C: ChannelKind> {
sender: mpsc::Sender<Outbound<C>>,
queue_high: mpsc::Sender<Outbound<C>>,
queue_normal: mpsc::Sender<Outbound<C>>,
queue_low: mpsc::Sender<Outbound<C>>,
depth_high: Arc<AtomicUsize>,
depth_normal: Arc<AtomicUsize>,
depth_low: Arc<AtomicUsize>,
capabilities: Vec<C>,
token: Option<String>,
created_at: Instant,
replay: Arc<ClientReplay<C>>,
allowed_channels: DashSet<C>,
}
#[derive(Debug)]
#[allow(dead_code)]
struct SseHandle<C: ChannelKind> {
sender: mpsc::Sender<Frame<C>>,
allowed_channels: DashSet<C>,
replay: Arc<ClientReplay<C>>,
}
#[derive(Debug)]
#[allow(dead_code)]
enum Outbound<C: ChannelKind> {
Frame(Frame<C>),
System(SystemOp<C>),
Raw(Message),
Priority {
priority: u8,
inner: Box<Outbound<C>>,
},
}
impl<C: ChannelKind> Outbound<C> {
fn into_message(self) -> serde_json::Result<Message> {
match self {
Outbound::Frame(frame) => serde_json::to_string(&frame).map(Message::Text),
Outbound::System(op) => serde_json::to_string(&op).map(Message::Text),
Outbound::Raw(message) => Ok(message),
Outbound::Priority { inner, .. } => inner.into_message(),
}
}
fn priority(&self) -> u8 {
match self {
Outbound::Priority { priority, .. } => *priority,
Outbound::System(_) => PRIORITY_HIGH,
Outbound::Frame(frame) => frame.channel.priority(),
Outbound::Raw(_) => PRIORITY_NORMAL,
}
}
}
#[derive(Debug, Default)]
struct ChannelCursorState {
last_sent: AtomicU64,
last_acked: AtomicU64,
buffer_floor: AtomicU64,
}
impl ChannelCursorState {
fn mark_sent(&self, cursor: u64) {
let _ = self.last_sent.fetch_max(cursor, Ordering::SeqCst);
let _ = self
.buffer_floor
.fetch_max(self.last_acked(), Ordering::SeqCst);
}
fn mark_acked(&self, cursor: u64) {
let mut current = self.last_acked.load(Ordering::SeqCst);
while cursor > current {
match self.last_acked.compare_exchange(
current,
cursor,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
}
let _ = self.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
}
fn last_sent(&self) -> u64 {
self.last_sent.load(Ordering::SeqCst)
}
fn last_acked(&self) -> u64 {
self.last_acked.load(Ordering::SeqCst)
}
fn buffer_floor(&self) -> u64 {
self.buffer_floor.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
struct ChannelReplay<C: ChannelKind> {
state: Arc<ChannelCursorState>,
buffer: Mutex<VecDeque<Frame<C>>>,
}
impl<C: ChannelKind> ChannelReplay<C> {
fn new() -> Self {
Self {
state: Arc::new(ChannelCursorState::default()),
buffer: Mutex::new(VecDeque::new()),
}
}
fn state(&self) -> Arc<ChannelCursorState> {
self.state.clone()
}
fn push(&self, frame: &Frame<C>, limit: usize) {
let mut buffer = self.buffer.lock().unwrap();
buffer.push_back(frame.clone());
while buffer.len() > limit {
if let Some(dropped) = buffer.pop_front()
&& let Some(cursor) = dropped.cursor
{
self.state.buffer_floor.store(cursor, Ordering::SeqCst);
}
}
}
fn ack(&self, cursor: u64) {
self.state.mark_acked(cursor);
let mut buffer = self.buffer.lock().unwrap();
while buffer
.front()
.and_then(|f| f.cursor)
.map(|c| c <= cursor)
.unwrap_or(false)
{
buffer.pop_front();
}
let _ = self.state.buffer_floor.fetch_max(cursor, Ordering::SeqCst);
}
fn replay_from(&self, from: u64) -> ReplayOutcome<C> {
let floor = self.state.buffer_floor();
if from < floor {
return ReplayOutcome::Gap {
buffer_floor: floor,
};
}
let min_cursor = self.state.last_acked().max(from);
let buffer = self.buffer.lock().unwrap();
let frames: Vec<Frame<C>> = buffer
.iter()
.filter(|f| f.cursor.map(|c| c > min_cursor).unwrap_or(false))
.cloned()
.collect();
ReplayOutcome::Frames(frames)
}
}
#[derive(Debug)]
enum ReplayOutcome<C: ChannelKind> {
Frames(Vec<Frame<C>>),
Gap { buffer_floor: u64 },
}
#[derive(Debug)]
struct ClientReplay<C: ChannelKind> {
channels: DashMap<C, Arc<ChannelReplay<C>>>,
}
impl<C: ChannelKind> Default for ClientReplay<C> {
fn default() -> Self {
Self {
channels: DashMap::new(),
}
}
}
impl<C: ChannelKind> ClientReplay<C> {
fn channel(&self, channel: C) -> Arc<ChannelReplay<C>> {
self.channels
.entry(channel)
.or_insert_with(|| Arc::new(ChannelReplay::new()))
.clone()
}
fn resume_state(&self) -> HashMap<C, u64> {
self.channels
.iter()
.map(|entry| (*entry.key(), entry.value().state.last_acked()))
.collect()
}
}
#[derive(Debug, thiserror::Error)]
pub enum SendError {
#[error("client {0} not connected")]
NotConnected(Uuid),
#[error("send buffer full for client {0}")]
Backpressure(Uuid),
#[error("payload rejected: {0}")]
Rejected(String),
#[error("payload serialization error: {0}")]
Serialization(String),
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("invalid token")]
InvalidToken,
#[error("capabilities not permitted")]
Forbidden,
#[error("{0}")]
Other(String),
}
pub struct PushServer<C: ChannelKind> {
connections: DashMap<Uuid, ConnectionHandle<C>>,
sse_connections: DashMap<Uuid, SseHandle<C>>,
channel_cursors: DashMap<C, Arc<AtomicU64>>,
client_replay: DashMap<Uuid, Arc<ClientReplay<C>>>,
channel_handlers: DashMap<C, ChannelHandler<C>>,
auth_validator: AuthValidator<C>,
}
impl<C: ChannelKind> Default for PushServer<C> {
fn default() -> Self {
Self::new()
}
}
impl<C: ChannelKind> PushServer<C> {
pub fn new() -> Self {
Self::with_auth_validator(Arc::new(|_, _, _| Ok(())))
}
pub fn with_auth_validator(auth_validator: AuthValidator<C>) -> Self {
let counters = DashMap::new();
for channel in C::all() {
counters.insert(*channel, Arc::new(AtomicU64::new(0)));
}
Self {
connections: DashMap::new(),
sse_connections: DashMap::new(),
channel_cursors: counters,
client_replay: DashMap::new(),
channel_handlers: DashMap::new(),
auth_validator,
}
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
let digest = hasher.finalize();
digest.iter().map(|b| format!("{:02x}", b)).collect()
}
pub fn router(self: Arc<Self>) -> Router<Arc<Self>> {
Router::new()
.route("/rps", get(ws_upgrade::<C>))
.route("/rps/sse", get(sse_upgrade::<C>))
.route("/rps/ack", post(http_ack::<C>))
.with_state(self)
}
pub fn connected_clients(&self) -> usize {
self.connections.len()
}
pub fn connected_client_ids(&self) -> Vec<Uuid> {
self.connections.iter().map(|entry| *entry.key()).collect()
}
pub fn register_handler<F>(&self, channel: C, handler: F)
where
F: Fn(Uuid, Frame<C>, &PushServer<C>) + Send + Sync + 'static,
{
self.channel_handlers.insert(channel, Arc::new(handler));
}
fn stamp_frame(&self, replay: &ClientReplay<C>, frame: Frame<C>) -> Frame<C> {
let cursor = self.next_cursor(frame.channel, frame.cursor);
replay.channel(frame.channel).state().mark_sent(cursor);
frame.with_cursor(cursor)
}
fn next_cursor(&self, channel: C, existing: Option<u64>) -> u64 {
let counter = self
.channel_cursors
.get(&channel)
.map(|c| c.clone())
.unwrap_or_else(|| {
let fresh = Arc::new(AtomicU64::new(0));
self.channel_cursors.insert(channel, fresh.clone());
fresh
});
let cursor = existing.unwrap_or_else(|| counter.fetch_add(1, Ordering::SeqCst) + 1);
let _ = counter.fetch_max(cursor, Ordering::SeqCst);
cursor
}
pub async fn upgrade(self: Arc<Self>, ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(move |socket| async move {
if let Err(err) = self.handle_socket(socket).await {
warn!(?err, "RPS websocket closed with error");
}
})
}
pub fn send(&self, client_id: Uuid, frame: Frame<C>) -> Result<(), SendError> {
let replay = self
.client_replay
.entry(client_id)
.or_insert_with(|| Arc::new(ClientReplay::default()))
.clone();
match self.connections.get(&client_id) {
Some(conn) => {
let stamped = self.stamp_frame(&replay, frame).with_client(client_id);
replay
.channel(stamped.channel)
.push(&stamped, REPLAY_BUFFER);
self.enqueue_outbound(conn.value(), Outbound::Frame(stamped.clone()), client_id)?;
if let Some(sse) = self.sse_connections.get(&client_id)
&& sse.allowed_channels.contains(&stamped.channel)
&& let Err(err) = sse.sender.try_send(stamped)
{
warn!(?client_id, ?err, "dropping SSE frame (buffer full?)");
self.sse_connections.remove(&client_id);
}
Ok(())
}
None => Err(SendError::NotConnected(client_id)),
}
}
pub fn send_binary(
&self,
client_id: Uuid,
channel: C,
bytes: &[u8],
mime: &str,
name: Option<&str>,
pointer_url: Option<&str>,
) -> Result<(), SendError> {
if !ALLOWED_BINARY_MIME
.iter()
.any(|m| m.eq_ignore_ascii_case(mime))
{
return Err(SendError::Rejected(format!(
"mime type {mime} not permitted"
)));
}
let sha256 = Self::sha256_hex(bytes);
let size = bytes.len() as u64;
let envelope = if bytes.len() <= BINARY_INLINE_LIMIT {
BinaryEnvelope::Inline {
mime: mime.to_string(),
sha256,
size,
data_base64: general_purpose::STANDARD.encode(bytes),
name: name.map(|s| s.to_string()),
}
} else if let Some(url) = pointer_url {
BinaryEnvelope::Pointer {
mime: mime.to_string(),
sha256,
size,
url: url.to_string(),
name: name.map(|s| s.to_string()),
}
} else {
return Err(SendError::Rejected(format!(
"payload size {} exceeds inline limit {} and no pointer_url provided",
bytes.len(),
BINARY_INLINE_LIMIT
)));
};
let payload =
serde_json::to_value(envelope).map_err(|e| SendError::Serialization(e.to_string()))?;
self.send(client_id, Frame::new(channel, payload))
}
pub fn send_system(&self, client_id: Uuid, op: SystemOp<C>) {
self.enqueue_system(client_id, op);
}
async fn handle_socket(self: Arc<Self>, socket: WebSocket) -> Result<()> {
let (mut ws_tx, mut ws_rx) = socket.split();
let first = futures_util::StreamExt::next(&mut ws_rx)
.await
.ok_or_else(|| anyhow!("connection closed before auth"))?;
let first = first.context("failed to read first RPS frame")?;
let auth: SystemOp<C> = match first {
Message::Text(text) => {
serde_json::from_str(&text).context("failed to parse auth frame")?
}
Message::Binary(bytes) => {
serde_json::from_slice(&bytes).context("failed to parse binary auth frame")?
}
other => anyhow::bail!("expected auth frame as text, got {other:?}"),
};
let (client_id, capabilities, token, resume_cursor, resume_cursors) = match auth {
SystemOp::Auth {
client_id,
capabilities,
token,
resume_cursor,
resume_cursors,
..
} => (
client_id,
capabilities,
token,
resume_cursor,
resume_cursors,
),
other => anyhow::bail!("first RPS frame must be auth, got {other:?}"),
};
if let Err(err) = (self.auth_validator)(client_id, token.as_deref(), &capabilities) {
let reason = match &err {
AuthError::InvalidToken => "invalid token",
AuthError::Forbidden => "capabilities not permitted",
AuthError::Other(msg) => msg.as_str(),
}
.to_string();
let _ = ws_tx
.send(Message::Close(Some(CloseFrame {
code: 1008,
reason: reason.into(),
})))
.await;
return Err(anyhow!(err));
}
let (tx, mut rx) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
let (q_high, mut rx_high) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
let (q_norm, mut rx_norm) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
let (q_low, mut rx_low) = mpsc::channel::<Outbound<C>>(OUTBOUND_BUFFER);
let depth_high = Arc::new(AtomicUsize::new(0));
let depth_normal = Arc::new(AtomicUsize::new(0));
let depth_low = Arc::new(AtomicUsize::new(0));
let replay = self
.client_replay
.entry(client_id)
.or_insert_with(|| Arc::new(ClientReplay::default()))
.clone();
let allowed_init = {
let set = DashSet::new();
for ch in &capabilities {
set.insert(*ch);
}
set
};
if let Some(_old) = self.connections.insert(
client_id,
ConnectionHandle {
sender: tx.clone(),
queue_high: q_high.clone(),
queue_normal: q_norm.clone(),
queue_low: q_low.clone(),
depth_high: depth_high.clone(),
depth_normal: depth_normal.clone(),
depth_low: depth_low.clone(),
capabilities,
token,
created_at: Instant::now(),
replay: replay.clone(),
allowed_channels: allowed_init,
},
) {
warn!(?client_id, "replacing existing RPS connection for client");
}
let resume_snapshot = replay.resume_state();
let resume_cursor_reply = resume_snapshot
.values()
.copied()
.max()
.unwrap_or(DEFAULT_RESUME_CURSOR);
ws_tx
.send(Message::Text(
serde_json::to_string(&SystemOp::<C>::AuthOk {
resume_cursor: resume_cursor_reply,
resume_cursors: resume_snapshot.clone(),
})
.context("serialize auth_ok")?,
))
.await
.map_err(anyhow::Error::new)?;
let mut requested = resume_cursors;
if let Some(global) = resume_cursor {
for channel in C::all() {
requested.entry(*channel).or_insert(global);
}
}
for entry in replay.channels.iter() {
let channel = *entry.key();
let channel_replay = entry.value();
let from = requested
.get(&channel)
.copied()
.unwrap_or(DEFAULT_RESUME_CURSOR);
match channel_replay.replay_from(from) {
ReplayOutcome::Frames(frames) => {
for frame in frames {
if let Err(err) = tx.try_send(Outbound::Frame(frame)) {
warn!(?client_id, ?err, "failed to enqueue replay frame");
break;
}
}
}
ReplayOutcome::Gap { buffer_floor } => {
self.enqueue_system(
client_id,
SystemOp::ResumeRequired {
channel,
from_cursor: buffer_floor,
},
);
}
}
}
let writer = tokio::spawn(async move {
loop {
tokio::select! {
biased;
Some(item) = rx_high.recv() => {
depth_high.fetch_sub(1, Ordering::SeqCst);
let message = item.into_message().context("serialize prio-high RPS")?;
ws_tx.send(message).await.map_err(anyhow::Error::new)?;
}
Some(item) = rx_norm.recv() => {
depth_normal.fetch_sub(1, Ordering::SeqCst);
let message = item.into_message().context("serialize prio-norm RPS")?;
ws_tx.send(message).await.map_err(anyhow::Error::new)?;
}
Some(item) = rx_low.recv() => {
depth_low.fetch_sub(1, Ordering::SeqCst);
let message = item.into_message().context("serialize prio-low RPS")?;
ws_tx.send(message).await.map_err(anyhow::Error::new)?;
}
result = rx.recv() => {
match result {
Some(outbound) => {
let message = outbound
.into_message()
.context("serialize outbound RPS message")?;
ws_tx.send(message).await.map_err(anyhow::Error::new)?;
}
None => break,
}
}
}
}
Ok::<(), anyhow::Error>(())
});
let reader = {
let server = self.clone();
let tx = tx.clone();
tokio::spawn(async move {
while let Some(incoming) = futures_util::StreamExt::next(&mut ws_rx).await {
match incoming {
Ok(Message::Text(text)) => match serde_json::from_str::<Frame<C>>(&text) {
Ok(frame) => server.handle_incoming(client_id, frame).await,
Err(err) => {
warn!(?err, "invalid RPS frame from client");
server.enqueue_system(
client_id,
SystemOp::Error {
message: "invalid frame schema".into(),
},
);
}
},
Ok(Message::Binary(_)) => {
warn!("ignoring binary RPS frame");
}
Ok(Message::Ping(payload)) => {
let _ = tx.send(Outbound::Raw(Message::Pong(payload))).await;
}
Ok(Message::Pong(_)) => {}
Ok(Message::Close(_)) => break,
Err(err) => return Err(anyhow::Error::new(err)),
}
}
Ok::<(), anyhow::Error>(())
})
};
let result = tokio::try_join!(writer, reader);
self.connections.remove(&client_id);
result.map(|_| ()).map_err(anyhow::Error::new)
}
async fn handle_incoming(&self, client_id: Uuid, frame: Frame<C>) {
if let Err(msg) = validate_frame(&frame) {
self.enqueue_system(
client_id,
SystemOp::Error {
message: msg.to_string(),
},
);
return;
}
if frame.channel.is_system()
&& let Some(conn) = self.connections.get(&client_id)
{
conn.replay
.channel(frame.channel)
.push(&frame, REPLAY_BUFFER);
}
let replay = self.client_replay.get(&client_id).map(|c| c.clone());
if frame.channel.is_system() {
match serde_json::from_value::<SystemOp<C>>(frame.payload.clone()) {
Ok(SystemOp::Ping) => self.enqueue_system(client_id, SystemOp::Pong),
Ok(SystemOp::Slow { window }) => {
debug!(?client_id, ?window, "client reported backpressure window");
}
Ok(SystemOp::Ack { channel, cursor }) => {
self.handle_ack(client_id, channel, cursor, replay.as_deref());
}
Ok(SystemOp::ResumeRequired { .. }) => {
debug!(?client_id, "client reported resume_required; ignoring");
}
Ok(SystemOp::Subscribe { channels }) => {
if let Some(conn) = self.connections.get(&client_id) {
for ch in channels {
conn.allowed_channels.insert(ch);
}
}
}
Ok(SystemOp::Unsubscribe { channels }) => {
if let Some(conn) = self.connections.get(&client_id) {
for ch in channels {
conn.allowed_channels.remove(&ch);
}
}
}
Ok(SystemOp::Health { status, detail }) => {
debug!(?client_id, ?status, ?detail, "client reported health");
}
Ok(SystemOp::Features {
supported,
requested,
}) => {
debug!(?client_id, ?supported, ?requested, "client features");
}
Ok(SystemOp::Goodbye { reason }) => {
debug!(?client_id, ?reason, "client goodbye");
self.enqueue_system(client_id, SystemOp::Goodbye { reason });
}
Ok(other) => {
debug!(?client_id, ?other, "received system message");
}
Err(err) => {
warn!(?err, "invalid system payload");
self.enqueue_system(
client_id,
SystemOp::Error {
message: "invalid system payload".into(),
},
);
}
}
} else {
let channel = frame.channel;
if let Some(conn) = self.connections.get(&client_id)
&& !conn.allowed_channels.contains(&channel)
{
self.enqueue_system(
client_id,
SystemOp::Error {
message: format!("channel {} not subscribed", channel.name()),
},
);
return;
}
if let Some(handler) = self.channel_handlers.get(&channel) {
(handler.value())(client_id, frame.clone(), self);
} else {
self.enqueue_system(
client_id,
SystemOp::Error {
message: format!("no handler for channel {}", channel.name()),
},
);
}
debug!(
?client_id,
channel = channel.name(),
cursor = ?frame.cursor,
"received RPS frame"
);
}
}
fn handle_ack(
&self,
client_id: Uuid,
channel: C,
cursor: u64,
replay: Option<&ClientReplay<C>>,
) {
let Some(replay) = replay else {
warn!(?client_id, "ack from unknown client");
return;
};
let channel_replay = replay.channel(channel);
let state = channel_replay.state();
let last_sent = state.last_sent();
let buffer_floor = state.buffer_floor();
if cursor < buffer_floor {
self.enqueue_system(
client_id,
SystemOp::ResumeRequired {
channel,
from_cursor: buffer_floor,
},
);
return;
}
if cursor > last_sent {
self.enqueue_system(
client_id,
SystemOp::ResumeRequired {
channel,
from_cursor: last_sent,
},
);
return;
}
channel_replay.ack(cursor);
}
fn enqueue_system(&self, client_id: Uuid, op: SystemOp<C>) {
if let Some(conn) = self.connections.get(&client_id) {
let _ = self.enqueue_outbound(conn.value(), Outbound::System(op), client_id);
} else {
warn!(?client_id, "ignoring system send for unknown client");
}
}
fn enqueue_outbound(
&self,
conn: &ConnectionHandle<C>,
outbound: Outbound<C>,
client_id: Uuid,
) -> Result<(), SendError> {
let prio = outbound.priority();
let (target, depth) = match prio {
PRIORITY_HIGH => (&conn.queue_high, &conn.depth_high),
PRIORITY_LOW => (&conn.queue_low, &conn.depth_low),
_ => (&conn.queue_normal, &conn.depth_normal),
};
let depth_now = depth.fetch_add(1, Ordering::SeqCst) + 1;
if depth_now > QUEUE_WARN_THRESHOLD {
debug!(
?client_id,
?prio,
depth = depth_now,
"send queue depth high"
);
}
if depth_now > OUTBOUND_BUFFER {
depth.fetch_sub(1, Ordering::SeqCst);
if prio == PRIORITY_LOW {
warn!(
?client_id,
?prio,
"dropping low-priority frame (queue full)"
);
return Ok(());
} else {
warn!(
?client_id,
?prio,
"send queue overflow; treating as backpressure"
);
return Err(SendError::Backpressure(client_id));
}
}
match target.try_send(outbound) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(_)) => {
depth.fetch_sub(1, Ordering::SeqCst);
if prio == PRIORITY_LOW {
warn!(
?client_id,
?prio,
"dropping low-priority frame (queue full)"
);
Ok(())
} else {
Err(SendError::Backpressure(client_id))
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
depth.fetch_sub(1, Ordering::SeqCst);
Err(SendError::NotConnected(client_id))
}
}
}
}
async fn ws_upgrade<C: ChannelKind>(
State(server): State<Arc<PushServer<C>>>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
server.upgrade(ws).await
}
#[derive(Debug, Deserialize)]
struct SseParams {
client_id: Uuid,
#[serde(default)]
token: Option<String>,
#[serde(default)]
capabilities: Option<String>,
#[serde(default)]
channels: Option<String>,
#[serde(default)]
resume_cursor: Option<u64>,
}
async fn sse_upgrade<C: ChannelKind>(
State(server): State<Arc<PushServer<C>>>,
Query(params): Query<SseParams>,
) -> Result<impl IntoResponse, StatusCode> {
let client_id = params.client_id;
let capabilities = parse_channels::<C>(params.capabilities.as_deref());
let subscribe = parse_channels::<C>(params.channels.as_deref());
if let Err(_err) = (server.auth_validator)(client_id, params.token.as_deref(), &capabilities) {
return Err(StatusCode::UNAUTHORIZED);
}
let replay = server
.client_replay
.entry(client_id)
.or_insert_with(|| Arc::new(ClientReplay::default()))
.clone();
let allowed = {
let set = DashSet::new();
if !subscribe.is_empty() {
for ch in subscribe {
set.insert(ch);
}
} else if !capabilities.is_empty() {
for ch in capabilities.clone() {
set.insert(ch);
}
} else {
for ch in C::all() {
set.insert(*ch);
}
}
set
};
let (tx, rx) = mpsc::channel::<Frame<C>>(OUTBOUND_BUFFER);
server.sse_connections.insert(
client_id,
SseHandle {
sender: tx.clone(),
allowed_channels: allowed.clone(),
replay: replay.clone(),
},
);
let snapshot = replay.resume_state();
let resume_cursor = snapshot
.values()
.copied()
.max()
.unwrap_or(DEFAULT_RESUME_CURSOR);
let system_channel = C::all()
.iter()
.find(|c| c.is_system())
.copied()
.expect("ChannelKind must have a system channel");
let auth_ok = Frame::new(
system_channel,
serde_json::to_value(SystemOp::<C>::AuthOk {
resume_cursor,
resume_cursors: snapshot.clone(),
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
)
.with_client(client_id);
let _ = tx.try_send(auth_ok);
if let Some(from) = params.resume_cursor {
for entry in replay.channels.iter() {
let channel = *entry.key();
if !allowed.contains(&channel) {
continue;
}
match entry.value().replay_from(from) {
ReplayOutcome::Frames(frames) => {
for frame in frames {
let _ = tx.try_send(frame);
}
}
ReplayOutcome::Gap { buffer_floor } => {
let _ = tx.try_send(
Frame::new(
system_channel,
serde_json::to_value(SystemOp::<C>::ResumeRequired {
channel,
from_cursor: buffer_floor,
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
)
.with_client(client_id),
);
}
}
}
}
let stream = futures_util::StreamExt::map(
ReceiverStream::new(rx),
|frame| -> Result<Event, Infallible> {
let id = frame
.cursor
.map(|c| c.to_string())
.unwrap_or_else(|| "0".into());
let event = match serde_json::to_string(&frame) {
Ok(json) => Event::default().event("frame").id(id).data(json),
Err(err) => {
warn!(?err, "failed to serialize SSE frame");
Event::default().event("error").data("serialize_failed")
}
};
Ok(event)
},
);
Ok(Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
}
#[derive(Debug, Deserialize)]
#[serde(bound(deserialize = "C: ChannelKind"))]
struct AckBody<C: ChannelKind> {
client_id: Uuid,
channel: C,
cursor: u64,
}
async fn http_ack<C: ChannelKind>(
State(server): State<Arc<PushServer<C>>>,
Json(body): Json<AckBody<C>>,
) -> impl IntoResponse {
server.handle_ack(body.client_id, body.channel, body.cursor, None);
axum::http::StatusCode::NO_CONTENT
}
fn parse_channels<C: ChannelKind>(raw: Option<&str>) -> Vec<C> {
raw.map(|list| {
list.split(',')
.filter_map(|s| C::from_name(s.trim()))
.collect()
})
.unwrap_or_default()
}
fn validate_frame<C: ChannelKind>(frame: &Frame<C>) -> Result<(), &'static str> {
if frame.payload.is_null() {
return Err("payload required");
}
Ok(())
}