pub mod helpers;
#[cfg(test)]
mod tests;
use std::sync::Arc;
use axum::http::{HeaderMap, HeaderValue};
use bytes::Bytes;
use fraiseql_core::{db::traits::DatabaseAdapter, security::SecurityContext};
use futures::stream;
use helpers::{StreamState, fetch_and_serialize_batch};
use super::{
handler::{PreferHeader, ResolvedGetQuery, RestError, RestHandler, set_request_id},
params::PaginationParams,
};
pub const NDJSON_CONTENT_TYPE: &str = "application/x-ndjson";
#[must_use]
pub fn accepts_ndjson(headers: &HeaderMap) -> bool {
headers.get("accept").and_then(|v| v.to_str().ok()).is_some_and(|accept| {
accept
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(NDJSON_CONTENT_TYPE))
})
}
pub fn validate_ndjson_request(
prefer: &PreferHeader,
pagination: &PaginationParams,
) -> Result<(), RestError> {
if prefer.count_exact || prefer.count_planned || prefer.count_estimated {
return Err(RestError::bad_request("count not available for streaming responses"));
}
if let PaginationParams::Offset { offset, .. } = pagination {
if *offset > 0 {
return Err(RestError::bad_request(
"pagination not available for streaming; use filters to narrow results",
));
}
}
if matches!(pagination, PaginationParams::Cursor { .. }) {
return Err(RestError::bad_request(
"pagination not available for streaming; use filters to narrow results",
));
}
Ok(())
}
pub async fn handle_ndjson_get<A: DatabaseAdapter + 'static>(
handler: &RestHandler<'_, A>,
relative_path: &str,
query_pairs: &[(&str, &str)],
headers: &HeaderMap,
security_context: Option<&SecurityContext>,
) -> Result<NdjsonResponse, RestError> {
let resolved = handler.resolve_get_query(relative_path, query_pairs, security_context)?;
let prefer = PreferHeader::from_headers(headers);
validate_ndjson_request(&prefer, &resolved.params.pagination)?;
let ResolvedGetQuery {
query_name,
query_match,
variables,
..
} = resolved;
let batch_size = handler.config().ndjson_batch_size.max(1);
let mut response_headers = HeaderMap::new();
set_request_id(headers, &mut response_headers);
response_headers.insert("content-type", HeaderValue::from_static(NDJSON_CONTENT_TYPE));
response_headers.insert(
"x-stream-batch-size",
HeaderValue::from_str(&batch_size.to_string())
.unwrap_or_else(|_| HeaderValue::from_static("500")),
);
let executor = Arc::clone(handler.executor());
let security_ctx_owned = security_context.cloned();
let ndjson_stream = stream::unfold(
StreamState {
executor,
query_name,
query_match,
variables,
security_ctx: security_ctx_owned,
batch_size,
offset: 0,
done: false,
},
|mut state| async move {
if state.done {
return None;
}
match fetch_and_serialize_batch(&mut state).await {
Ok(Some(bytes)) => Some((Ok(bytes), state)),
Ok(None) => None,
Err(err_bytes) => {
state.done = true;
Some((Ok(err_bytes), state))
},
}
},
);
Ok(NdjsonResponse {
headers: response_headers,
body: NdjsonBody::Stream(Box::pin(ndjson_stream)),
})
}
pub struct NdjsonResponse {
pub headers: HeaderMap,
pub body: NdjsonBody,
}
#[non_exhaustive]
pub enum NdjsonBody {
Stream(
std::pin::Pin<
Box<dyn futures::Stream<Item = Result<Bytes, std::convert::Infallible>> + Send>,
>,
),
}
impl NdjsonBody {
pub fn into_body(self) -> axum::body::Body {
match self {
Self::Stream(stream) => axum::body::Body::from_stream(stream),
}
}
}