use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use std::hash::Hash;
use std::io::{Error as IoError, ErrorKind};
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, AtomicU8, Ordering};
use std::sync::Arc;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::{
future::{FutureExt, LocalBoxFuture},
lock::Mutex,
};
use https::{
header::{
HeaderName, ACCEPT, CACHE_CONTROL, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH,
CONTENT_TYPE, TRANSFER_ENCODING,
},
Method, StatusCode,
};
use pi_async_rt::rt::AsyncRuntime;
use tcp::{Socket, SocketHandle};
use wyhash::WyHasherBuilder;
use crate::gateway::GatewayContext;
use crate::middleware::{Middleware, MiddlewareResult};
use crate::request::HttpRequest;
use crate::response::{HttpResponse, ResponseHandler};
const DEFAULT_CHANNEL_SIZE: usize = 16;
const DEFAULT_MAX_EVENT_BYTES: usize = 64 * 1024;
const DEFAULT_HEARTBEAT_INTERVAL_MS: usize = 15_000;
const SSE_CONTENT_TYPE: &str = "text/event-stream; charset=utf-8";
const SSE_CACHE_CONTROL: &str = "no-cache, no-transform";
const SSE_KEEP_ALIVE: &str = "keep-alive";
const SSE_TRANSFER_ENCODING: &str = "chunked";
const SSE_ACCEL_BUFFERING_HEADER: &str = "x-accel-buffering";
const SSE_ACCEL_BUFFERING_DISABLED: &str = "no";
const LAST_EVENT_ID_HEADER: &str = "last-event-id";
const SENDER_STATE_OPEN: u8 = 0;
const SENDER_STATE_FINISHING: u8 = 1;
const SENDER_STATE_CLOSED: u8 = 2;
static SSE_CONNECTION_ID_ALLOCATOR: AtomicU32 = AtomicU32::new(1);
pub type SseResult<T> = Result<T, SseError>;
#[derive(Debug)]
pub enum SseError {
InvalidConfig(String),
InvalidMethod(String),
InvalidEvent(String),
EventTooLarge {
len: usize,
limit: usize,
},
QueueFull,
Busy,
Closed,
ConnectionIdExhausted,
ConnectionNotFound,
Io(IoError),
}
impl Display for SseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
SseError::InvalidConfig(reason) => write!(f, "invalid SSE config: {}", reason),
SseError::InvalidMethod(method) => write!(f, "invalid SSE method: {}", method),
SseError::InvalidEvent(reason) => write!(f, "invalid SSE event: {}", reason),
SseError::EventTooLarge { len, limit } => {
write!(f, "SSE event too large: {}, limit: {}", len, limit)
}
SseError::QueueFull => write!(f, "SSE queue is full"),
SseError::Busy => write!(f, "SSE sender is busy"),
SseError::Closed => write!(f, "SSE connection is closed"),
SseError::ConnectionIdExhausted => write!(f, "SSE connection id exhausted"),
SseError::ConnectionNotFound => write!(f, "SSE connection not found"),
SseError::Io(error) => write!(f, "SSE io error: {}", error),
}
}
}
impl StdError for SseError {}
impl From<IoError> for SseError {
fn from(value: IoError) -> Self {
match value.kind() {
ErrorKind::WouldBlock => SseError::QueueFull,
ErrorKind::BrokenPipe | ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset => {
SseError::Closed
}
_ => SseError::Io(value),
}
}
}
#[derive(Clone, Debug)]
pub struct SseConfig {
channel_size: usize,
max_event_bytes: usize,
heartbeat_interval_ms: usize,
send_initial_comment: bool,
no_cache: bool,
disable_proxy_buffering: bool,
}
impl Default for SseConfig {
fn default() -> Self {
SseConfig {
channel_size: DEFAULT_CHANNEL_SIZE,
max_event_bytes: DEFAULT_MAX_EVENT_BYTES,
heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
send_initial_comment: true,
no_cache: true,
disable_proxy_buffering: true,
}
}
}
impl SseConfig {
pub fn builder() -> SseConfigBuilder {
SseConfigBuilder {
config: SseConfig::default(),
}
}
pub fn channel_size(&self) -> usize {
self.channel_size
}
pub fn max_event_bytes(&self) -> usize {
self.max_event_bytes
}
pub fn heartbeat_interval_ms(&self) -> usize {
self.heartbeat_interval_ms
}
}
#[derive(Clone, Debug)]
pub struct SseConfigBuilder {
config: SseConfig,
}
impl SseConfigBuilder {
pub fn channel_size(mut self, channel_size: usize) -> Self {
self.config.channel_size = channel_size;
self
}
pub fn max_event_bytes(mut self, max_event_bytes: usize) -> Self {
self.config.max_event_bytes = max_event_bytes;
self
}
pub fn heartbeat_interval_ms(mut self, heartbeat_interval_ms: usize) -> Self {
self.config.heartbeat_interval_ms = heartbeat_interval_ms;
self
}
pub fn send_initial_comment(mut self, send_initial_comment: bool) -> Self {
self.config.send_initial_comment = send_initial_comment;
self
}
pub fn no_cache(mut self, no_cache: bool) -> Self {
self.config.no_cache = no_cache;
self
}
pub fn disable_proxy_buffering(mut self, disable_proxy_buffering: bool) -> Self {
self.config.disable_proxy_buffering = disable_proxy_buffering;
self
}
pub fn build(self) -> SseResult<SseConfig> {
if self.config.channel_size == 0 {
return Err(SseError::InvalidConfig(
"channel_size must be greater than 0".to_string(),
));
}
if self.config.max_event_bytes == 0 {
return Err(SseError::InvalidConfig(
"max_event_bytes must be greater than 0".to_string(),
));
}
Ok(self.config)
}
}
#[derive(Clone, Debug, Default)]
pub struct SseEvent {
event: Option<String>,
data: Vec<String>,
id: Option<String>,
retry: Option<u64>,
comment: Option<String>,
}
impl SseEvent {
pub fn builder() -> SseEventBuilder {
SseEventBuilder {
event: SseEvent::default(),
}
}
pub fn data(data: impl Into<String>) -> Self {
SseEvent {
data: vec![data.into()],
..SseEvent::default()
}
}
pub fn named(event: impl Into<String>, data: impl Into<String>) -> Self {
SseEvent {
event: Some(event.into()),
data: vec![data.into()],
..SseEvent::default()
}
}
pub fn comment(comment: impl Into<String>) -> Self {
SseEvent {
comment: Some(comment.into()),
..SseEvent::default()
}
}
pub fn id(mut self, id: impl Into<String>) -> SseResult<Self> {
let id = id.into();
validate_single_line_field("id", &id)?;
self.id = Some(id);
Ok(self)
}
pub fn retry(mut self, millis: u64) -> SseResult<Self> {
self.retry = Some(millis);
Ok(self)
}
pub fn encode(&self, max_event_bytes: usize) -> SseResult<Vec<u8>> {
if max_event_bytes == 0 {
return Err(SseError::InvalidConfig(
"max_event_bytes must be greater than 0".to_string(),
));
}
if let Some(event) = &self.event {
validate_single_line_field("event", event)?;
}
if let Some(id) = &self.id {
validate_single_line_field("id", id)?;
}
if let Some(comment) = &self.comment {
validate_text_field("comment", comment)?;
}
for data in &self.data {
validate_text_field("data", data)?;
}
let mut buf = Vec::new();
if let Some(comment) = &self.comment {
encode_multiline_field(&mut buf, ":", comment, false);
}
if let Some(id) = &self.id {
buf.extend_from_slice(b"id: ");
buf.extend_from_slice(id.as_bytes());
buf.extend_from_slice(b"\n");
}
if let Some(event) = &self.event {
buf.extend_from_slice(b"event: ");
buf.extend_from_slice(event.as_bytes());
buf.extend_from_slice(b"\n");
}
if let Some(retry) = self.retry {
buf.extend_from_slice(b"retry: ");
buf.extend_from_slice(retry.to_string().as_bytes());
buf.extend_from_slice(b"\n");
}
for data in &self.data {
encode_multiline_field(&mut buf, "data: ", data, true);
}
buf.extend_from_slice(b"\n");
if buf.len() > max_event_bytes {
return Err(SseError::EventTooLarge {
len: buf.len(),
limit: max_event_bytes,
});
}
Ok(buf)
}
}
#[derive(Clone, Debug, Default)]
pub struct SseEventBuilder {
event: SseEvent,
}
impl SseEventBuilder {
pub fn event(mut self, event: impl Into<String>) -> Self {
self.event.event = Some(event.into());
self
}
pub fn data(mut self, data: impl Into<String>) -> Self {
self.event.data.push(data.into());
self
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.event.id = Some(id.into());
self
}
pub fn retry(mut self, millis: u64) -> Self {
self.event.retry = Some(millis);
self
}
pub fn comment(mut self, comment: impl Into<String>) -> Self {
self.event.comment = Some(comment.into());
self
}
pub fn build(self) -> SseResult<SseEvent> {
self.event.encode(usize::MAX)?;
Ok(self.event)
}
}
pub struct SseResponse;
impl SseResponse {
pub fn builder<S: Socket>(req: &HttpRequest<S>) -> SseResponseBuilder<'_, S> {
SseResponseBuilder {
req,
config: SseConfig::default(),
}
}
}
pub struct SseResponseBuilder<'a, S: Socket> {
req: &'a HttpRequest<S>,
config: SseConfig,
}
impl<'a, S: Socket> SseResponseBuilder<'a, S> {
pub fn config(mut self, config: SseConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> SseResult<(HttpResponse, SseSender<S>)> {
if self.req.method() != &Method::GET {
return Err(SseError::InvalidMethod(
self.req.method().as_str().to_string(),
));
}
SseConfigBuilder {
config: self.config.clone(),
}
.build()?;
let mut resp = HttpResponse::new(self.config.channel_size);
resp.enable_stream();
resp.insert_header(CONTENT_TYPE.as_str(), SSE_CONTENT_TYPE);
resp.insert_header(CONNECTION.as_str(), SSE_KEEP_ALIVE);
resp.insert_header(TRANSFER_ENCODING.as_str(), SSE_TRANSFER_ENCODING);
resp.remove_header(CONTENT_LENGTH.as_str());
resp.remove_header(CONTENT_ENCODING.as_str());
if self.config.no_cache {
resp.insert_header(CACHE_CONTROL.as_str(), SSE_CACHE_CONTROL);
}
if self.config.disable_proxy_buffering {
resp.insert_header(SSE_ACCEL_BUFFERING_HEADER, SSE_ACCEL_BUFFERING_DISABLED);
}
let handler = resp
.get_response_handler()
.ok_or_else(|| SseError::InvalidConfig("missing response handler".to_string()))?;
let last_event_id = read_last_event_id(self.req)?;
let sender = SseSender::new(
next_connection_id(),
self.req.get_handle().clone(),
handler,
self.config.max_event_bytes,
last_event_id,
);
if self.config.send_initial_comment {
sender.try_comment("pi_http sse connected")?;
}
Ok((resp, sender))
}
pub fn build_with_heartbeat<R>(self, runtime: R) -> SseResult<(HttpResponse, SseSender<S>)>
where
R: AsyncRuntime<()>,
{
let interval = self.config.heartbeat_interval_ms;
let (resp, sender) = self.build()?;
if interval > 0 {
sender.start_heartbeat(runtime, interval)?;
}
Ok((resp, sender))
}
}
pub struct SseAccept<'a, S: Socket> {
pub context: &'a GatewayContext,
pub request: &'a HttpRequest<S>,
pub sender: SseSender<S>,
}
#[derive(Clone, Debug)]
pub enum SseAcceptDecision<K> {
Accept(K),
Reject {
status: StatusCode,
message: String,
},
}
impl<K> SseAcceptDecision<K> {
pub fn accept(key: K) -> Self {
SseAcceptDecision::Accept(key)
}
pub fn reject(status: StatusCode, message: impl Into<String>) -> Self {
SseAcceptDecision::Reject {
status,
message: message.into(),
}
}
}
pub type SseAcceptHandler<K, S> =
Arc<dyn for<'a> Fn(SseAccept<'a, S>) -> SseResult<SseAcceptDecision<K>> + Send + Sync>;
pub type SseOpenHandler<K, S> = Arc<dyn Fn(SseOpen<K, S>) -> SseResult<()> + Send + Sync>;
type SseHeartbeatStarter<S> = Arc<dyn Fn(SseSender<S>, usize) -> SseResult<()> + Send + Sync>;
#[derive(Clone)]
pub struct SseOpen<K, S: Socket> {
pub key: K,
pub id: SseConnectionId,
pub sender: SseSender<S>,
}
pub struct SseMiddleware<K, S: Socket> {
hub: SseHub<K, S>,
config: SseConfig,
acceptor: SseAcceptHandler<K, S>,
on_open: Option<SseOpenHandler<K, S>>,
heartbeat_starter: Option<SseHeartbeatStarter<S>>,
require_accept_header: bool,
}
impl<K, S: Socket> Clone for SseMiddleware<K, S> {
fn clone(&self) -> Self {
SseMiddleware {
hub: self.hub.clone(),
config: self.config.clone(),
acceptor: self.acceptor.clone(),
on_open: self.on_open.clone(),
heartbeat_starter: self.heartbeat_starter.clone(),
require_accept_header: self.require_accept_header,
}
}
}
impl<S: Socket> SseMiddleware<SseConnectionId, S> {
pub fn new(hub: SseHub<SseConnectionId, S>) -> Self {
Self::builder(hub)
.build()
.expect("default SSE middleware config must be valid")
}
pub fn builder(hub: SseHub<SseConnectionId, S>) -> SseMiddlewareBuilder<SseConnectionId, S> {
SseMiddlewareBuilder {
hub,
config: SseConfig::default(),
acceptor: Arc::new(|accept| Ok(SseAcceptDecision::Accept(accept.sender.id()))),
on_open: None,
heartbeat_starter: None,
require_accept_header: false,
}
}
}
impl<K, S> SseMiddleware<K, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
S: Socket,
{
pub fn with_acceptor<F>(hub: SseHub<K, S>, acceptor: F) -> SseMiddlewareBuilder<K, S>
where
F: for<'a> Fn(SseAccept<'a, S>) -> SseResult<SseAcceptDecision<K>> + Send + Sync + 'static,
{
SseMiddlewareBuilder {
hub,
config: SseConfig::default(),
acceptor: Arc::new(acceptor),
on_open: None,
heartbeat_starter: None,
require_accept_header: false,
}
}
pub fn hub(&self) -> SseHub<K, S> {
self.hub.clone()
}
pub fn config(&self) -> &SseConfig {
&self.config
}
}
impl<K, S> Middleware<S, GatewayContext> for SseMiddleware<K, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
S: Socket,
{
fn request<'a>(
&'a self,
context: &'a mut GatewayContext,
req: HttpRequest<S>,
) -> LocalBoxFuture<'a, MiddlewareResult<S>> {
async move {
if self.require_accept_header && !accepts_event_stream(&req) {
return MiddlewareResult::Break(sse_error_response(
StatusCode::NOT_ACCEPTABLE,
"SSE request must accept text/event-stream",
));
}
let (resp, sender) = match SseResponse::builder(&req)
.config(self.config.clone())
.build()
{
Ok(value) => value,
Err(error) => return MiddlewareResult::Break(sse_error_to_response(error)),
};
let decision = match (self.acceptor)(SseAccept {
context,
request: &req,
sender: sender.clone(),
}) {
Ok(decision) => decision,
Err(error) => {
let _ = sender.try_finish();
return MiddlewareResult::Break(sse_error_to_response(error));
}
};
let key = match decision {
SseAcceptDecision::Accept(key) => key,
SseAcceptDecision::Reject { status, message } => {
let _ = sender.try_finish();
return MiddlewareResult::Break(sse_error_response(status, message));
}
};
let id = match self.hub.register(key.clone(), sender.clone()) {
Ok(id) => id,
Err(error) => {
let _ = sender.try_finish();
return MiddlewareResult::Break(sse_error_to_response(error));
}
};
if let Some(on_open) = &self.on_open {
let open = SseOpen {
key,
id,
sender: sender.clone(),
};
if let Err(error) = on_open(open) {
let _ = self.hub.unregister(id);
let _ = sender.try_finish();
return MiddlewareResult::Break(sse_error_to_response(error));
}
}
if let Some(starter) = &self.heartbeat_starter {
if self.config.heartbeat_interval_ms > 0 {
if let Err(error) = starter(sender.clone(), self.config.heartbeat_interval_ms) {
let _ = self.hub.unregister(id);
let _ = sender.try_finish();
return MiddlewareResult::Break(sse_error_to_response(error));
}
}
}
MiddlewareResult::Finish((req, resp))
}
.boxed_local()
}
fn response<'a>(
&'a self,
_context: &'a mut GatewayContext,
req: HttpRequest<S>,
resp: HttpResponse,
) -> LocalBoxFuture<'a, MiddlewareResult<S>> {
async move {
if resp.is_stream() {
MiddlewareResult::Break(resp)
} else {
MiddlewareResult::Finish((req, resp))
}
}
.boxed_local()
}
}
pub struct SseMiddlewareBuilder<K, S: Socket> {
hub: SseHub<K, S>,
config: SseConfig,
acceptor: SseAcceptHandler<K, S>,
on_open: Option<SseOpenHandler<K, S>>,
heartbeat_starter: Option<SseHeartbeatStarter<S>>,
require_accept_header: bool,
}
impl<K, S> SseMiddlewareBuilder<K, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
S: Socket,
{
pub fn config(mut self, config: SseConfig) -> Self {
self.config = config;
self
}
pub fn require_accept_header(mut self, require_accept_header: bool) -> Self {
self.require_accept_header = require_accept_header;
self
}
pub fn heartbeat_runtime<R>(mut self, runtime: R) -> Self
where
R: AsyncRuntime<()>,
{
self.heartbeat_starter = Some(Arc::new(move |sender, interval_ms| {
sender.start_heartbeat(runtime.clone(), interval_ms)
}));
self
}
pub fn acceptor<F>(mut self, acceptor: F) -> Self
where
F: for<'a> Fn(SseAccept<'a, S>) -> SseResult<SseAcceptDecision<K>> + Send + Sync + 'static,
{
self.acceptor = Arc::new(acceptor);
self
}
pub fn on_open<F>(mut self, handler: F) -> Self
where
F: Fn(SseOpen<K, S>) -> SseResult<()> + Send + Sync + 'static,
{
self.on_open = Some(Arc::new(handler));
self
}
pub fn build(self) -> SseResult<SseMiddleware<K, S>> {
SseConfigBuilder {
config: self.config.clone(),
}
.build()?;
Ok(SseMiddleware {
hub: self.hub,
config: self.config,
acceptor: self.acceptor,
on_open: self.on_open,
heartbeat_starter: self.heartbeat_starter,
require_accept_header: self.require_accept_header,
})
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct SseConnectionId(u32);
impl SseConnectionId {
pub fn get(self) -> u32 {
self.0
}
}
pub struct SseSender<S: Socket> {
inner: Arc<SseSenderInner<S>>,
}
struct SseSenderInner<S: Socket> {
id: SseConnectionId,
handle: SocketHandle<S>,
response: ResponseHandler,
max_event_bytes: usize,
last_event_id: Option<String>,
state: AtomicU8,
send_gate: Arc<Mutex<()>>,
}
impl<S: Socket> Clone for SseSender<S> {
fn clone(&self) -> Self {
SseSender {
inner: self.inner.clone(),
}
}
}
impl<S: Socket> SseSender<S> {
fn new(
id: SseConnectionId,
handle: SocketHandle<S>,
response: ResponseHandler,
max_event_bytes: usize,
last_event_id: Option<String>,
) -> Self {
SseSender {
inner: Arc::new(SseSenderInner {
id,
handle,
response,
max_event_bytes,
last_event_id,
state: AtomicU8::new(SENDER_STATE_OPEN),
send_gate: Arc::new(Mutex::new(())),
}),
}
}
pub async fn send(&self, event: SseEvent) -> SseResult<()> {
let _guard = self.inner.send_gate.lock().await;
let encoded = event.encode(self.inner.max_event_bytes)?;
self.write_encoded(encoded).await
}
pub async fn send_data(&self, data: impl AsRef<str>) -> SseResult<()> {
self.send(SseEvent::data(data.as_ref().to_string())).await
}
pub async fn comment(&self, text: impl AsRef<str>) -> SseResult<()> {
self.send(SseEvent::comment(text.as_ref().to_string()))
.await
}
pub async fn heartbeat(&self) -> SseResult<()> {
self.comment("").await
}
pub async fn finish(&self) -> SseResult<()> {
let _guard = self.inner.send_gate.lock().await;
self.finish_locked().await
}
pub fn try_send(&self, event: SseEvent) -> SseResult<()> {
let _guard = self.inner.send_gate.try_lock().ok_or(SseError::Busy)?;
let encoded = event.encode(self.inner.max_event_bytes)?;
self.try_write_encoded(encoded)
}
pub fn try_send_data(&self, data: impl AsRef<str>) -> SseResult<()> {
self.try_send(SseEvent::data(data.as_ref().to_string()))
}
pub fn try_comment(&self, text: impl AsRef<str>) -> SseResult<()> {
self.try_send(SseEvent::comment(text.as_ref().to_string()))
}
pub fn try_heartbeat(&self) -> SseResult<()> {
self.try_comment("")
}
pub fn try_finish(&self) -> SseResult<()> {
let _guard = self.inner.send_gate.try_lock().ok_or(SseError::Busy)?;
self.try_finish_locked()
}
pub fn id(&self) -> SseConnectionId {
self.inner.id
}
pub fn is_closed(&self) -> bool {
if self.inner.handle.is_closed() {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
return true;
}
self.inner.state.load(Ordering::Acquire) == SENDER_STATE_CLOSED
}
pub fn last_event_id(&self) -> Option<&str> {
self.inner.last_event_id.as_deref()
}
pub fn remote_addr(&self) -> SocketAddr {
self.inner.handle.get_remote().clone()
}
pub fn local_addr(&self) -> SocketAddr {
self.inner.handle.get_local().clone()
}
fn start_heartbeat<R>(&self, runtime: R, interval_ms: usize) -> SseResult<()>
where
R: AsyncRuntime<()>,
{
let sender = self.clone();
let runtime_for_task = runtime.clone();
runtime
.spawn(async move {
loop {
runtime_for_task.timeout(interval_ms).await;
if sender.is_closed() {
break;
}
if sender.heartbeat().await.is_err() {
break;
}
}
})
.map(|_| ())
.map_err(SseError::Io)
}
async fn write_encoded(&self, encoded: Vec<u8>) -> SseResult<()> {
if self.is_closed() || self.inner.state.load(Ordering::Acquire) != SENDER_STATE_OPEN {
return Err(SseError::Closed);
}
match self.inner.response.write(encoded).await {
Ok(_) => Ok(()),
Err(e) => {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
Err(SseError::from(e))
}
}
}
fn try_write_encoded(&self, encoded: Vec<u8>) -> SseResult<()> {
if self.is_closed() || self.inner.state.load(Ordering::Acquire) != SENDER_STATE_OPEN {
return Err(SseError::Closed);
}
match self.inner.response.try_write(encoded) {
Ok(_) => Ok(()),
Err(e) => {
let error = SseError::from(e);
if matches!(error, SseError::Closed) {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
}
Err(error)
}
}
}
async fn finish_locked(&self) -> SseResult<()> {
match self.inner.state.load(Ordering::Acquire) {
SENDER_STATE_CLOSED => return Ok(()),
SENDER_STATE_OPEN => {
self.inner
.state
.store(SENDER_STATE_FINISHING, Ordering::Release);
}
_ => (),
}
match self.inner.response.finish().await {
Ok(_) => {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
Ok(())
}
Err(e) => {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
Err(SseError::from(e))
}
}
}
fn try_finish_locked(&self) -> SseResult<()> {
match self.inner.state.load(Ordering::Acquire) {
SENDER_STATE_CLOSED => return Ok(()),
SENDER_STATE_OPEN => {
self.inner
.state
.store(SENDER_STATE_FINISHING, Ordering::Release);
}
_ => (),
}
match self.inner.response.try_finish() {
Ok(_) => {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
Ok(())
}
Err(e) => {
let error = SseError::from(e);
if !matches!(error, SseError::QueueFull | SseError::Busy) {
self.inner
.state
.store(SENDER_STATE_CLOSED, Ordering::Release);
}
Err(error)
}
}
}
}
#[derive(Clone, Debug)]
pub struct SseConnectionInfo<K> {
pub key: K,
pub id: SseConnectionId,
pub remote_addr: SocketAddr,
pub local_addr: SocketAddr,
pub last_event_id: Option<String>,
pub closed: bool,
}
#[derive(Clone)]
struct SseHubEntry<K, S: Socket> {
key: K,
sender: SseSender<S>,
}
pub struct SseHub<K, S: Socket> {
inner: Arc<SseHubInner<K, S>>,
}
struct SseHubInner<K, S: Socket> {
by_id: DashMap<SseConnectionId, SseHubEntry<K, S>, WyHasherBuilder>,
by_key: DashMap<K, Vec<SseConnectionId>, WyHasherBuilder>,
auto_remove_closed: bool,
}
impl<K, S: Socket> Clone for SseHub<K, S> {
fn clone(&self) -> Self {
SseHub {
inner: self.inner.clone(),
}
}
}
impl<K, S> SseHub<K, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
S: Socket,
{
pub fn builder() -> SseHubBuilder<K, S> {
SseHubBuilder {
auto_remove_closed: true,
marker: PhantomData,
}
}
pub fn register(&self, key: K, sender: SseSender<S>) -> SseResult<SseConnectionId> {
let id = sender.id();
if id.get() == 0 {
return Err(SseError::ConnectionIdExhausted);
}
match self.inner.by_id.entry(id) {
Entry::Occupied(_) => return Err(SseError::ConnectionIdExhausted),
Entry::Vacant(slot) => {
slot.insert(SseHubEntry {
key: key.clone(),
sender,
});
}
}
self.inner
.by_key
.entry(key)
.or_insert_with(Vec::new)
.push(id);
Ok(id)
}
pub fn unregister(&self, id: SseConnectionId) -> Option<SseSender<S>> {
let (_, entry) = self.inner.by_id.remove(&id)?;
self.remove_id_from_key(&entry.key, id);
Some(entry.sender)
}
pub async fn close(&self, id: SseConnectionId) -> SseResult<()> {
let sender = self.get_by_id(id).ok_or(SseError::ConnectionNotFound)?;
let result = sender.finish().await;
if result.is_ok() || matches!(result, Err(SseError::Closed)) {
let _ = self.unregister(id);
}
result
}
pub fn try_close(&self, id: SseConnectionId) -> SseResult<()> {
let sender = self.get_by_id(id).ok_or(SseError::ConnectionNotFound)?;
let result = sender.try_finish();
if result.is_ok() || matches!(result, Err(SseError::Closed)) {
let _ = self.unregister(id);
}
result
}
pub fn get(&self, key: &K) -> Vec<SseSender<S>> {
let ids = self
.inner
.by_key
.get(key)
.map(|ids| ids.clone())
.unwrap_or_default();
ids.into_iter()
.filter_map(|id| self.get_by_id(id))
.collect()
}
pub fn get_by_id(&self, id: SseConnectionId) -> Option<SseSender<S>> {
self.inner.by_id.get(&id).map(|entry| entry.sender.clone())
}
pub fn snapshot(&self) -> Vec<SseConnectionInfo<K>> {
self.inner
.by_id
.iter()
.map(|entry| {
let sender = &entry.value().sender;
SseConnectionInfo {
key: entry.value().key.clone(),
id: *entry.key(),
remote_addr: sender.remote_addr(),
local_addr: sender.local_addr(),
last_event_id: sender.last_event_id().map(ToOwned::to_owned),
closed: sender.is_closed(),
}
})
.collect()
}
pub fn remove_closed(&self) -> usize {
let ids = self
.inner
.by_id
.iter()
.filter_map(|entry| {
if entry.value().sender.is_closed() {
Some(*entry.key())
} else {
None
}
})
.collect::<Vec<_>>();
let len = ids.len();
for id in ids {
let _ = self.unregister(id);
}
len
}
pub async fn send_to(&self, key: &K, event: SseEvent) -> SseSendReport {
let senders = self.get(key);
self.send_many(senders, event).await
}
pub fn try_send_to(&self, key: &K, event: SseEvent) -> SseSendReport {
let senders = self.get(key);
self.try_send_many(senders, event)
}
pub async fn send_to_id(&self, id: SseConnectionId, event: SseEvent) -> SseSendReport {
match self.get_by_id(id) {
Some(sender) => self.send_many(vec![sender], event).await,
None => SseSendReport::not_found(id),
}
}
pub fn try_send_to_id(&self, id: SseConnectionId, event: SseEvent) -> SseSendReport {
match self.get_by_id(id) {
Some(sender) => self.try_send_many(vec![sender], event),
None => SseSendReport::not_found(id),
}
}
pub async fn broadcast(&self, event: SseEvent) -> SseSendReport {
let senders = self
.inner
.by_id
.iter()
.map(|entry| entry.value().sender.clone())
.collect::<Vec<_>>();
self.send_many(senders, event).await
}
pub fn try_broadcast(&self, event: SseEvent) -> SseSendReport {
let senders = self
.inner
.by_id
.iter()
.map(|entry| entry.value().sender.clone())
.collect::<Vec<_>>();
self.try_send_many(senders, event)
}
async fn send_many(&self, senders: Vec<SseSender<S>>, event: SseEvent) -> SseSendReport {
let mut report = SseSendReport::new(senders.len());
for sender in senders {
let id = sender.id();
match sender.send(event.clone()).await {
Ok(_) => report.sent += 1,
Err(error) => report.record(id, &error),
}
}
if self.inner.auto_remove_closed {
let _ = self.remove_closed();
}
report
}
fn try_send_many(&self, senders: Vec<SseSender<S>>, event: SseEvent) -> SseSendReport {
let mut report = SseSendReport::new(senders.len());
for sender in senders {
let id = sender.id();
match sender.try_send(event.clone()) {
Ok(_) => report.sent += 1,
Err(error) => report.record(id, &error),
}
}
if self.inner.auto_remove_closed {
let _ = self.remove_closed();
}
report
}
fn remove_id_from_key(&self, key: &K, id: SseConnectionId) {
let remove_key = if let Some(mut ids) = self.inner.by_key.get_mut(key) {
ids.retain(|item| item != &id);
ids.is_empty()
} else {
false
};
if remove_key {
self.inner.by_key.remove(key);
}
}
}
pub struct SseHubBuilder<K, S: Socket> {
auto_remove_closed: bool,
marker: PhantomData<(K, S)>,
}
impl<K, S> SseHubBuilder<K, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
S: Socket,
{
pub fn auto_remove_closed(mut self, auto_remove_closed: bool) -> Self {
self.auto_remove_closed = auto_remove_closed;
self
}
pub fn build(self) -> SseHub<K, S> {
SseHub {
inner: Arc::new(SseHubInner {
by_id: DashMap::with_hasher(WyHasherBuilder::default()),
by_key: DashMap::with_hasher(WyHasherBuilder::default()),
auto_remove_closed: self.auto_remove_closed,
}),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SseSendFailureKind {
Closed,
QueueFull,
Busy,
ConnectionNotFound,
Io,
InvalidEvent,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct SseSendFailure {
pub id: SseConnectionId,
pub kind: SseSendFailureKind,
}
#[derive(Clone, Debug, Default)]
pub struct SseSendReport {
pub total: usize,
pub sent: usize,
pub closed: usize,
pub queue_full: usize,
pub busy: usize,
pub not_found: usize,
pub failed: usize,
pub failures: Vec<SseSendFailure>,
}
impl SseSendReport {
fn new(total: usize) -> Self {
SseSendReport {
total,
..SseSendReport::default()
}
}
fn not_found(id: SseConnectionId) -> Self {
let mut report = SseSendReport::new(0);
report.not_found = 1;
report.failures.push(SseSendFailure {
id,
kind: SseSendFailureKind::ConnectionNotFound,
});
report
}
fn record(&mut self, id: SseConnectionId, error: &SseError) {
let kind = match error {
SseError::Closed => {
self.closed += 1;
SseSendFailureKind::Closed
}
SseError::QueueFull => {
self.queue_full += 1;
SseSendFailureKind::QueueFull
}
SseError::Busy => {
self.busy += 1;
SseSendFailureKind::Busy
}
SseError::ConnectionNotFound => {
self.not_found += 1;
SseSendFailureKind::ConnectionNotFound
}
SseError::InvalidEvent(_) | SseError::EventTooLarge { .. } => {
self.failed += 1;
SseSendFailureKind::InvalidEvent
}
_ => {
self.failed += 1;
SseSendFailureKind::Io
}
};
self.failures.push(SseSendFailure { id, kind });
}
}
fn accepts_event_stream<S: Socket>(req: &HttpRequest<S>) -> bool {
req.headers()
.get(ACCEPT)
.and_then(|value| value.to_str().ok())
.map(|value| {
value
.split(',')
.map(|item| item.trim().to_ascii_lowercase())
.any(|item| item.starts_with("text/event-stream") || item.starts_with("*/*"))
})
.unwrap_or(false)
}
fn sse_error_to_response(error: SseError) -> HttpResponse {
let status = match error {
SseError::InvalidMethod(_) => StatusCode::METHOD_NOT_ALLOWED,
SseError::InvalidConfig(_) => StatusCode::INTERNAL_SERVER_ERROR,
SseError::InvalidEvent(_) | SseError::EventTooLarge { .. } => StatusCode::BAD_REQUEST,
SseError::QueueFull | SseError::Busy => StatusCode::SERVICE_UNAVAILABLE,
SseError::Closed | SseError::ConnectionNotFound => StatusCode::GONE,
SseError::ConnectionIdExhausted | SseError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
sse_error_response(status, error.to_string())
}
fn sse_error_response(status: StatusCode, body: impl AsRef<str>) -> HttpResponse {
let mut resp = HttpResponse::new(1);
resp.status(status.as_u16());
resp.insert_header(CONTENT_TYPE.as_str(), "text/plain; charset=utf-8");
if let Some(resp_body) = resp.as_mut_body() {
let _ = resp_body.init();
resp_body.push(body.as_ref().as_bytes());
}
resp
}
fn validate_single_line_field(name: &str, value: &str) -> SseResult<()> {
if value.contains('\r') || value.contains('\n') || value.contains('\0') {
return Err(SseError::InvalidEvent(format!(
"{} must not contain CR, LF or NUL",
name
)));
}
Ok(())
}
fn validate_text_field(name: &str, value: &str) -> SseResult<()> {
if value.contains('\0') {
return Err(SseError::InvalidEvent(format!(
"{} must not contain NUL",
name
)));
}
Ok(())
}
fn encode_multiline_field(buf: &mut Vec<u8>, prefix: &str, value: &str, add_space_for_empty: bool) {
let normalized = value.replace("\r\n", "\n").replace('\r', "\n");
for line in normalized.split('\n') {
buf.extend_from_slice(prefix.as_bytes());
if prefix == ":" && !line.is_empty() {
buf.extend_from_slice(b" ");
}
if add_space_for_empty || !line.is_empty() {
buf.extend_from_slice(line.as_bytes());
}
buf.extend_from_slice(b"\n");
}
}
fn read_last_event_id<S: Socket>(req: &HttpRequest<S>) -> SseResult<Option<String>> {
let header_name = HeaderName::from_static(LAST_EVENT_ID_HEADER);
if let Some(value) = req.headers().get(header_name) {
let value = value
.to_str()
.map_err(|e| SseError::InvalidEvent(format!("invalid Last-Event-ID: {:?}", e)))?;
if value.is_empty() {
Ok(None)
} else {
validate_single_line_field(LAST_EVENT_ID_HEADER, value)?;
Ok(Some(value.to_string()))
}
} else {
Ok(None)
}
}
fn next_connection_id() -> SseConnectionId {
loop {
let id = SSE_CONNECTION_ID_ALLOCATOR.fetch_add(1, Ordering::Relaxed);
if id != 0 {
return SseConnectionId(id);
}
}
}
#[cfg(test)]
mod tests {
use std::cell::UnsafeCell;
use std::future::Future;
use std::io::Result as IoResult;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicU32};
use std::sync::{Arc, Mutex as StdMutex};
use std::task::Waker;
use bytes::BytesMut;
use crossbeam_channel::unbounded;
use futures::executor::block_on;
use https::{HeaderMap, HeaderValue, Version};
use mio::Token;
use pi_async_rt::rt::serial::AsyncValue;
use tcp::{
utils::{Hibernate, Ready, SocketContext},
Socket, SocketEvent, SocketHandle, SocketImage,
};
use super::*;
use crate::gateway::GatewayContext;
use crate::middleware::{Middleware, MiddlewareResult};
use crate::request::HttpRequest;
use crate::response::HttpResponse;
use crate::utils::HttpRecvResult;
struct TestSocket;
impl Socket for TestSocket {
fn is_closed(&self) -> bool {
false
}
fn is_flush(&self) -> bool {
true
}
fn set_flush(&self, _flush: bool) {}
fn get_handle(&self) -> SocketHandle<Self> {
unimplemented!("test socket does not expose handle from inner socket")
}
fn remove_handle(&mut self) -> Option<SocketHandle<Self>> {
None
}
fn get_local(&self) -> &std::net::SocketAddr {
unimplemented!("test socket local addr is read from SocketImage")
}
fn get_remote(&self) -> &std::net::SocketAddr {
unimplemented!("test socket remote addr is read from SocketImage")
}
fn get_token(&self) -> Option<&Token> {
None
}
fn get_uid(&self) -> Option<&usize> {
None
}
fn get_context(&self) -> Rc<UnsafeCell<SocketContext>> {
unimplemented!("unused by SSE tests")
}
fn set_timeout(&self, _timeout: usize, _event: SocketEvent) {}
fn unset_timeout(&self) {}
fn is_security(&self) -> bool {
false
}
fn read_ready(&mut self, _adjust: usize) -> Result<AsyncValue<usize>, usize> {
unimplemented!("unused by SSE tests")
}
fn is_wait_wakeup_read_ready(&self) -> bool {
false
}
fn wakeup_read_ready(&mut self) {}
fn get_read_buffer(&self) -> Rc<UnsafeCell<Option<BytesMut>>> {
unimplemented!("unused by SSE tests")
}
fn get_write_buffer(&mut self) -> Option<&mut BytesMut> {
None
}
fn write_ready<B>(&mut self, _buf: B) -> IoResult<()>
where
B: AsRef<[u8]> + 'static,
{
Ok(())
}
fn reregister_interest(&mut self, _ready: Ready) -> IoResult<()> {
Ok(())
}
fn is_hibernated(&self) -> bool {
false
}
fn push_hibernated_task<F>(&self, _task: F)
where
F: Future<Output = ()> + 'static,
{
}
fn run_hibernated_tasks(&self) {}
fn hibernate(&self, _handle: SocketHandle<Self>, _ready: Ready) -> Option<Hibernate<Self>> {
None
}
fn set_hibernate(&self, _hibernate: Hibernate<Self>) -> bool {
false
}
fn set_hibernate_wakers(&self, _waker: Waker) {}
fn wakeup(&mut self, _result: IoResult<()>) -> bool {
true
}
fn close(&mut self, _reason: IoResult<()>) -> IoResult<()> {
Ok(())
}
}
fn test_handle() -> SocketHandle<TestSocket> {
let local = "127.0.0.1:8080".parse().unwrap();
let remote = "127.0.0.1:10000".parse().unwrap();
let (close_sender, _) = unbounded();
let (timer_sender, _) = unbounded();
let socket = Arc::new(UnsafeCell::new(TestSocket));
let image = SocketImage::new(
&socket,
local,
remote,
Token(1),
1,
false,
Arc::new(AtomicBool::new(false)),
close_sender,
timer_sender,
);
SocketHandle::new(image)
}
struct TestSseFixture {
_resp: HttpResponse,
sender: SseSender<TestSocket>,
}
fn test_sender(channel_size: usize) -> TestSseFixture {
let resp = HttpResponse::new(channel_size);
let handler = resp.get_response_handler().unwrap();
let sender = SseSender::new(
next_connection_id(),
test_handle(),
handler,
DEFAULT_MAX_EVENT_BYTES,
Some("last-id".to_string()),
);
TestSseFixture {
_resp: resp,
sender,
}
}
fn test_get_request(headers: HeaderMap) -> HttpRequest<TestSocket> {
HttpRequest::new(
test_handle(),
"GET",
"http://127.0.0.1/sse",
Version::HTTP_11,
headers,
&[],
)
.expect("test HTTP request must be created")
}
fn test_middleware_config() -> SseConfig {
SseConfig::builder()
.channel_size(8)
.heartbeat_interval_ms(0)
.send_initial_comment(false)
.build()
.expect("test middleware config must be valid")
}
fn expect_finish(
result: MiddlewareResult<TestSocket>,
) -> (HttpRequest<TestSocket>, HttpResponse) {
match result {
MiddlewareResult::Finish(pair) => pair,
_ => panic!("SSE middleware test expected Finish"),
}
}
fn expect_break(result: MiddlewareResult<TestSocket>) -> HttpResponse {
match result {
MiddlewareResult::Break(resp) => resp,
_ => panic!("SSE middleware test expected Break"),
}
}
fn response_text(resp: HttpResponse) -> String {
String::from_utf8(Vec::<u8>::from(resp)).expect("test response must be UTF-8")
}
fn drain_event(resp: &HttpResponse) -> String {
let body = resp.as_body().expect("test response must have body");
match block_on(body.next()) {
HttpRecvResult::Ok(Some((_index, chunk))) => {
String::from_utf8(chunk).expect("test SSE event must be UTF-8")
}
_ => panic!("test response body must yield one SSE event"),
}
}
fn assert_send_sync<T: Send + Sync>() {}
fn assert_clone<T: Clone>() {}
#[test]
fn sse_event_builder_and_encoding() {
let event = SseEvent::builder()
.comment("ready")
.id("42")
.event("notice")
.retry(3000)
.data("hello\nworld")
.build()
.unwrap();
let encoded = event.encode(1024).unwrap();
assert_eq!(
String::from_utf8(encoded).unwrap(),
": ready\nid: 42\nevent: notice\nretry: 3000\ndata: hello\ndata: world\n\n"
);
}
#[test]
fn sse_event_rejects_invalid_or_large_event() {
assert!(matches!(
SseEvent::builder().event("bad\nname").build(),
Err(SseError::InvalidEvent(_))
));
let large = SseEvent::data("0123456789");
assert!(matches!(
large.encode(4),
Err(SseError::EventTooLarge { .. })
));
}
#[test]
fn sse_config_builder_rejects_invalid_values() {
assert!(matches!(
SseConfig::builder().channel_size(0).build(),
Err(SseError::InvalidConfig(_))
));
assert!(matches!(
SseConfig::builder().max_event_bytes(0).build(),
Err(SseError::InvalidConfig(_))
));
}
#[test]
fn sse_response_builder_rejects_non_get_method() {
let req = HttpRequest::new(
test_handle(),
"POST",
"http://127.0.0.1/sse",
Version::HTTP_11,
HeaderMap::new(),
&[],
)
.expect("test HTTP request must be created");
assert!(matches!(
SseResponse::builder(&req).build(),
Err(SseError::InvalidMethod(method)) if method == "POST"
));
}
#[test]
fn sse_sender_try_send_reports_queue_full() {
let fixture = test_sender(1);
let sender = &fixture.sender;
sender.try_send_data("first").unwrap();
assert!(matches!(
sender.try_send_data("second"),
Err(SseError::QueueFull)
));
}
#[test]
fn sse_sender_send_and_finish_are_ordered() {
let fixture = test_sender(4);
let sender = &fixture.sender;
block_on(sender.send_data("first")).unwrap();
block_on(sender.finish()).unwrap();
block_on(sender.finish()).unwrap();
assert!(matches!(
block_on(sender.send_data("after-finish")),
Err(SseError::Closed)
));
}
#[test]
fn sse_sender_preserves_same_thread_successful_enqueue_order() {
let fixture = test_sender(8);
let sender = &fixture.sender;
sender.try_send_data("first").unwrap();
sender
.try_send(SseEvent::named("notice", "second"))
.unwrap();
sender.try_finish().unwrap();
let body = fixture._resp.as_body().unwrap();
let first = match block_on(body.next()) {
HttpRecvResult::Ok(Some((_index, chunk))) => String::from_utf8(chunk).unwrap(),
_ => panic!("test response body must yield the first SSE event"),
};
let second = match block_on(body.next()) {
HttpRecvResult::Ok(Some((_index, chunk))) => String::from_utf8(chunk).unwrap(),
_ => panic!("test response body must yield the second SSE event"),
};
assert!(
first.contains("data: first"),
"first queued event must be the first user call, got: {}",
first
);
assert!(
second.contains("event: notice") && second.contains("data: second"),
"second queued event must be the second user call, got: {}",
second
);
assert!(matches!(block_on(body.next()), HttpRecvResult::Fin(None)));
}
#[test]
fn sse_sender_and_hub_are_send_sync_clone() {
assert_send_sync::<SseSender<TestSocket>>();
assert_clone::<SseSender<TestSocket>>();
assert_send_sync::<SseHub<String, TestSocket>>();
assert_clone::<SseHub<String, TestSocket>>();
assert_send_sync::<SseMiddleware<SseConnectionId, TestSocket>>();
assert_clone::<SseMiddleware<SseConnectionId, TestSocket>>();
}
#[test]
fn sse_middleware_registers_default_connection_id_key() {
let hub = SseHub::<SseConnectionId, TestSocket>::builder().build();
let middleware = SseMiddleware::builder(hub.clone())
.config(test_middleware_config())
.build()
.expect("default SSE middleware must build");
let mut context = GatewayContext::new();
let req = test_get_request(HeaderMap::new());
let (req, resp) = expect_finish(block_on(middleware.request(&mut context, req)));
assert!(resp.is_stream());
let snapshot = hub.snapshot();
assert_eq!(snapshot.len(), 1);
assert_eq!(snapshot[0].key, snapshot[0].id);
let resp = expect_break(block_on(middleware.response(&mut context, req, resp)));
assert!(resp.is_stream());
}
#[test]
fn sse_middleware_custom_acceptor_on_open_and_multi_connection_key() {
let hub = SseHub::<String, TestSocket>::builder().build();
let saved_senders = Arc::new(StdMutex::new(Vec::<SseSender<TestSocket>>::new()));
let opened = Arc::new(AtomicU32::new(0));
let saved_for_open = saved_senders.clone();
let opened_for_open = opened.clone();
let middleware = SseMiddleware::with_acceptor(hub.clone(), |accept| {
assert_eq!(accept.request.url().path(), "/sse");
assert!(!accept.sender.is_closed());
Ok(SseAcceptDecision::accept("user-a".to_string()))
})
.config(test_middleware_config())
.on_open(move |open| {
assert_eq!(open.key, "user-a");
open.sender
.try_send_data(format!("connected-{}", open.id.get()))?;
saved_for_open.lock().unwrap().push(open.sender.clone());
opened_for_open.fetch_add(1, Ordering::Relaxed);
Ok(())
})
.build()
.expect("custom SSE middleware must build");
let mut context = GatewayContext::new();
let mut responses = Vec::new();
for _ in 0..2 {
let req = test_get_request(HeaderMap::new());
let (_req, resp) = expect_finish(block_on(middleware.request(&mut context, req)));
assert!(resp.is_stream());
assert!(drain_event(&resp).contains("data: connected-"));
responses.push(resp);
}
assert_eq!(opened.load(Ordering::Relaxed), 2);
assert_eq!(saved_senders.lock().unwrap().len(), 2);
assert_eq!(hub.get(&"user-a".to_string()).len(), 2);
let report = hub.try_send_to(&"user-a".to_string(), SseEvent::data("broadcast"));
assert_eq!(report.total, 2);
assert_eq!(report.sent, 2);
for resp in &responses {
assert!(drain_event(resp).contains("data: broadcast"));
}
}
#[test]
fn sse_middleware_acceptor_can_reject_before_stream_is_returned() {
let hub = SseHub::<String, TestSocket>::builder().build();
let middleware = SseMiddleware::with_acceptor(hub.clone(), |_accept| {
Ok(SseAcceptDecision::reject(
StatusCode::FORBIDDEN,
"sse denied",
))
})
.config(test_middleware_config())
.build()
.expect("rejecting SSE middleware must build");
let mut context = GatewayContext::new();
let req = test_get_request(HeaderMap::new());
let resp = expect_break(block_on(middleware.request(&mut context, req)));
assert!(!resp.is_stream());
assert!(hub.snapshot().is_empty());
let text = response_text(resp).to_ascii_lowercase();
assert!(text.contains("http/1.1 403"));
assert!(text.contains("sse denied"));
}
#[test]
fn sse_middleware_can_require_accept_event_stream_header() {
let hub = SseHub::<SseConnectionId, TestSocket>::builder().build();
let middleware = SseMiddleware::builder(hub.clone())
.config(test_middleware_config())
.require_accept_header(true)
.build()
.expect("accept-check SSE middleware must build");
let mut context = GatewayContext::new();
let req = test_get_request(HeaderMap::new());
let resp = expect_break(block_on(middleware.request(&mut context, req)));
assert!(!resp.is_stream());
assert!(response_text(resp).to_ascii_lowercase().contains("406"));
assert!(hub.snapshot().is_empty());
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
let req = test_get_request(headers);
let (_req, resp) = expect_finish(block_on(middleware.request(&mut context, req)));
assert!(resp.is_stream());
assert_eq!(hub.snapshot().len(), 1);
}
#[test]
fn sse_middleware_heartbeat_runtime_sends_comment() {
let _timer = pi_async_rt::rt::startup_global_time_loop(1);
let runtime = pi_async_rt::rt::AsyncRuntimeBuilder::default_multi_thread(
Some("sse-heartbeat-unit"),
None,
Some(1),
Some(1),
);
let hub = SseHub::<SseConnectionId, TestSocket>::builder().build();
let config = SseConfig::builder()
.channel_size(8)
.heartbeat_interval_ms(1)
.send_initial_comment(false)
.build()
.expect("heartbeat middleware config must be valid");
let middleware = SseMiddleware::builder(hub.clone())
.config(config)
.heartbeat_runtime(runtime.clone())
.build()
.expect("heartbeat SSE middleware must build");
let mut context = GatewayContext::new();
let req = test_get_request(HeaderMap::new());
let (_req, resp) = expect_finish(block_on(middleware.request(&mut context, req)));
assert!(resp.is_stream());
assert_eq!(hub.snapshot().len(), 1);
let heartbeat = drain_event(&resp);
assert_eq!(heartbeat, ":\n\n");
let id = hub.snapshot()[0].id;
hub.try_close(id)
.expect("heartbeat test connection must close");
let _ = runtime.close();
}
#[test]
fn sse_middleware_on_open_error_unregisters_sender() {
let hub = SseHub::<String, TestSocket>::builder().build();
let middleware = SseMiddleware::with_acceptor(hub.clone(), |_accept| {
Ok(SseAcceptDecision::accept("user-a".to_string()))
})
.config(test_middleware_config())
.on_open(|_open| Err(SseError::InvalidEvent("open rejected".to_string())))
.build()
.expect("on-open-failing SSE middleware must build");
let mut context = GatewayContext::new();
let req = test_get_request(HeaderMap::new());
let resp = expect_break(block_on(middleware.request(&mut context, req)));
assert!(!resp.is_stream());
assert!(hub.snapshot().is_empty());
let text = response_text(resp).to_ascii_lowercase();
assert!(text.contains("400"));
assert!(text.contains("open rejected"));
}
#[test]
fn sse_hub_register_get_unregister() {
let hub = SseHub::<String, TestSocket>::builder().build();
let fixture = test_sender(4);
let sender = &fixture.sender;
let id = hub.register("user-a".to_string(), sender.clone()).unwrap();
assert_eq!(hub.get(&"user-a".to_string()).len(), 1);
assert!(hub.get_by_id(id).is_some());
let unregistered = hub.unregister(id).unwrap();
assert_eq!(unregistered.id(), id);
assert!(hub.get_by_id(id).is_none());
assert!(!unregistered.is_closed());
}
#[test]
fn sse_hub_register_rejects_duplicate_sender_without_key_leak() {
let hub = SseHub::<String, TestSocket>::builder().build();
let fixture = test_sender(4);
let sender = &fixture.sender;
let id = hub.register("user-a".to_string(), sender.clone()).unwrap();
assert!(matches!(
hub.register("user-b".to_string(), sender.clone()),
Err(SseError::ConnectionIdExhausted)
));
assert_eq!(hub.get(&"user-a".to_string()).len(), 1);
assert_eq!(hub.get(&"user-b".to_string()).len(), 0);
assert_eq!(hub.unregister(id).unwrap().id(), id);
}
#[test]
fn sse_hub_try_send_to_id_reports_result() {
let hub = SseHub::<String, TestSocket>::builder().build();
let fixture = test_sender(1);
let sender = &fixture.sender;
let id = hub.register("user-a".to_string(), sender.clone()).unwrap();
let report = hub.try_send_to_id(id, SseEvent::data("first"));
assert_eq!(report.total, 1);
assert_eq!(report.sent, 1);
let report = hub.try_send_to_id(id, SseEvent::data("second"));
assert_eq!(report.queue_full, 1);
let report = hub.try_send_to_id(SseConnectionId(u32::MAX), SseEvent::data("missing"));
assert_eq!(report.not_found, 1);
}
#[test]
fn sse_hub_close_removes_and_finishes() {
let hub = SseHub::<String, TestSocket>::builder().build();
let fixture = test_sender(4);
let sender = &fixture.sender;
let id = hub.register("user-a".to_string(), sender.clone()).unwrap();
block_on(hub.close(id)).unwrap();
assert!(hub.get_by_id(id).is_none());
assert!(sender.is_closed());
}
#[test]
fn sse_hub_try_close_keeps_registration_when_queue_full() {
let hub = SseHub::<String, TestSocket>::builder().build();
let fixture = test_sender(1);
let sender = &fixture.sender;
let id = hub.register("user-a".to_string(), sender.clone()).unwrap();
sender.try_send_data("first").unwrap();
assert!(matches!(hub.try_close(id), Err(SseError::QueueFull)));
assert!(hub.get_by_id(id).is_some());
assert!(matches!(
sender.try_send_data("after-close-started"),
Err(SseError::Closed)
));
let body = fixture._resp.as_body().unwrap();
match block_on(body.next()) {
HttpRecvResult::Ok(Some((_index, chunk))) => {
assert!(String::from_utf8(chunk).unwrap().contains("data: first"));
}
_ => panic!("test response body must yield the queued SSE event"),
}
hub.try_close(id).unwrap();
assert!(hub.get_by_id(id).is_none());
assert!(sender.is_closed());
}
}