use crate::domain::entities::Frame;
use async_stream::try_stream;
use axum::{
http::{HeaderMap, StatusCode, header},
response::Response,
};
use futures::{Stream, StreamExt};
#[derive(Debug, Clone, Copy)]
pub enum StreamFormat {
Json,
NdJson,
ServerSentEvents,
Binary,
}
impl StreamFormat {
pub fn from_accept_header(headers: &HeaderMap) -> Self {
if let Some(accept) = headers.get(header::ACCEPT)
&& let Ok(accept_str) = accept.to_str()
{
if accept_str.contains("text/event-stream") {
return Self::ServerSentEvents;
} else if accept_str.contains("application/x-ndjson") {
return Self::NdJson;
} else if accept_str.contains("application/octet-stream") {
return Self::Binary;
}
}
Self::Json
}
pub fn content_type(&self) -> &'static str {
match self {
Self::Json => "application/json",
Self::NdJson => "application/x-ndjson",
Self::ServerSentEvents => "text/event-stream",
Self::Binary => "application/octet-stream",
}
}
}
fn frame_to_value(frame: &Frame) -> serde_json::Value {
serde_json::json!({
"type": format!("{:?}", frame.frame_type()),
"priority": frame.priority().value(),
"sequence": frame.sequence(),
"timestamp": frame.timestamp().to_rfc3339(),
"payload": frame.payload(),
"metadata": frame.metadata()
})
}
fn format_frame_owned(frame: &Frame, format: StreamFormat) -> Result<String, StreamError> {
let v = frame_to_value(frame);
match format {
StreamFormat::Json => Ok(serde_json::to_string(&v)?),
StreamFormat::NdJson => Ok(format!("{}\n", serde_json::to_string(&v)?)),
StreamFormat::ServerSentEvents => Ok(format!("data: {}\n\n", serde_json::to_string(&v)?)),
StreamFormat::Binary => Ok(serde_json::to_string(&v)?),
}
}
fn format_batch_owned(frames: &[Frame], format: StreamFormat) -> Result<String, StreamError> {
let values: Vec<_> = frames.iter().map(frame_to_value).collect();
match format {
StreamFormat::Json | StreamFormat::NdJson => {
let mut out = String::new();
for v in values {
out.push_str(&serde_json::to_string(&v)?);
out.push('\n');
}
Ok(out)
}
StreamFormat::ServerSentEvents => {
let mut out = String::new();
for v in values {
out.push_str(&format!("data: {}\n\n", serde_json::to_string(&v)?));
}
Ok(out)
}
StreamFormat::Binary => Ok(serde_json::to_string(&values)?),
}
}
fn maybe_compress_owned(s: String, enabled: bool) -> Result<String, StreamError> {
#[cfg(feature = "compression")]
if enabled {
use crate::compression::secure::{ByteCodec, SecureCompressor};
let compressor = SecureCompressor::with_default_security(ByteCodec::Gzip);
let compressed = compressor
.compress(s.as_bytes())
.map_err(|e| StreamError::Io(e.to_string()))?;
return String::from_utf8(compressed.data)
.map_err(|e| StreamError::Io(format!("compressed output is not valid UTF-8: {e}")));
}
#[cfg(not(feature = "compression"))]
let _ = enabled;
Ok(s)
}
pub struct AdaptiveFrameStream<S> {
inner: S,
format: StreamFormat,
compression: bool,
buffer_size: usize,
}
impl<S> AdaptiveFrameStream<S>
where
S: Stream<Item = Frame> + Unpin + Send + 'static,
{
pub fn new(stream: S, format: StreamFormat) -> Self {
Self {
inner: stream,
format,
compression: false,
buffer_size: 10,
}
}
pub fn with_compression(mut self, enabled: bool) -> Self {
self.compression = enabled;
self
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn into_stream(self) -> impl Stream<Item = Result<String, StreamError>> + Send + 'static {
let Self {
inner,
format,
compression,
buffer_size,
} = self;
try_stream! {
let mut chunked = inner.ready_chunks(buffer_size);
while let Some(batch) = chunked.next().await {
for frame in batch {
let s = format_frame_owned(&frame, format)?;
let s = maybe_compress_owned(s, compression)?;
yield s;
}
}
}
}
}
pub struct BatchFrameStream<S> {
inner: S,
format: StreamFormat,
batch_size: usize,
}
impl<S> BatchFrameStream<S>
where
S: Stream<Item = Frame> + Unpin + Send + 'static,
{
pub fn new(stream: S, format: StreamFormat, batch_size: usize) -> Self {
Self {
inner: stream,
format,
batch_size,
}
}
pub fn content_type(&self) -> &'static str {
match self.format {
StreamFormat::Json => "application/x-ndjson",
other => other.content_type(),
}
}
pub fn into_stream(self) -> impl Stream<Item = Result<String, StreamError>> + Send + 'static {
let Self {
inner,
format,
batch_size,
} = self;
try_stream! {
let mut batch: Vec<Frame> = Vec::with_capacity(batch_size);
futures::pin_mut!(inner);
while let Some(frame) = inner.next().await {
batch.push(frame);
if batch.len() >= batch_size {
let s = format_batch_owned(&batch, format)?;
batch.clear();
yield s;
}
}
if !batch.is_empty() {
let s = format_batch_owned(&batch, format)?;
yield s;
}
}
}
}
pub struct PriorityFrameStream<S> {
inner: S,
format: StreamFormat,
buffer_size: usize,
}
#[derive(Debug, Clone)]
struct PriorityFrame {
frame: Frame,
priority: u8,
}
impl PartialEq for PriorityFrame {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority
}
}
impl Eq for PriorityFrame {}
impl PartialOrd for PriorityFrame {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityFrame {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority.cmp(&other.priority)
}
}
impl<S> PriorityFrameStream<S>
where
S: Stream<Item = Frame> + Unpin + Send + 'static,
{
pub fn new(stream: S, format: StreamFormat, buffer_size: usize) -> Self {
Self {
inner: stream,
format,
buffer_size,
}
}
pub fn into_stream(self) -> impl Stream<Item = Result<String, StreamError>> + Send + 'static {
let Self {
inner,
format,
buffer_size,
} = self;
try_stream! {
let mut heap = std::collections::BinaryHeap::<PriorityFrame>::with_capacity(buffer_size);
let mut inner_done = false;
futures::pin_mut!(inner);
loop {
while !inner_done && heap.len() < buffer_size {
match inner.next().await {
Some(frame) => {
let priority = frame.priority().value();
heap.push(PriorityFrame { frame, priority });
}
None => inner_done = true,
}
}
match heap.pop() {
Some(pf) => {
let s = format_frame_owned(&pf.frame, format)?;
yield s;
}
None if inner_done => break,
None => unreachable!("loop above guarantees inner_done or non-empty heap"),
}
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum StreamError {
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("IO error: {0}")]
Io(String),
#[error("Buffer overflow")]
BufferOverflow,
#[error("Stream closed")]
StreamClosed,
}
pub fn create_streaming_response<S>(
stream: S,
format: StreamFormat,
) -> Result<Response, StreamError>
where
S: Stream<Item = Result<String, StreamError>> + Send + 'static,
{
let body = axum::body::Body::from_stream(stream);
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, format.content_type())
.header(header::CACHE_CONTROL, "no-cache");
match format {
StreamFormat::ServerSentEvents => {
response = response
.header(header::CONNECTION, "keep-alive")
.header("X-Accel-Buffering", "no");
}
StreamFormat::NdJson => {
response = response.header("Transfer-Encoding", "chunked");
}
_ => {}
}
response
.body(body)
.map_err(|e| StreamError::Io(e.to_string()))
}
pub fn create_streaming_response_with_content_type<S>(
stream: S,
content_type: &str,
) -> Result<Response, StreamError>
where
S: Stream<Item = Result<String, StreamError>> + Send + 'static,
{
let body = axum::body::Body::from_stream(stream);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::CACHE_CONTROL, "no-cache")
.body(body)
.map_err(|e| StreamError::Io(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::entities::Frame;
use crate::domain::value_objects::{JsonData, JsonPath, Priority, StreamId};
use axum::http::header;
use futures::StreamExt;
use futures::stream;
use pjson_rs_domain::entities::frame::FramePatch;
use std::pin::Pin;
use std::task::{Context, Poll};
fn make_skeleton_frame() -> Frame {
Frame::skeleton(StreamId::new(), 1, JsonData::Null)
}
fn make_patch_frame(priority: Priority) -> Frame {
let path = JsonPath::new("$.x").expect("valid path");
let patch = FramePatch::set(path, JsonData::Null);
Frame::patch(StreamId::new(), 1, priority, vec![patch]).expect("valid patch frame")
}
struct PendingThenReady<I: Iterator> {
iter: I,
pending_remaining: usize,
pending_per_item: usize,
done: bool,
}
impl<I: Iterator> PendingThenReady<I> {
fn new(iter: I, pending_per_item: usize) -> Self {
Self {
iter,
pending_remaining: pending_per_item,
pending_per_item,
done: false,
}
}
}
impl<I: Iterator + Unpin> Stream for PendingThenReady<I> {
type Item = I::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
if self.pending_remaining > 0 {
self.pending_remaining -= 1;
cx.waker().wake_by_ref();
return Poll::Pending;
}
match self.iter.next() {
Some(item) => {
self.pending_remaining = self.pending_per_item;
Poll::Ready(Some(item))
}
None => {
self.done = true;
Poll::Ready(None)
}
}
}
}
#[test]
fn test_stream_format_detection() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT, "text/event-stream".parse().unwrap());
let format = StreamFormat::from_accept_header(&headers);
assert!(matches!(format, StreamFormat::ServerSentEvents));
}
#[tokio::test]
async fn test_adaptive_stream_empty() {
let frame_stream = stream::iter(Vec::<Frame>::new());
let adaptive = AdaptiveFrameStream::new(frame_stream, StreamFormat::Json);
let collected: Vec<_> = adaptive.into_stream().collect().await;
assert!(collected.is_empty());
}
#[tokio::test]
async fn test_batch_frame_stream_multiple_batches() {
let frames: Vec<Frame> = (0..5).map(|_| make_skeleton_frame()).collect();
let frame_stream = stream::iter(frames);
let batch_stream = BatchFrameStream::new(frame_stream, StreamFormat::Json, 2);
let collected: Vec<Result<String, StreamError>> =
batch_stream.into_stream().collect().await;
assert_eq!(
collected.len(),
3,
"expected 3 batches for 5 frames with batch_size=2"
);
let mut total_objects = 0usize;
for result in &collected {
let batch_str = result.as_ref().expect("batch should not error");
for line in batch_str.lines() {
if line.is_empty() {
continue;
}
let parsed: serde_json::Value =
serde_json::from_str(line).expect("each line must be valid JSON");
assert!(
parsed.is_object(),
"each line must be a JSON object (NDJSON-of-objects), got: {line}"
);
total_objects += 1;
}
}
assert_eq!(
total_objects, 5,
"total parsed objects across all batches must equal 5"
);
}
#[tokio::test]
async fn test_priority_stream_terminates() {
let frames: Vec<Frame> = (0..4).map(|_| make_skeleton_frame()).collect();
let frame_stream = stream::iter(frames);
let priority_stream = PriorityFrameStream::new(frame_stream, StreamFormat::Json, 8);
let collected: Vec<Result<String, StreamError>> =
priority_stream.into_stream().collect().await;
assert_eq!(collected.len(), 4);
for result in &collected {
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_priority_stream_ordering() {
let frames = vec![
make_patch_frame(Priority::new(10).unwrap()),
make_patch_frame(Priority::new(50).unwrap()),
make_patch_frame(Priority::new(30).unwrap()),
];
let frame_stream = stream::iter(frames);
let priority_stream = PriorityFrameStream::new(frame_stream, StreamFormat::Json, 8);
let collected: Vec<_> = priority_stream
.into_stream()
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.expect("no error"))
.collect();
let priorities: Vec<u64> = collected
.iter()
.map(|s| {
let v: serde_json::Value = serde_json::from_str(s).unwrap();
v["priority"].as_u64().unwrap()
})
.collect();
assert_eq!(
priorities,
vec![50, 30, 10],
"frames must be ordered highest priority first"
);
}
#[test]
fn test_adaptive_stream_makes_progress_under_pending() {
tokio_test::block_on(async {
let frames: Vec<Frame> = (0..6).map(|_| make_skeleton_frame()).collect();
let inner = PendingThenReady::new(frames.into_iter(), 3);
let adaptive = AdaptiveFrameStream::new(inner, StreamFormat::Json);
let collected: Vec<_> = adaptive.into_stream().collect().await;
assert_eq!(collected.len(), 6);
for r in collected {
assert!(r.is_ok());
}
});
}
#[test]
fn test_batch_stream_emits_only_full_batches_under_pending() {
tokio_test::block_on(async {
let frames: Vec<Frame> = (0..6).map(|_| make_skeleton_frame()).collect();
let inner = PendingThenReady::new(frames.into_iter(), 2);
let batch = BatchFrameStream::new(inner, StreamFormat::Json, 3);
let collected: Vec<_> = batch.into_stream().collect().await;
assert_eq!(
collected.len(),
2,
"6 frames at batch_size=3 must yield exactly 2 batches"
);
for r in collected {
assert!(r.is_ok());
}
});
}
#[tokio::test]
async fn test_batch_stream_ndjson_objects_per_line() {
let make_frames = || -> Vec<Frame> { (0..3).map(|_| make_skeleton_frame()).collect() };
let result_json: Vec<_> =
BatchFrameStream::new(stream::iter(make_frames()), StreamFormat::Json, 10)
.into_stream()
.collect()
.await;
assert_eq!(result_json.len(), 1);
let json_str = result_json[0].as_ref().unwrap();
for line in json_str.lines() {
if line.is_empty() {
continue;
}
let v: serde_json::Value = serde_json::from_str(line).unwrap();
assert!(v.is_object(), "Json format: each line must be an object");
}
let result_ndjson: Vec<_> =
BatchFrameStream::new(stream::iter(make_frames()), StreamFormat::NdJson, 10)
.into_stream()
.collect()
.await;
assert_eq!(result_ndjson.len(), 1);
let ndjson_str = result_ndjson[0].as_ref().unwrap();
for line in ndjson_str.lines() {
if line.is_empty() {
continue;
}
let v: serde_json::Value = serde_json::from_str(line).unwrap();
assert!(v.is_object(), "NdJson format: each line must be an object");
}
let json_count = json_str.lines().filter(|l| !l.is_empty()).count();
let ndjson_count = ndjson_str.lines().filter(|l| !l.is_empty()).count();
assert_eq!(
json_count, ndjson_count,
"Json and NdJson must produce the same object count"
);
let result_sse: Vec<_> = BatchFrameStream::new(
stream::iter(make_frames()),
StreamFormat::ServerSentEvents,
10,
)
.into_stream()
.collect()
.await;
assert_eq!(result_sse.len(), 1);
let sse_str = result_sse[0].as_ref().unwrap();
let sse_frames: Vec<&str> = sse_str.split("\n\n").filter(|s| !s.is_empty()).collect();
assert_eq!(sse_frames.len(), 3);
for frame_str in sse_frames {
assert!(frame_str.starts_with("data: "));
let json_part = &frame_str["data: ".len()..];
let v: serde_json::Value = serde_json::from_str(json_part).unwrap();
assert!(v.is_object());
}
let result_binary: Vec<_> =
BatchFrameStream::new(stream::iter(make_frames()), StreamFormat::Binary, 10)
.into_stream()
.collect()
.await;
assert_eq!(result_binary.len(), 1);
let binary_str = result_binary[0].as_ref().unwrap();
let v: serde_json::Value = serde_json::from_str(binary_str).unwrap();
assert!(v.is_array());
assert_eq!(v.as_array().unwrap().len(), 3);
}
#[test]
fn test_priority_stream_terminates_under_pending() {
tokio_test::block_on(async {
let frames: Vec<Frame> = (0..5).map(|_| make_skeleton_frame()).collect();
let inner = PendingThenReady::new(frames.into_iter(), 4);
let priority = PriorityFrameStream::new(inner, StreamFormat::Json, 8);
let collected: Vec<_> = priority.into_stream().collect().await;
assert_eq!(collected.len(), 5);
for r in collected {
assert!(r.is_ok());
}
});
}
#[tokio::test]
async fn test_create_streaming_response_with_content_type_uses_explicit_type() {
let frames: Vec<Frame> = (0..2).map(|_| make_skeleton_frame()).collect();
let batch = BatchFrameStream::new(stream::iter(frames), StreamFormat::Json, 10);
let expected_ct = batch.content_type();
assert_eq!(
expected_ct, "application/x-ndjson",
"BatchFrameStream with Json format must report application/x-ndjson"
);
let response =
create_streaming_response_with_content_type(batch.into_stream(), expected_ct)
.expect("response must be built");
let ct = response
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header must be present")
.to_str()
.unwrap();
assert_eq!(ct, "application/x-ndjson");
}
#[tokio::test]
async fn test_create_streaming_response_uses_format_content_type() {
let frames: Vec<Frame> = (0..1).map(|_| make_skeleton_frame()).collect();
let batch = BatchFrameStream::new(stream::iter(frames), StreamFormat::Json, 10);
let response = create_streaming_response(batch.into_stream(), StreamFormat::Json)
.expect("response must be built");
let ct = response
.headers()
.get(header::CONTENT_TYPE)
.expect("Content-Type header must be present")
.to_str()
.unwrap();
assert_eq!(ct, "application/json");
}
#[test]
fn test_priority_stream_ordering_preserved_under_pending() {
tokio_test::block_on(async {
let frames = vec![
make_patch_frame(Priority::new(10).unwrap()),
make_patch_frame(Priority::new(50).unwrap()),
make_patch_frame(Priority::new(30).unwrap()),
make_patch_frame(Priority::new(80).unwrap()),
];
let inner = PendingThenReady::new(frames.into_iter(), 2);
let priority = PriorityFrameStream::new(inner, StreamFormat::Json, 10);
let collected: Vec<_> = priority
.into_stream()
.collect::<Vec<_>>()
.await
.into_iter()
.map(|r| r.expect("no error"))
.collect();
assert_eq!(collected.len(), 4);
let priorities: Vec<u64> = collected
.iter()
.map(|s| {
let v: serde_json::Value = serde_json::from_str(s).unwrap();
v["priority"].as_u64().unwrap()
})
.collect();
assert_eq!(priorities, vec![80, 50, 30, 10]);
});
}
}