use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use serde::Serialize;
use serde_json::Value;
use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::time::Duration;
pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + Sync>>;
#[derive()]
pub enum ReplyData {
Json(Value),
Bytes {
content_type: Cow<'static, str>,
data: Vec<u8>,
},
Empty,
Rich(Box<ReplySpec>),
Stream(BodyStream),
Upgrade(Box<dyn Any + Send>),
}
impl fmt::Debug for ReplyData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReplyData::Json(j) => f.debug_tuple("Json").field(j).finish(),
ReplyData::Bytes { content_type, data } => f
.debug_struct("Bytes")
.field("content_type", content_type)
.field("data", data)
.finish(),
ReplyData::Empty => write!(f, "Empty"),
ReplyData::Rich(r) => f.debug_tuple("Rich").field(r).finish(),
ReplyData::Stream(_) => f.debug_tuple("Stream").field(&"...").finish(),
ReplyData::Upgrade(_) => f.debug_tuple("Upgrade").field(&"...").finish(),
}
}
}
impl PartialEq for ReplyData {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(ReplyData::Json(l), ReplyData::Json(r)) => l == r,
(
ReplyData::Bytes {
content_type: lc,
data: ld,
},
ReplyData::Bytes {
content_type: rc,
data: rd,
},
) => lc == rc && ld == rd,
(ReplyData::Empty, ReplyData::Empty) => true,
(ReplyData::Rich(l), ReplyData::Rich(r)) => l == r,
_ => false, }
}
}
pub type Reply = Result<ReplyData, WebError>;
impl ReplyData {
pub fn add_header(&mut self, name: impl Into<String>, value: impl Into<String>) {
if matches!(self, ReplyData::Upgrade(_)) {
return;
}
let name = name.into();
let value = value.into();
if let ReplyData::Rich(spec) = self {
spec.headers.insert(name, value);
return;
}
let payload = std::mem::replace(self, ReplyData::Empty);
let mut headers = HashMap::new();
headers.insert(name, value);
*self = ReplyData::Rich(Box::new(ReplySpec {
payload,
status: None,
headers,
}));
}
pub fn set_status(&mut self, status: http::StatusCode) {
if matches!(self, ReplyData::Upgrade(_)) {
return;
}
if let ReplyData::Rich(spec) = self {
spec.status = Some(status);
return;
}
let payload = std::mem::replace(self, ReplyData::Empty);
*self = ReplyData::Rich(Box::new(ReplySpec {
payload,
status: Some(status),
headers: HashMap::new(),
}));
}
}
pub fn json<T: Serialize>(val: T) -> ReplyData {
ReplyData::Json(serde_json::to_value(val).expect("serialize"))
}
pub fn try_json<T: Serialize>(val: T) -> Result<ReplyData, serde_json::Error> {
Ok(ReplyData::Json(serde_json::to_value(val)?))
}
pub fn bytes(content_type: impl Into<Cow<'static, str>>, data: impl Into<Vec<u8>>) -> ReplyData {
ReplyData::Bytes {
content_type: content_type.into(),
data: data.into(),
}
}
pub fn empty() -> ReplyData {
ReplyData::Empty
}
pub fn build_reply() -> ReplySpec {
ReplySpec::new()
}
pub fn stream<S>(s: S) -> ReplyData
where
S: Stream<Item = Result<Bytes, io::Error>> + Send + Sync + 'static,
{
ReplyData::Stream(Box::pin(s))
}
#[derive(Default, Debug, Clone)]
pub struct SseEvent {
event: Option<String>,
id: Option<String>,
retry: Option<Duration>,
data: Option<String>,
comment: Option<String>,
}
impl SseEvent {
pub fn data(data: impl Into<String>) -> Self {
Self {
data: Some(data.into()),
..Self::default()
}
}
pub fn comment(text: impl Into<String>) -> Self {
Self {
comment: Some(text.into()),
..Self::default()
}
}
pub fn event(mut self, name: impl Into<String>) -> Self {
self.event = Some(name.into());
self
}
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn retry(mut self, d: Duration) -> Self {
self.retry = Some(d);
self
}
pub fn to_bytes(&self) -> Bytes {
let mut out = String::new();
if let Some(event) = &self.event {
out.push_str("event: ");
out.push_str(&single_line(event));
out.push('\n');
}
if let Some(id) = &self.id {
out.push_str("id: ");
out.push_str(&single_line(id));
out.push('\n');
}
if let Some(retry) = self.retry {
let ms = u64::try_from(retry.as_millis()).unwrap_or(u64::MAX);
out.push_str("retry: ");
out.push_str(&ms.to_string());
out.push('\n');
}
if let Some(comment) = &self.comment {
for line in comment.split('\n') {
out.push(':');
if !line.is_empty() {
out.push(' ');
out.push_str(line);
}
out.push('\n');
}
}
if let Some(data) = &self.data {
for line in data.split('\n') {
out.push_str("data: ");
out.push_str(line);
out.push('\n');
}
}
out.push('\n'); Bytes::from(out)
}
}
pub fn sse<S>(events: S) -> ReplyData
where
S: Stream<Item = SseEvent> + Send + Sync + 'static,
{
let byte_stream = events.map(|ev| Ok::<Bytes, io::Error>(ev.to_bytes()));
let mut headers = HashMap::new();
headers.insert("content-type".to_string(), "text/event-stream".to_string());
headers.insert("cache-control".to_string(), "no-cache".to_string());
ReplyData::Rich(Box::new(ReplySpec {
payload: ReplyData::Stream(Box::pin(byte_stream)),
status: None,
headers,
}))
}
fn single_line(s: &str) -> String {
s.replace(['\n', '\r'], " ")
}
#[derive(Debug, PartialEq)]
pub struct ReplySpec {
pub payload: ReplyData,
pub status: Option<http::StatusCode>,
pub headers: HashMap<String, String>,
}
impl Default for ReplySpec {
fn default() -> Self {
Self::new()
}
}
impl ReplySpec {
pub fn new() -> Self {
Self {
payload: ReplyData::Empty,
status: None,
headers: HashMap::new(),
}
}
pub fn body(mut self, data: ReplyData) -> Self {
self.payload = data;
self
}
pub fn status(mut self, s: http::StatusCode) -> Self {
self.status = Some(s);
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn done(self) -> ReplyData {
ReplyData::Rich(Box::new(self))
}
}
#[macro_export]
macro_rules! reply {
(status = $status:expr, headers = { $($k:literal : $v:expr),* $(,)? }, $body:expr) => {{
match $crate::reply::try_json($body) {
Ok(__rd) => {
let mut spec = $crate::reply::build_reply()
.status($status)
.body(__rd);
$(
spec = spec.header($k, &$v.to_string());
)*
Ok(spec.done())
}
Err(e) => Err($crate::reply::WebError::Internal(
::std::format!("serialize response: {}", e),
)),
}
}};
(status = $status:expr, $body:expr) => {
match $crate::reply::try_json($body) {
Ok(__rd) => Ok($crate::reply::build_reply()
.status($status)
.body(__rd)
.done()),
Err(e) => Err($crate::reply::WebError::Internal(
::std::format!("serialize response: {}", e),
)),
}
};
(stream: $s:expr) => {
Ok($crate::reply::stream($s))
};
(sse: $s:expr) => {
Ok($crate::reply::sse($s))
};
() => {
Ok($crate::reply::empty())
};
($expr:expr) => {
match $crate::reply::try_json($expr) {
Ok(__rd) => Ok(__rd),
Err(e) => Err($crate::reply::WebError::Internal(
::std::format!("serialize response: {}", e),
)),
}
};
}
#[derive(Debug, Clone)]
pub struct ProblemDetails {
pub status: http::StatusCode,
pub title: String,
pub detail: Option<String>,
pub extra: Box<serde_json::Map<String, serde_json::Value>>,
}
impl ProblemDetails {
pub fn new(status: http::StatusCode, title: impl Into<String>) -> Self {
Self {
status,
title: title.into(),
detail: None,
extra: Box::new(serde_json::Map::new()),
}
}
pub fn detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
pub fn extra(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
impl fmt::Display for ProblemDetails {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.detail {
Some(d) => write!(f, "{}: {}", self.title, d),
None => write!(f, "{}", self.title),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum WebError {
#[error("Not found")]
NotFound,
#[error("Method not allowed; allowed methods: {0:?}")]
MethodNotAllowed(Vec<&'static str>),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Payload too large")]
PayloadTooLarge,
#[error("Too many requests")]
TooManyRequests(Option<Duration>),
#[error("Request timeout")]
Timeout,
#[error("Server busy")]
Busy(Option<Duration>),
#[error("Unauthorized")]
Unauthorized,
#[error("Forbidden")]
Forbidden,
#[error("Internal error: {0}")]
Internal(String),
#[error("{0}")]
Problem(ProblemDetails),
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{StreamExt, stream};
use http::StatusCode;
use serde_json::json;
#[test]
fn test_macro_returns_result() {
let success_reply: Reply = reply!(json!({"status": "ok"}));
assert!(success_reply.is_ok());
let empty_reply: Reply = reply!();
assert!(empty_reply.is_ok());
}
#[test]
fn test_constructors() {
matches!(json(json!({"ok": true})), ReplyData::Json(_));
matches!(
bytes("application/octet-stream", vec![0, 1]),
ReplyData::Bytes { .. }
);
matches!(empty(), ReplyData::Empty);
}
#[tokio::test]
async fn test_stream_constructor() {
let test_stream = stream::once(async { Ok::<_, io::Error>(Bytes::from("test")) });
let reply_data = stream(test_stream);
if let ReplyData::Stream(mut s) = reply_data {
assert_eq!(s.next().await.unwrap().unwrap(), "test");
} else {
panic!("Expected a stream reply");
}
}
#[test]
fn test_reply_macro_variants() {
let data_res: Reply = reply!(json!({"a":1}));
matches!(data_res.unwrap(), ReplyData::Json(_));
#[derive(Serialize)]
struct TestStruct {
val: i32,
}
let data_res: Reply = reply!(TestStruct { val: 10 });
matches!(data_res.unwrap(), ReplyData::Json(_));
let data_res: Reply = reply!();
matches!(data_res.unwrap(), ReplyData::Empty);
let data_res: Reply = reply!(status = StatusCode::CREATED, json!({"id": 1}));
match data_res.unwrap() {
ReplyData::Rich(spec) => {
assert_eq!(spec.status, Some(StatusCode::CREATED));
}
_ => panic!("Expected Rich reply"),
}
let data_res: Reply =
reply!(status = StatusCode::CREATED, headers = {"X-Test": "value"}, json!({"id": 1}));
match data_res.unwrap() {
ReplyData::Rich(spec) => {
assert_eq!(spec.headers.get("X-Test"), Some(&"value".to_string()));
}
_ => panic!("Expected Rich reply"),
}
}
#[test]
fn add_header_lifts_non_rich_into_rich_preserving_payload() {
let mut r = ReplyData::Json(json!({"x": 1}));
r.add_header("X-Trace-Id", "abc");
match &r {
ReplyData::Rich(spec) => {
assert!(matches!(spec.payload, ReplyData::Json(_)));
assert_eq!(spec.headers.get("X-Trace-Id"), Some(&"abc".to_string()));
assert_eq!(spec.status, None);
}
_ => panic!("expected Rich, got {r:?}"),
}
let mut r = ReplyData::Empty;
r.add_header("X-K", "v");
assert!(matches!(&r, ReplyData::Rich(s) if matches!(s.payload, ReplyData::Empty)));
let mut r = ReplyData::Rich(Box::new(ReplySpec {
payload: ReplyData::Json(json!({"y": 2})),
status: Some(StatusCode::CREATED),
headers: HashMap::from([("Existing".into(), "1".into())]),
}));
r.add_header("New", "2");
match &r {
ReplyData::Rich(spec) => {
assert_eq!(spec.status, Some(StatusCode::CREATED));
assert_eq!(spec.headers.get("Existing"), Some(&"1".to_string()));
assert_eq!(spec.headers.get("New"), Some(&"2".to_string()));
}
_ => unreachable!(),
}
}
#[test]
fn set_status_lifts_or_mutates() {
let mut r = ReplyData::Json(json!({"x": 1}));
r.set_status(StatusCode::ACCEPTED);
match &r {
ReplyData::Rich(spec) => {
assert_eq!(spec.status, Some(StatusCode::ACCEPTED));
assert!(matches!(spec.payload, ReplyData::Json(_)));
}
_ => panic!("expected Rich"),
}
let mut r = ReplyData::Rich(Box::new(ReplySpec {
payload: ReplyData::Empty,
status: None,
headers: HashMap::new(),
}));
r.set_status(StatusCode::NOT_FOUND);
if let ReplyData::Rich(spec) = &r {
assert_eq!(spec.status, Some(StatusCode::NOT_FOUND));
}
}
#[test]
fn add_header_and_set_status_are_a_noop_on_upgrade() {
let task: Box<dyn std::any::Any + Send> = Box::new(42i32); let mut r = ReplyData::Upgrade(task);
r.add_header("X-K", "v");
r.set_status(StatusCode::CREATED);
assert!(matches!(r, ReplyData::Upgrade(_))); }
fn s(b: &Bytes) -> &str {
std::str::from_utf8(b).expect("SSE encoding is always valid UTF-8")
}
#[test]
fn sse_data_only_event_emits_one_data_line_and_blank_separator() {
let bytes = SseEvent::data("hello").to_bytes();
assert_eq!(s(&bytes), "data: hello\n\n");
}
#[test]
fn sse_multi_line_data_emits_one_data_line_per_source_line() {
let bytes = SseEvent::data("a\nb\nc").to_bytes();
assert_eq!(s(&bytes), "data: a\ndata: b\ndata: c\n\n");
}
#[test]
fn sse_event_id_retry_data_all_in_one_frame() {
let bytes = SseEvent::data("payload")
.event("update")
.id("17")
.retry(Duration::from_millis(2500))
.to_bytes();
assert_eq!(
s(&bytes),
"event: update\nid: 17\nretry: 2500\ndata: payload\n\n",
);
}
#[test]
fn sse_event_and_id_strip_embedded_newlines() {
let bytes = SseEvent::data("ok").event("a\nb").id("x\ry").to_bytes();
assert_eq!(s(&bytes), "event: a b\nid: x y\ndata: ok\n\n");
}
#[test]
fn sse_comment_only_event_is_a_heartbeat() {
let single = SseEvent::comment("keep-alive").to_bytes();
assert_eq!(s(&single), ": keep-alive\n\n");
let multi = SseEvent::comment("line1\n\nline3").to_bytes();
assert_eq!(s(&multi), ": line1\n:\n: line3\n\n");
}
#[test]
fn sse_empty_data_string_still_emits_a_data_field() {
let bytes = SseEvent::data("").to_bytes();
assert_eq!(s(&bytes), "data: \n\n");
}
#[tokio::test]
async fn sse_stream_wraps_in_rich_with_correct_headers_and_streams_each_event() {
use futures_util::stream;
let events = stream::iter(vec![
SseEvent::data("first").id("1"),
SseEvent::data("second").event("update"),
]);
let reply = sse(events);
let ReplyData::Rich(spec) = reply else {
panic!("sse() produces a Rich reply for the headers");
};
assert_eq!(
spec.headers.get("content-type"),
Some(&"text/event-stream".to_string())
);
assert_eq!(
spec.headers.get("cache-control"),
Some(&"no-cache".to_string())
);
let ReplyData::Stream(byte_stream) = spec.payload else {
panic!("sse() payload is a Stream of bytes");
};
let chunks: Vec<Bytes> = byte_stream
.map(|r| r.expect("infallible map"))
.collect()
.await;
let joined: String = chunks
.iter()
.map(|b| std::str::from_utf8(b).unwrap())
.collect();
assert_eq!(
joined,
"id: 1\ndata: first\n\nevent: update\ndata: second\n\n",
);
}
#[test]
fn test_builder_pattern() {
let data = build_reply()
.status(StatusCode::ACCEPTED)
.header("X-Custom", "value123")
.body(bytes("application/zip", vec![1, 2, 3]))
.done();
match data {
ReplyData::Rich(spec) => {
assert_eq!(spec.status, Some(StatusCode::ACCEPTED));
assert_eq!(spec.headers.get("X-Custom"), Some(&"value123".to_string()));
assert_eq!(
spec.payload,
ReplyData::Bytes {
content_type: Cow::Borrowed("application/zip"),
data: vec![1, 2, 3],
}
);
}
_ => panic!("Expected Rich reply"),
}
}
}