use std::collections::HashMap;
use bytes::{Buf as _, BytesMut};
use futures_util::stream::StreamExt as _;
use http_body_util::BodyExt as _;
use osproxy_core::{PartitionId, Target};
use osproxy_rewrite::{
parse_bulk, parse_bulk_action, parse_bulk_op, BulkAction, BulkItem, RewriteError,
};
use osproxy_sink::{ByteBody, DocOp, OpResult, Sink, SinkError, WriteAck, WriteBatch, WriteOp};
use osproxy_spi::RequestCtx;
use osproxy_tenancy::{Resolved, Router};
use serde_json::{json, Value};
use crate::asyncwrite::{
op_id_for, unavailable_response, unsupported_async, unsupported_response, QueuedWrite,
WriteQueue,
};
use crate::bulkprep::{prepare, Prepared};
use crate::error::RequestError;
use crate::pipeline::PipelineResponse;
const FLUSH_THRESHOLD: usize = 256;
const BYTE_FLUSH_THRESHOLD: usize = 4 * 1024 * 1024;
const MAX_DISPATCH_CONCURRENCY: usize = 8;
type Entries = Vec<(usize, Prepared)>;
pub(crate) async fn ingest_bulk<R: Router, S: Sink>(
router: &R,
sink: &S,
ctx: &RequestCtx<'_>,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let items = parse_bulk(ctx.body())?;
let n = items.len();
let mut lines: Vec<Value> = vec![Value::Null; n];
let mut buffers: HashMap<Target, Entries> = HashMap::new();
let mut sizes: HashMap<Target, usize> = HashMap::new();
let mut cache: HashMap<(PartitionId, String), Resolved> = HashMap::new();
for (ordinal, item) in items.into_iter().enumerate() {
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
buffer_and_flush(
router,
sink,
&mut buffers,
&mut sizes,
&mut lines,
ordinal,
p,
)
.await;
}
Err(fail) => lines[ordinal] = fail.into_line(),
}
}
flush_remaining(router, sink, buffers, &mut lines).await;
let errors = lines.iter().any(is_error_line);
let body = json!({ "took": 0, "errors": errors, "items": lines });
Ok(PipelineResponse {
status: 200,
body: serde_json::to_vec(&body).map_err(|_| RequestError::Internal {
reason: "serializing bulk response",
})?,
content_type: None,
})
}
const MAX_LINE_BYTES: usize = 64 * 1024 * 1024;
pub(crate) async fn ingest_bulk_streamed<R: Router, S: Sink>(
router: &R,
sink: &S,
ctx: &RequestCtx<'_>,
body: ByteBody,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let mut reader = NdjsonReader::new(body);
let mut lines: Vec<Value> = Vec::new();
let mut buffers: HashMap<Target, Entries> = HashMap::new();
let mut sizes: HashMap<Target, usize> = HashMap::new();
let mut cache: HashMap<(PartitionId, String), Resolved> = HashMap::new();
let mut ordinal = 0usize;
while let Some(item) = reader.next_op().await? {
let ord = ordinal;
ordinal += 1;
lines.push(Value::Null);
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
buffer_and_flush(router, sink, &mut buffers, &mut sizes, &mut lines, ord, p).await;
}
Err(fail) => lines[ord] = fail.into_line(),
}
}
flush_remaining(router, sink, buffers, &mut lines).await;
let errors = lines.iter().any(is_error_line);
let body = json!({ "took": 0, "errors": errors, "items": lines });
Ok(PipelineResponse {
status: 200,
body: serde_json::to_vec(&body).map_err(|_| RequestError::Internal {
reason: "serializing bulk response",
})?,
content_type: None,
})
}
struct NdjsonReader {
body: ByteBody,
buf: BytesMut,
scan: usize,
done: bool,
}
impl NdjsonReader {
fn new(body: ByteBody) -> Self {
Self {
body,
buf: BytesMut::new(),
scan: 0,
done: false,
}
}
async fn next_op(&mut self) -> Result<Option<BulkItem>, RequestError> {
let Some(action_line) = self.next_line().await? else {
return Ok(None);
};
let action = parse_bulk_action(&action_line).map_err(RequestError::from)?;
let source = if action.has_source() {
Some(
self.next_line()
.await?
.ok_or_else(|| RequestError::from(RewriteError::MalformedBulkAction))?,
)
} else {
None
};
parse_bulk_op(&action_line, source.as_deref())
.map(Some)
.map_err(RequestError::from)
}
async fn next_line(&mut self) -> Result<Option<BytesMut>, RequestError> {
loop {
if let Some(rel) = self.buf[self.scan..].iter().position(|&b| b == b'\n') {
let nl = self.scan + rel;
let mut line = self.buf.split_to(nl); self.buf.advance(1); self.scan = 0;
if line.last() == Some(&b'\r') {
line.truncate(line.len() - 1);
}
if line.iter().all(u8::is_ascii_whitespace) {
continue;
}
return Ok(Some(line));
}
self.scan = self.buf.len(); if self.done {
if self.buf.iter().all(u8::is_ascii_whitespace) {
return Ok(None);
}
return Ok(Some(std::mem::take(&mut self.buf)));
}
if self.buf.len() > MAX_LINE_BYTES {
return Err(RequestError::PayloadTooLarge {
reason: "bulk line exceeds the per-op size cap",
});
}
match self.body.frame().await {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
self.buf.extend_from_slice(&data);
}
}
Some(Err(_)) => {
return Err(RequestError::Internal {
reason: "reading bulk body stream",
})
}
None => self.done = true,
}
}
}
}
pub(crate) async fn ingest_bulk_async<R: Router>(
router: &R,
queue: &dyn WriteQueue,
ctx: &RequestCtx<'_>,
retry: crate::RetryPolicy,
up_trace: Option<osproxy_core::TraceContext>,
) -> Result<PipelineResponse, RequestError> {
let index = ctx.logical_index();
if let Some(reason) = unsupported_async(ctx) {
return Ok(unsupported_response(reason, index));
}
if !queue.enabled() {
return Ok(unavailable_response(index));
}
let items = parse_bulk(ctx.body())?;
let batch_id = op_id_for(ctx, ctx.request_id());
let mut lines: Vec<Value> = vec![Value::Null; items.len()];
let mut cache: HashMap<(PartitionId, String), Resolved> = HashMap::new();
for (ordinal, item) in items.into_iter().enumerate() {
if matches!(item.action, BulkAction::Update) || item.concurrency_control {
lines[ordinal] = json!({ item.action.keyword(): {
"_index": item.index.clone().unwrap_or_else(|| index.to_owned()),
"_id": item.id,
"status": 400,
"error": { "type": "unsupported_async" },
}});
continue;
}
match prepare(router, ctx, &mut cache, item, retry, up_trace.as_ref()).await {
Ok(p) => {
let op_id = format!("{batch_id}:{ordinal}");
let write = QueuedWrite {
op_id: op_id.clone(),
partition_key: p.partition.as_str().to_owned(),
batch: WriteBatch::single(p.op.clone()),
};
lines[ordinal] = match queue.enqueue(write).await {
Ok(()) => queued_line(&p, &op_id),
Err(_) => enqueue_failed_line(&p),
};
}
Err(fail) => lines[ordinal] = fail.into_line(),
}
}
let errors = lines.iter().any(is_error_line);
let body = json!({ "took": 0, "errors": errors, "items": lines });
Ok(PipelineResponse {
status: 200,
body: serde_json::to_vec(&body).map_err(|_| RequestError::Internal {
reason: "serializing bulk response",
})?,
content_type: None,
})
}
fn queued_line(p: &Prepared, op_id: &str) -> Value {
json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"op_id": op_id,
"status": 202,
"result": "queued",
}})
}
fn enqueue_failed_line(p: &Prepared) -> Value {
json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"status": 503,
"error": { "type": "enqueue_failed" },
}})
}
async fn buffer_and_flush<R: Router, S: Sink>(
router: &R,
sink: &S,
buffers: &mut HashMap<Target, Entries>,
sizes: &mut HashMap<Target, usize>,
lines: &mut [Value],
ordinal: usize,
prepared: Prepared,
) {
let target = prepared.op.target.clone();
let op_bytes = op_body_len(&prepared.op);
let buf = buffers.entry(target.clone()).or_default();
buf.push((ordinal, prepared));
let over_count = buf.len() >= FLUSH_THRESHOLD;
let size = sizes.entry(target.clone()).or_default();
*size += op_bytes;
if over_count || *size >= BYTE_FLUSH_THRESHOLD {
let entries = buffers.remove(&target).unwrap_or_default();
sizes.remove(&target);
flush(router, sink, entries, lines).await;
}
}
fn op_body_len(op: &WriteOp) -> usize {
match &op.doc {
DocOp::Index { body, .. } | DocOp::Create { body, .. } | DocOp::Update { body, .. } => {
body.len()
}
DocOp::Delete { .. } => 0,
}
}
async fn flush<R: Router, S: Sink>(router: &R, sink: &S, entries: Entries, lines: &mut [Value]) {
let (admitted, rejected) = gate(router, entries).await;
for (ordinal, p) in &rejected {
lines[*ordinal] = stale_epoch_line(p);
}
apply_results(&admitted, sink.write(build_batch(&admitted)).await, lines);
}
async fn flush_remaining<R: Router, S: Sink>(
router: &R,
sink: &S,
buffers: HashMap<Target, Entries>,
lines: &mut [Value],
) {
type Flushed = (Entries, Entries, Result<WriteAck, SinkError>);
let pending = buffers.into_values().filter(|v| !v.is_empty());
let results: Vec<Flushed> = futures_util::stream::iter(pending)
.map(|entries| async move {
let (admitted, rejected) = gate(router, entries).await;
let ack = sink.write(build_batch(&admitted)).await;
(admitted, rejected, ack)
})
.buffer_unordered(MAX_DISPATCH_CONCURRENCY)
.collect()
.await;
for (admitted, rejected, ack) in results {
for (ordinal, p) in &rejected {
lines[*ordinal] = stale_epoch_line(p);
}
apply_results(&admitted, ack, lines);
}
}
async fn gate<R: Router>(router: &R, entries: Entries) -> (Entries, Entries) {
let mut admitted = Entries::new();
let mut rejected = Entries::new();
for (ordinal, p) in entries {
if router.admit_write(&p.partition, p.op.epoch).await {
admitted.push((ordinal, p));
} else {
rejected.push((ordinal, p));
}
}
(admitted, rejected)
}
fn stale_epoch_line(p: &Prepared) -> Value {
json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"status": 409,
"error": { "type": "stale_epoch" },
}})
}
fn build_batch(entries: &[(usize, Prepared)]) -> WriteBatch {
entries
.iter()
.fold(WriteBatch::new(), |b, (_, p)| b.with(p.op.clone()))
}
fn apply_results(
entries: &[(usize, Prepared)],
result: Result<WriteAck, SinkError>,
lines: &mut [Value],
) {
match result {
Ok(ack) => {
for ((ordinal, p), op_result) in entries.iter().zip(ack.results()) {
lines[*ordinal] = success_line(p, op_result);
}
}
Err(_) => {
for (ordinal, p) in entries {
lines[*ordinal] = upstream_failure_line(p);
}
}
}
}
fn success_line(p: &Prepared, result: &OpResult) -> Value {
if result.status >= 400 {
return json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"status": result.status,
"error": { "type": error_type_for(result.status) },
}});
}
let outcome = if result.created { "created" } else { "updated" };
json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"status": result.status,
"result": outcome,
}})
}
fn error_type_for(status: u16) -> &'static str {
match status {
409 => "conflict",
404 => "not_found",
_ => "rejected",
}
}
fn upstream_failure_line(p: &Prepared) -> Value {
json!({ p.action: {
"_index": p.logical_index,
"_id": p.logical_id,
"status": 502,
"error": { "type": "upstream_failed" },
}})
}
fn is_error_line(line: &Value) -> bool {
line.as_object()
.and_then(|o| o.values().next())
.and_then(|v| v.get("error"))
.is_some()
}