use crate::config::{LongPollTimeout, SseReconnectInterval};
use crate::protocol::cursor;
use crate::protocol::error::Error;
use crate::protocol::headers::names;
use crate::protocol::json_mode;
use crate::protocol::offset::Offset;
use crate::protocol::problem::{Result, request_instance};
use crate::protocol::sse::{self, ControlPayload};
use crate::protocol::stream_name::StreamName;
use crate::router::ShutdownToken;
use crate::storage::{ReadResult, Storage};
use axum::{
Extension,
body::Body,
extract::{OriginalUri, Query, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use bytes::{BufMut, BytesMut};
use serde::Deserialize;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Deserialize)]
pub struct ReadQuery {
offset: Option<String>,
live: Option<String>,
#[allow(dead_code)]
cursor: Option<String>,
}
pub async fn read_stream<S: Storage + 'static>(
State(storage): State<Arc<S>>,
StreamName(name): StreamName,
original_uri: OriginalUri,
Query(query): Query<ReadQuery>,
Extension(LongPollTimeout(timeout)): Extension<LongPollTimeout>,
Extension(SseReconnectInterval(reconnect_interval_secs)): Extension<SseReconnectInterval>,
Extension(ShutdownToken(shutdown)): Extension<ShutdownToken>,
headers: HeaderMap,
) -> Result<Response> {
let instance = request_instance(&original_uri);
let result = async {
let raw_offset = if let Some(ref live) = query.live {
match query.offset {
Some(ref o) => o.clone(),
None => {
return Err(Error::InvalidHeader {
header: "offset".to_string(),
reason: format!("offset query parameter is required for live={live} mode"),
}
.into());
}
}
} else {
query.offset.clone().unwrap_or_else(|| "-1".to_string())
};
let offset = Offset::from_str(&raw_offset)?;
let metadata = storage.head(&name)?;
let content_type = metadata.config.content_type.clone();
if let Some(ref live) = query.live {
match live.as_str() {
"long-poll" => {
let if_none_match = headers.get("if-none-match").and_then(|v| v.to_str().ok());
read_long_poll(
&storage,
&name,
&offset,
&raw_offset,
if_none_match,
&content_type,
timeout,
shutdown,
)
.await
}
"sse" => read_sse(
storage,
name,
&offset,
&content_type,
reconnect_interval_secs,
shutdown,
),
other => Err(Error::InvalidHeader {
header: "live".to_string(),
reason: format!("unsupported live mode: {other}"),
}
.into()),
}
} else {
let if_none_match = headers.get("if-none-match").and_then(|v| v.to_str().ok());
read_catch_up(
&storage,
&name,
&offset,
&raw_offset,
if_none_match,
&content_type,
)
}
}
.await;
result.map_err(|problem| problem.with_instance(instance))
}
fn read_catch_up<S: Storage>(
storage: &Arc<S>,
name: &str,
offset: &Offset,
raw_offset: &str,
if_none_match: Option<&str>,
content_type: &str,
) -> Result<Response> {
let read_result = storage.read(name, offset)?;
let etag = generate_etag(raw_offset, &read_result);
if let Some(client_etag) = if_none_match
&& client_etag == etag
{
return Ok(build_304_response(&read_result));
}
Ok(build_data_response(&read_result, content_type, &etag, None))
}
#[allow(clippy::too_many_arguments)]
async fn read_long_poll<S: Storage>(
storage: &Arc<S>,
name: &str,
offset: &Offset,
raw_offset: &str,
if_none_match: Option<&str>,
content_type: &str,
timeout: Duration,
shutdown: CancellationToken,
) -> Result<Response> {
let mut receiver = storage
.subscribe(name)
.ok_or_else(|| Error::NotFound(name.to_string()))?;
let read_result = storage.read(name, offset)?;
let etag = generate_etag(raw_offset, &read_result);
if let Some(client_etag) = if_none_match
&& client_etag == etag
{
return Ok(build_304_response(&read_result));
}
if !read_result.messages.is_empty() {
let cursor_val = cursor::generate(&read_result.next_offset);
return Ok(build_data_response(
&read_result,
content_type,
&etag,
Some(&cursor_val),
));
}
if read_result.closed && read_result.at_tail {
return Ok(build_204_response(&read_result.next_offset, true));
}
let tail_offset = read_result.next_offset.clone();
let tail_offset_str = tail_offset.to_string();
tokio::select! {
_ = receiver.recv() => {
handle_long_poll_wake(storage, name, &tail_offset, &tail_offset_str, content_type)
}
() = tokio::time::sleep(timeout) => {
let read_result = storage.read(name, &tail_offset)?;
let is_closed = read_result.closed && read_result.at_tail;
Ok(build_204_response(&read_result.next_offset, is_closed))
}
() = shutdown.cancelled() => {
let read_result = storage.read(name, &tail_offset)?;
let is_closed = read_result.closed && read_result.at_tail;
Ok(build_204_response(&read_result.next_offset, is_closed))
}
}
}
fn read_sse<S: Storage + 'static>(
storage: Arc<S>,
name: String,
offset: &Offset,
content_type: &str,
reconnect_interval_secs: u64,
shutdown: CancellationToken,
) -> Result<Response> {
let is_binary = sse::is_binary_content_type(content_type);
let is_json = json_mode::is_json_content_type(content_type);
let receiver = storage
.subscribe(&name)
.ok_or_else(|| Error::NotFound(name.clone()))?;
let read_result = storage.read(&name, offset)?;
let byte_stream = build_sse_byte_stream(
storage,
name,
read_result,
receiver,
is_binary,
is_json,
reconnect_interval_secs,
shutdown,
);
let body = Body::from_stream(byte_stream);
let mut headers = HeaderMap::new();
headers.insert("content-type", "text/event-stream".parse().unwrap());
if is_binary {
headers.insert("stream-sse-data-encoding", "base64".parse().unwrap());
}
Ok((StatusCode::OK, headers, body).into_response())
}
#[allow(clippy::too_many_arguments)]
fn build_sse_byte_stream<S: Storage + 'static>(
storage: Arc<S>,
name: String,
initial_read: ReadResult,
mut receiver: tokio::sync::broadcast::Receiver<()>,
is_binary: bool,
is_json: bool,
reconnect_interval_secs: u64,
shutdown: CancellationToken,
) -> impl futures_util::stream::Stream<Item = std::result::Result<String, std::convert::Infallible>> + Send
{
async_stream::stream! {
let read_result = initial_read;
let data_frames = sse::format_data_frames(&read_result.messages, is_binary, is_json);
if !data_frames.is_empty() {
yield Ok(data_frames);
}
let control = build_sse_control(&read_result);
yield Ok(sse::format_control_frame(&control));
if read_result.closed && read_result.at_tail {
return;
}
let mut tail_offset = read_result.next_offset;
let idle_timeout = if reconnect_interval_secs > 0 {
Some(Duration::from_secs(reconnect_interval_secs))
} else {
None
};
let mut idle_deadline = idle_timeout.map(|timeout| Instant::now() + timeout);
let keepalive_interval = Duration::from_secs(15);
loop {
tokio::select! {
recv_result = receiver.recv() => {
match recv_result {
Ok(()) | Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
if let Ok(rr) = storage.read(&name, &tail_offset) {
let data_frames = sse::format_data_frames(&rr.messages, is_binary, is_json);
if !data_frames.is_empty() {
yield Ok(data_frames);
}
let ctrl = build_sse_control(&rr);
yield Ok(sse::format_control_frame(&ctrl));
}
return;
}
}
}
() = tokio::time::sleep(keepalive_interval) => {
yield Ok(sse::format_keepalive_frame().to_string());
continue;
}
() = async {
match idle_deadline {
Some(deadline) => tokio::time::sleep_until(deadline).await,
None => std::future::pending().await,
}
} => {
return;
}
() = shutdown.cancelled() => {
return;
}
}
let Ok(rr) = storage.read(&name, &tail_offset) else {
return;
};
let data_frames = sse::format_data_frames(&rr.messages, is_binary, is_json);
if !data_frames.is_empty() {
yield Ok(data_frames);
}
let ctrl = build_sse_control(&rr);
yield Ok(sse::format_control_frame(&ctrl));
if !rr.messages.is_empty()
&& let Some(timeout) = idle_timeout
{
idle_deadline = Some(Instant::now() + timeout);
}
if rr.closed && rr.at_tail {
return;
}
tail_offset = rr.next_offset;
}
}
}
fn build_sse_control(read_result: &ReadResult) -> ControlPayload {
let is_closed_at_tail = read_result.closed && read_result.at_tail;
ControlPayload {
stream_next_offset: read_result.next_offset.to_string(),
stream_cursor: if is_closed_at_tail {
None
} else {
Some(cursor::generate(&read_result.next_offset))
},
up_to_date: if read_result.at_tail {
Some(true)
} else {
None
},
stream_closed: if is_closed_at_tail { Some(true) } else { None },
}
}
fn handle_long_poll_wake<S: Storage>(
storage: &Arc<S>,
name: &str,
offset: &Offset,
raw_offset: &str,
content_type: &str,
) -> Result<Response> {
let read_result = storage.read(name, offset)?;
if read_result.messages.is_empty() {
let is_closed = read_result.closed && read_result.at_tail;
return Ok(build_204_response(&read_result.next_offset, is_closed));
}
let etag = generate_etag(raw_offset, &read_result);
let cursor_val = cursor::generate(&read_result.next_offset);
Ok(build_data_response(
&read_result,
content_type,
&etag,
Some(&cursor_val),
))
}
fn generate_etag(start_offset: &str, read_result: &ReadResult) -> String {
let end_offset = read_result.next_offset.as_str();
if read_result.closed && read_result.at_tail {
format!("\"{start_offset}:{end_offset}:c\"")
} else {
format!("\"{start_offset}:{end_offset}\"")
}
}
fn build_304_response(read_result: &ReadResult) -> Response {
let mut headers = HeaderMap::new();
headers.insert(
names::STREAM_NEXT_OFFSET,
axum::http::HeaderValue::from_bytes(read_result.next_offset.as_str().as_bytes()).unwrap(),
);
headers.insert(names::STREAM_UP_TO_DATE, "true".parse().unwrap());
(StatusCode::NOT_MODIFIED, headers).into_response()
}
fn build_data_response(
read_result: &ReadResult,
content_type: &str,
etag: &str,
cursor_val: Option<&str>,
) -> Response {
let body = build_body(read_result, content_type);
let mut headers = HeaderMap::new();
headers.insert("content-type", content_type.parse().unwrap());
headers.insert(
names::STREAM_NEXT_OFFSET,
axum::http::HeaderValue::from_bytes(read_result.next_offset.as_str().as_bytes()).unwrap(),
);
headers.insert(
names::STREAM_UP_TO_DATE,
(if read_result.at_tail { "true" } else { "false" })
.parse()
.unwrap(),
);
headers.insert("etag", etag.parse().unwrap());
let is_closed_at_tail = read_result.closed && read_result.at_tail;
if is_closed_at_tail {
headers.insert(names::STREAM_CLOSED, "true".parse().unwrap());
}
if let Some(c) = cursor_val {
headers.insert(names::STREAM_CURSOR, c.parse().unwrap());
}
(StatusCode::OK, headers, body).into_response()
}
fn build_204_response(next_offset: &Offset, is_closed: bool) -> Response {
let mut headers = HeaderMap::new();
headers.insert(
names::STREAM_NEXT_OFFSET,
axum::http::HeaderValue::from_bytes(next_offset.as_str().as_bytes()).unwrap(),
);
headers.insert(names::STREAM_UP_TO_DATE, "true".parse().unwrap());
let cursor_val = cursor::generate(next_offset);
if is_closed {
headers.insert(names::STREAM_CLOSED, "true".parse().unwrap());
}
headers.insert(names::STREAM_CURSOR, cursor_val.parse().unwrap());
(StatusCode::NO_CONTENT, headers).into_response()
}
fn build_body(read_result: &ReadResult, content_type: &str) -> bytes::Bytes {
if json_mode::is_json_content_type(content_type) {
json_mode::wrap_read_iter(read_result.messages.iter())
} else if read_result.messages.is_empty() {
bytes::Bytes::new()
} else {
let total_len: usize = read_result.messages.iter().map(bytes::Bytes::len).sum();
let mut buf = BytesMut::with_capacity(total_len);
for message in &read_result.messages {
buf.put(message.clone());
}
buf.freeze()
}
}