use bytes::Bytes;
use futures_util::{stream, StreamExt as _};
use http_body::Frame;
use http_body_util::{BodyExt as _, BodyStream, StreamBody};
use osproxy_observe::RequestTrace;
use osproxy_sink::{buffered, BodyError, ByteBody, Reader, Sink};
use osproxy_spi::RequestCtx;
use osproxy_tenancy::Router;
use crate::cursor::{forwardable_query, pit_id_in_body};
use crate::error::RequestError;
use crate::observe::{read_dispatch_info, resolve_info};
use crate::pipeline::Pipeline;
use crate::read::build_search_op;
use crate::search_scan::{HitShaper, SearchHitsScanner};
impl<R: Router, S: Sink + Reader> Pipeline<R, S> {
pub(crate) async fn run_search_stream(
&self,
ctx: &RequestCtx<'_>,
trace: &mut RequestTrace,
) -> Result<StreamSearch, RequestError> {
if self.cursor_signer.is_some() {
if let Some(wrapped) = pit_id_in_body(ctx.body()) {
let resp = self.pit_search(ctx, trace, &wrapped).await?;
return Ok(StreamSearch::buffered(resp.status, resp.body));
}
}
let resolved = self.resolve_with_retry(ctx).await?;
trace.record_resolve(resolve_info(&resolved));
let (search_op, shape) = build_search_op(&resolved, ctx.body())?;
let stream = self
.sink()
.search_stream(
search_op
.with_query(forwardable_query(ctx.query()))
.with_trace(self.upstream_trace(ctx))
.with_forward_headers(ctx.forward_headers().to_vec()),
)
.await?;
trace.record_dispatch(read_dispatch_info(
&resolved,
stream.status,
stream.pool_reuse,
));
let shaper = HitShaper {
logical_index: ctx.logical_index().to_owned(),
partition: resolved.partition.as_str().to_owned(),
shape,
};
Ok(StreamSearch::stream(
stream.status,
shape_hits_stream(stream.body, shaper),
))
}
}
pub struct StreamSearch {
pub status: u16,
pub body: ByteBody,
}
impl std::fmt::Debug for StreamSearch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamSearch")
.field("status", &self.status)
.finish_non_exhaustive()
}
}
impl StreamSearch {
#[must_use]
pub fn stream(status: u16, body: ByteBody) -> Self {
Self { status, body }
}
#[must_use]
pub fn buffered(status: u16, body: Vec<u8>) -> Self {
Self {
status,
body: buffered(Bytes::from(body)),
}
}
}
struct Active {
frames: BodyStream<ByteBody>,
scanner: SearchHitsScanner,
}
enum Stage {
Active(Box<Active>),
Done,
}
#[must_use]
pub(crate) fn shape_hits_stream(upstream: ByteBody, shaper: HitShaper) -> ByteBody {
let init = Stage::Active(Box::new(Active {
frames: BodyStream::new(upstream),
scanner: SearchHitsScanner::new(shaper),
}));
let stream = stream::unfold(init, |stage| async move { next_frame(stage).await });
StreamBody::new(stream).boxed_unsync()
}
async fn next_frame(stage: Stage) -> Option<(Result<Frame<Bytes>, BodyError>, Stage)> {
let Stage::Active(mut active) = stage else {
return None;
};
loop {
match active.frames.next().await {
Some(Ok(frame)) => {
let Ok(data) = frame.into_data() else {
continue; };
let out = active.scanner.feed(&data);
if !out.is_empty() {
return Some((Ok(Frame::data(Bytes::from(out))), Stage::Active(active)));
}
}
Some(Err(err)) => return Some((Err(err), Stage::Done)),
None => {
let tail = active.scanner.finish();
return if tail.is_empty() {
None
} else {
Some((Ok(Frame::data(Bytes::from(tail))), Stage::Done))
};
}
}
}
}
#[cfg(test)]
#[path = "search_stream_tests.rs"]
mod tests;