use std::future::Future;
use std::pin::Pin;
use osproxy_core::RequestId;
use osproxy_sink::{Reader, Sink, WriteBatch};
use osproxy_spi::RequestCtx;
use osproxy_tenancy::{Resolved, Router};
use serde_json::json;
use crate::pipeline::{Pipeline, PipelineResponse};
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum WriteMode {
#[default]
Sync,
Async,
}
impl WriteMode {
#[must_use]
pub fn parse(value: &str) -> Option<Self> {
if value.eq_ignore_ascii_case("sync") {
Some(Self::Sync)
} else if value.eq_ignore_ascii_case("async") {
Some(Self::Async)
} else {
None
}
}
}
const MAX_OP_ID_LEN: usize = 128;
#[must_use]
pub fn valid_op_id(candidate: &str) -> bool {
!candidate.is_empty()
&& candidate.len() <= MAX_OP_ID_LEN
&& candidate
.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_' | b'.' | b':'))
}
#[must_use]
pub fn op_id_for(ctx: &RequestCtx<'_>, request_id: &RequestId) -> String {
ctx.headers()
.get("x-op-id")
.filter(|h| valid_op_id(h))
.map_or_else(|| request_id.as_str().to_owned(), ToOwned::to_owned)
}
#[must_use]
pub fn unsupported_async(ctx: &RequestCtx<'_>) -> Option<&'static str> {
if let Some(query) = ctx.query() {
let has_cas = query.split('&').any(|pair| {
let key = pair.split('=').next().unwrap_or(pair);
matches!(key, "if_seq_no" | "if_primary_term" | "version")
});
if has_cas {
return Some("optimistic concurrency (if_seq_no/if_primary_term/version) is not supported in async write mode");
}
}
if ctx.path().split('/').any(|seg| seg == "_update") {
return Some("scripted/partial _update is not supported in async write mode");
}
None
}
#[derive(Clone, Debug)]
pub struct QueuedWrite {
pub op_id: String,
pub partition_key: String,
pub batch: WriteBatch,
}
pub trait WriteQueue: Send + Sync {
fn enabled(&self) -> bool {
false
}
fn enqueue<'a>(
&'a self,
write: QueuedWrite,
) -> Pin<Box<dyn Future<Output = Result<(), QueueError>> + Send + 'a>>;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct QueueError {
pub reason: &'static str,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct NoQueue;
impl WriteQueue for NoQueue {
fn enqueue<'a>(
&'a self,
_write: QueuedWrite,
) -> Pin<Box<dyn Future<Output = Result<(), QueueError>> + Send + 'a>> {
Box::pin(async {
Err(QueueError {
reason: "async write queue is not configured",
})
})
}
}
#[must_use]
pub(crate) fn accepted_response(op_id: &str, index: &str) -> PipelineResponse {
PipelineResponse {
status: 202,
body: serde_json::to_vec(&json!({
"op_id": op_id,
"status": "accepted",
"result": "queued",
"_index": index,
}))
.unwrap_or_else(|_| b"{}".to_vec()),
content_type: None,
}
}
#[must_use]
pub(crate) fn unsupported_response(reason: &str, index: &str) -> PipelineResponse {
PipelineResponse {
status: 400,
body: serde_json::to_vec(&json!({
"status": "rejected",
"error": reason,
"_index": index,
}))
.unwrap_or_else(|_| b"{}".to_vec()),
content_type: None,
}
}
#[must_use]
pub(crate) fn unavailable_response(index: &str) -> PipelineResponse {
PipelineResponse {
status: 422,
body: serde_json::to_vec(&json!({
"status": "rejected",
"error": "async write mode is not available on this proxy",
"_index": index,
}))
.unwrap_or_else(|_| b"{}".to_vec()),
content_type: None,
}
}
#[must_use]
pub(crate) fn enqueue_failed_response(op_id: &str, index: &str) -> PipelineResponse {
PipelineResponse {
status: 503,
body: serde_json::to_vec(&json!({
"op_id": op_id,
"status": "rejected",
"error": "async write could not be enqueued",
"_index": index,
}))
.unwrap_or_else(|_| b"{}".to_vec()),
content_type: None,
}
}
impl<R: Router, S: Sink + Reader> Pipeline<R, S> {
pub(crate) async fn enqueue_async(
&self,
ctx: &RequestCtx<'_>,
resolved: &Resolved,
batch: WriteBatch,
) -> PipelineResponse {
let index = ctx.logical_index();
if let Some(reason) = unsupported_async(ctx) {
return unsupported_response(reason, index);
}
if !self.write_queue.enabled() {
return unavailable_response(index);
}
let op_id = op_id_for(ctx, ctx.request_id());
let write = QueuedWrite {
op_id: op_id.clone(),
partition_key: resolved.partition.as_str().to_owned(),
batch,
};
match self.write_queue.enqueue(write).await {
Ok(()) => accepted_response(&op_id, index),
Err(_) => enqueue_failed_response(&op_id, index),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_known_modes_case_insensitively_and_rejects_unknown() {
assert_eq!(WriteMode::parse("sync"), Some(WriteMode::Sync));
assert_eq!(WriteMode::parse("ASYNC"), Some(WriteMode::Async));
assert_eq!(WriteMode::parse("queue"), None);
assert_eq!(WriteMode::parse(""), None);
}
#[test]
fn op_id_validation_bounds_length_and_charset() {
assert!(valid_op_id("a-b_c.d:1"));
assert!(!valid_op_id(""));
assert!(!valid_op_id("has space"));
assert!(!valid_op_id("inject\nkey"));
assert!(!valid_op_id(&"x".repeat(MAX_OP_ID_LEN + 1)));
assert!(valid_op_id(&"x".repeat(MAX_OP_ID_LEN)));
}
#[tokio::test]
async fn no_queue_is_disabled_and_refuses() {
assert!(!NoQueue.enabled());
let write = QueuedWrite {
op_id: "op-1".to_owned(),
partition_key: "acme".to_owned(),
batch: WriteBatch::single(test_op()),
};
let err = NoQueue.enqueue(write).await.unwrap_err();
assert_eq!(err.reason, "async write queue is not configured");
}
fn test_op() -> osproxy_sink::WriteOp {
use osproxy_core::{ClusterId, Epoch, IndexName, Target};
use osproxy_sink::{DocOp, WriteOp};
WriteOp::new(
Target::new(ClusterId::from("c"), IndexName::from("i")),
DocOp::Index {
id: Some("p:1".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
)
}
}