use crate::http_client::HttpClient;
use crate::record::write_batched_records_v1::WriteBatchType;
use crate::Record;
use async_stream::stream;
use futures_util::StreamExt;
use reduct_base::batch::v2::{encode_entry_name, make_batched_header_name};
use reduct_base::error::{ErrorCode, ReductError};
use reqwest::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE};
use reqwest::{Body, Method};
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::Arc;
use std::time::SystemTime;
pub struct WriteRecordBatchBuilder {
bucket: String,
batch_type: WriteBatchType,
records: VecDeque<Record>,
client: Arc<HttpClient>,
last_access: SystemTime,
}
type FailedRecordMap = BTreeMap<(String, u64), ReductError>;
impl WriteRecordBatchBuilder {
pub(crate) fn new(bucket: String, client: Arc<HttpClient>, batch_type: WriteBatchType) -> Self {
Self {
bucket,
batch_type,
records: VecDeque::new(),
client,
last_access: SystemTime::now(),
}
}
pub fn add_record(mut self, record: Record) -> Self {
self.records.push_back(record);
self.last_access = SystemTime::now();
self
}
pub fn append_record(&mut self, record: Record) {
self.records.push_back(record);
self.last_access = SystemTime::now();
}
pub fn add_records(mut self, records: Vec<Record>) -> Self {
self.records.extend(records);
self.last_access = SystemTime::now();
self
}
pub fn append_records(&mut self, records: Vec<Record>) {
self.records.extend(records);
self.last_access = SystemTime::now();
}
pub async fn send(mut self) -> Result<FailedRecordMap, ReductError> {
if let Some(version) = self.client.get_api_version().await {
if version.1 < 18 {
let message = match self.batch_type {
WriteBatchType::Write => {
"Multi-entry batch writes are not supported in API versions below v1.18"
}
WriteBatchType::Update => {
"Multi-entry batch updates are not supported in API versions below v1.18"
}
WriteBatchType::Remove => {
"Multi-entry batch remove is not supported in API versions below v1.18"
}
};
return Err(ReductError::new(ErrorCode::InvalidRequest, message));
}
}
if self.records.is_empty() {
return Err(ReductError::new(
ErrorCode::InvalidRequest,
"Batch must contain at least one record",
));
}
let mut entries = Vec::new();
let mut entry_index = HashMap::new();
for record in &self.records {
if record.entry().is_empty() {
return Err(ReductError::new(
ErrorCode::InvalidRequest,
"Record entry name is required for multi-entry batch operations",
));
}
if !entry_index.contains_key(record.entry()) {
let index = entries.len();
entries.push(record.entry().to_string());
entry_index.insert(record.entry().to_string(), index);
}
}
let mut records: Vec<Record> = self.records.drain(..).collect();
let start_ts = records
.iter()
.map(|record| record.timestamp_us())
.min()
.unwrap();
records.sort_by(|left, right| {
let left_idx = entry_index.get(left.entry()).unwrap();
let right_idx = entry_index.get(right.entry()).unwrap();
left_idx
.cmp(right_idx)
.then_with(|| left.timestamp_us().cmp(&right.timestamp_us()))
});
let mut request = match self.batch_type {
WriteBatchType::Write => self
.client
.request(Method::POST, &format!("/io/{}/write", self.bucket))
.header(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
)
.header(
CONTENT_LENGTH,
HeaderValue::from_str(
&records
.iter()
.map(|r| r.content_length())
.sum::<usize>()
.to_string(),
)
.unwrap(),
),
WriteBatchType::Update => self
.client
.request(Method::PATCH, &format!("/io/{}/update", self.bucket))
.header(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
)
.header(CONTENT_LENGTH, HeaderValue::from_static("0")),
WriteBatchType::Remove => self
.client
.request(Method::DELETE, &format!("/io/{}/remove", self.bucket))
.header(CONTENT_LENGTH, HeaderValue::from_static("0")),
};
request = request
.header(
"x-reduct-start-ts",
HeaderValue::from_str(&start_ts.to_string()).unwrap(),
)
.header(
"x-reduct-entries",
HeaderValue::from_str(&encode_entries(&entries)).unwrap(),
);
for record in &records {
let idx = *entry_index.get(record.entry()).unwrap();
let delta = record.timestamp_us() - start_ts;
let value = match self.batch_type {
WriteBatchType::Write => make_record_header_value(record),
WriteBatchType::Update => make_update_header_value(record),
WriteBatchType::Remove => String::new(),
};
let header_value = if value.is_empty() {
HeaderValue::from_static("")
} else {
HeaderValue::from_str(&value).unwrap()
};
request = request.header(make_batched_header_name(idx, delta), header_value);
}
let response = match self.batch_type {
WriteBatchType::Write => {
let client = Arc::clone(&self.client);
let stream = stream! {
for record in records {
let mut stream = record.stream_bytes();
while let Some(bytes) = stream.next().await {
yield bytes;
}
}
};
client
.send_request(request.body(Body::wrap_stream(stream)))
.await?
}
WriteBatchType::Update | WriteBatchType::Remove => {
self.client.send_request(request).await?
}
};
let mut failed_records = FailedRecordMap::new();
response
.headers()
.iter()
.filter(|(key, _)| key.as_str().starts_with("x-reduct-error-"))
.for_each(|(key, value)| {
if let Some((entry_idx, delta)) = parse_error_key(key.as_str()) {
if let Some(entry) = entries.get(entry_idx) {
if let Some((status, message)) = value.to_str().unwrap().split_once(',') {
if let Ok(status) = status.parse::<i16>() {
if let Ok(code) = ErrorCode::try_from(status) {
failed_records.insert(
(entry.to_string(), start_ts + delta),
ReductError::new(code, message),
);
}
}
}
}
}
});
Ok(failed_records)
}
pub fn size(&self) -> usize {
self.records.iter().map(|r| r.content_length()).sum()
}
pub fn record_count(&self) -> usize {
self.records.len()
}
pub fn last_access(&self) -> SystemTime {
self.last_access
}
pub fn clear(&mut self) {
self.records.clear();
}
}
fn encode_entries(entries: &[String]) -> String {
entries
.iter()
.map(|entry| encode_entry_name(entry))
.collect::<Vec<_>>()
.join(",")
}
fn make_record_header_value(record: &Record) -> String {
let content_type = if record.content_type().is_empty() {
"application/octet-stream"
} else {
record.content_type()
};
let labels = record.labels();
if labels.is_empty() {
format!("{},{}", record.content_length(), content_type)
} else {
format!(
"{},{},{}",
record.content_length(),
content_type,
format_label_delta(labels)
)
}
}
fn make_update_header_value(record: &Record) -> String {
let labels = record.labels();
if labels.is_empty() {
"0,application/octet-stream".to_string()
} else {
format!("0,application/octet-stream,{}", format_label_delta(labels))
}
}
fn format_label_delta(labels: &crate::Labels) -> String {
let mut pairs: Vec<_> = labels.iter().collect();
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
pairs
.into_iter()
.map(|(key, value)| {
if value.contains(',') {
format!("{}=\"{}\"", key, value)
} else {
format!("{}={}", key, value)
}
})
.collect::<Vec<_>>()
.join(",")
}
fn parse_error_key(key: &str) -> Option<(usize, u64)> {
let suffix = key.strip_prefix("x-reduct-error-")?;
let (entry_idx, delta) = suffix.rsplit_once('-')?;
Some((entry_idx.parse().ok()?, delta.parse().ok()?))
}