use std::collections::HashMap;
use bytes::{Buf as _, BytesMut};
use futures_util::stream::StreamExt as _;
use http_body_util::BodyExt as _;
use osproxy_core::Target;
use osproxy_rewrite::{parse_bulk, parse_bulk_action, BulkAction, BulkItem, RewriteError};
use osproxy_sink::{ByteBody, DocOp, OpResult, Sink, SinkError, WriteAck, WriteBatch, WriteOp};
use osproxy_spi::RequestCtx;
use osproxy_tenancy::Router;
use crate::asyncwrite::{
op_id_for, unavailable_response, unsupported_async, unsupported_response, QueuedWrite,
WriteQueue,
};
use crate::bulkline::{BulkBody, Line};
use crate::bulkprep::{prepare, Prepared};
use crate::error::RequestError;
use crate::pipeline::PipelineResponse;
const FLUSH_THRESHOLD: usize = 256;
const BYTE_FLUSH_THRESHOLD: usize = 4 * 1024 * 1024;
const MAX_DISPATCH_CONCURRENCY: usize = 8;
type Entries = Vec<(usize, Prepared)>;
pub(crate) async fn ingest_bulk<R: Router, S: Sink>(
router: &R,
sink: &S,
ctx: &RequestCtx<'_>,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let items = parse_bulk(ctx.body())?;
let n = items.len();
let mut lines: Vec<Option<Line>> = Vec::new();
lines.resize_with(n, || None);
let mut buffers: HashMap<Target, Entries> = HashMap::new();
let mut sizes: HashMap<Target, usize> = HashMap::new();
let mut cache = crate::bulkprep::ResolutionCache::new();
for (ordinal, item) in items.into_iter().enumerate() {
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
buffer_and_flush(
router,
sink,
&mut buffers,
&mut sizes,
&mut lines,
ordinal,
p,
)
.await;
}
Err(fail) => lines[ordinal] = Some(fail.into_line()),
}
}
flush_remaining(router, sink, buffers, &mut lines).await;
render(&lines)
}
fn render(lines: &[Option<Line>]) -> Result<PipelineResponse, RequestError> {
let errors = lines.iter().flatten().any(Line::is_error);
let body = serde_json::to_vec(&BulkBody {
took: 0,
errors,
items: lines,
})
.map_err(|_| RequestError::Internal {
reason: "serializing bulk response",
})?;
Ok(PipelineResponse {
status: 200,
body,
content_type: None,
})
}
const MAX_LINE_BYTES: usize = 64 * 1024 * 1024;
pub(crate) async fn ingest_bulk_streamed<R: Router, S: Sink>(
router: &R,
sink: &S,
ctx: &RequestCtx<'_>,
body: ByteBody,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let mut reader = NdjsonReader::new(body);
let mut lines: Vec<Option<Line>> = Vec::new();
let mut buffers: HashMap<Target, Entries> = HashMap::new();
let mut sizes: HashMap<Target, usize> = HashMap::new();
let mut cache = crate::bulkprep::ResolutionCache::new();
let mut ordinal = 0usize;
while let Some(item) = reader.next_op().await? {
let ord = ordinal;
ordinal += 1;
lines.push(None);
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
buffer_and_flush(router, sink, &mut buffers, &mut sizes, &mut lines, ord, p).await;
}
Err(fail) => lines[ord] = Some(fail.into_line()),
}
}
flush_remaining(router, sink, buffers, &mut lines).await;
render(&lines)
}
struct NdjsonReader {
body: ByteBody,
buf: BytesMut,
scan: usize,
done: bool,
}
impl NdjsonReader {
fn new(body: ByteBody) -> Self {
Self {
body,
buf: BytesMut::new(),
scan: 0,
done: false,
}
}
async fn next_op(&mut self) -> Result<Option<BulkItem>, RequestError> {
let Some(action_line) = self.next_line().await? else {
return Ok(None);
};
let parsed = parse_bulk_action(&action_line).map_err(RequestError::from)?;
let source = if parsed.has_source() {
Some(
self.next_line()
.await?
.ok_or_else(|| RequestError::from(RewriteError::MalformedBulkAction))?,
)
} else {
None
};
parsed
.into_item(source.as_deref())
.map(Some)
.map_err(RequestError::from)
}
async fn next_line(&mut self) -> Result<Option<BytesMut>, RequestError> {
loop {
if let Some(rel) = self.buf[self.scan..].iter().position(|&b| b == b'\n') {
let nl = self.scan + rel;
let mut line = self.buf.split_to(nl); self.buf.advance(1); self.scan = 0;
if line.last() == Some(&b'\r') {
line.truncate(line.len() - 1);
}
if line.iter().all(u8::is_ascii_whitespace) {
continue;
}
return Ok(Some(line));
}
self.scan = self.buf.len(); if self.done {
if self.buf.iter().all(u8::is_ascii_whitespace) {
return Ok(None);
}
return Ok(Some(std::mem::take(&mut self.buf)));
}
if self.buf.len() > MAX_LINE_BYTES {
return Err(RequestError::PayloadTooLarge {
reason: "bulk line exceeds the per-op size cap",
});
}
match self.body.frame().await {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
self.buf.extend_from_slice(&data);
}
}
Some(Err(_)) => {
return Err(RequestError::Internal {
reason: "reading bulk body stream",
})
}
None => self.done = true,
}
}
}
}
pub(crate) async fn ingest_bulk_async<R: Router>(
router: &R,
queue: &dyn WriteQueue,
ctx: &RequestCtx<'_>,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let index = ctx.logical_index();
if let Some(reason) = unsupported_async(ctx) {
return Ok(unsupported_response(reason, index));
}
if !queue.enabled() {
return Ok(unavailable_response(index));
}
let items = parse_bulk(ctx.body())?;
let batch_id = op_id_for(ctx, ctx.request_id());
let mut lines: Vec<Option<Line>> = Vec::new();
lines.resize_with(items.len(), || None);
let mut cache = crate::bulkprep::ResolutionCache::new();
for (ordinal, item) in items.into_iter().enumerate() {
if matches!(item.action, BulkAction::Update) || item.concurrency_control {
let logical_index = item.index.clone().unwrap_or_else(|| index.to_owned());
lines[ordinal] = Some(Line::error(
item.action.keyword(),
logical_index,
item.id.clone(),
400,
"unsupported_async",
));
continue;
}
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
let op_id = format!("{batch_id}:{ordinal}");
let write = QueuedWrite {
op_id: op_id.clone(),
partition_key: p.partition.as_str().to_owned(),
batch: WriteBatch::single(p.op.clone()),
};
lines[ordinal] = Some(match queue.enqueue(write).await {
Ok(()) => queued_line(&p, op_id),
Err(_) => enqueue_failed_line(&p),
});
}
Err(fail) => lines[ordinal] = Some(fail.into_line()),
}
}
render(&lines)
}
fn queued_line(p: &Prepared, op_id: String) -> Line {
Line::queued(
p.action,
p.logical_index.clone(),
Some(p.logical_id.clone()),
202,
op_id,
)
}
fn enqueue_failed_line(p: &Prepared) -> Line {
error_line(p, 503, "enqueue_failed")
}
async fn buffer_and_flush<R: Router, S: Sink>(
router: &R,
sink: &S,
buffers: &mut HashMap<Target, Entries>,
sizes: &mut HashMap<Target, usize>,
lines: &mut [Option<Line>],
ordinal: usize,
prepared: Prepared,
) {
let target = prepared.op.target.clone();
let op_bytes = op_body_len(&prepared.op);
let buf = buffers.entry(target.clone()).or_default();
buf.push((ordinal, prepared));
let over_count = buf.len() >= FLUSH_THRESHOLD;
let size = sizes.entry(target.clone()).or_default();
*size += op_bytes;
if over_count || *size >= BYTE_FLUSH_THRESHOLD {
let entries = buffers.remove(&target).unwrap_or_default();
sizes.remove(&target);
flush(router, sink, entries, lines).await;
}
}
fn op_body_len(op: &WriteOp) -> usize {
match &op.doc {
DocOp::Index { body, .. } | DocOp::Create { body, .. } | DocOp::Update { body, .. } => {
body.len()
}
DocOp::Delete { .. } => 0,
}
}
async fn flush<R: Router, S: Sink>(
router: &R,
sink: &S,
entries: Entries,
lines: &mut [Option<Line>],
) {
let (admitted, rejected) = gate(router, entries).await;
for (ordinal, p) in &rejected {
lines[*ordinal] = Some(stale_epoch_line(p));
}
apply_results(&admitted, sink.write(build_batch(&admitted)).await, lines);
}
async fn flush_remaining<R: Router, S: Sink>(
router: &R,
sink: &S,
buffers: HashMap<Target, Entries>,
lines: &mut [Option<Line>],
) {
type Flushed = (Entries, Entries, Result<WriteAck, SinkError>);
let pending = buffers.into_values().filter(|v| !v.is_empty());
let results: Vec<Flushed> = futures_util::stream::iter(pending)
.map(|entries| async move {
let (admitted, rejected) = gate(router, entries).await;
let ack = sink.write(build_batch(&admitted)).await;
(admitted, rejected, ack)
})
.buffer_unordered(MAX_DISPATCH_CONCURRENCY)
.collect()
.await;
for (admitted, rejected, ack) in results {
for (ordinal, p) in &rejected {
lines[*ordinal] = Some(stale_epoch_line(p));
}
apply_results(&admitted, ack, lines);
}
}
async fn gate<R: Router>(router: &R, entries: Entries) -> (Entries, Entries) {
let mut admitted = Entries::new();
let mut rejected = Entries::new();
for (ordinal, p) in entries {
if router.admit_write(&p.partition, p.op.epoch).await {
admitted.push((ordinal, p));
} else {
rejected.push((ordinal, p));
}
}
(admitted, rejected)
}
fn stale_epoch_line(p: &Prepared) -> Line {
error_line(p, 409, "stale_epoch")
}
fn error_line(p: &Prepared, status: u16, error: &'static str) -> Line {
Line::error(
p.action,
p.logical_index.clone(),
Some(p.logical_id.clone()),
status,
error,
)
}
fn build_batch(entries: &[(usize, Prepared)]) -> WriteBatch {
entries
.iter()
.fold(WriteBatch::new(), |b, (_, p)| b.with(p.op.clone()))
}
fn apply_results(
entries: &[(usize, Prepared)],
result: Result<WriteAck, SinkError>,
lines: &mut [Option<Line>],
) {
match result {
Ok(ack) => {
for ((ordinal, p), op_result) in entries.iter().zip(ack.results()) {
lines[*ordinal] = Some(success_line(p, op_result));
}
}
Err(_) => {
for (ordinal, p) in entries {
lines[*ordinal] = Some(upstream_failure_line(p));
}
}
}
}
fn success_line(p: &Prepared, result: &OpResult) -> Line {
if result.status >= 400 {
return error_line(p, result.status, error_type_for(result.status));
}
let outcome = if result.created { "created" } else { "updated" };
Line::result(
p.action,
p.logical_index.clone(),
Some(p.logical_id.clone()),
result.status,
outcome,
)
}
fn error_type_for(status: u16) -> &'static str {
match status {
409 => "conflict",
404 => "not_found",
_ => "rejected",
}
}
fn upstream_failure_line(p: &Prepared) -> Line {
error_line(p, 502, "upstream_failed")
}