use crate::http_client::HttpClient;
use crate::{Record, RecordBuilder};
use async_stream::stream;
use futures_util::StreamExt;
use reqwest::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE};
use reqwest::{Body, Method};
use std::collections::{BTreeMap, VecDeque};
use reduct_base::error::{ErrorCode, ReductError};
use std::sync::Arc;
use std::time::SystemTime;
pub(crate) enum WriteBatchType {
Write,
Update,
Remove,
}
pub struct WriteBatchBuilder {
bucket: String,
entry: String,
batch_type: WriteBatchType,
records: VecDeque<Record>,
client: Arc<HttpClient>,
last_access: SystemTime,
}
type FailedRecordMap = BTreeMap<u64, ReductError>;
impl WriteBatchBuilder {
pub(crate) fn new(
bucket: String,
entry: String,
client: Arc<HttpClient>,
batch_type: WriteBatchType,
) -> Self {
Self {
bucket,
entry,
batch_type,
records: VecDeque::new(),
client,
last_access: SystemTime::now(),
}
}
pub fn add_record(mut self, record: Record) -> Self {
self.records.push_back(record);
self.last_access = SystemTime::now();
self
}
pub fn append_record(&mut self, record: Record) {
self.records.push_back(record);
self.last_access = SystemTime::now();
}
pub fn add_records(mut self, records: Vec<Record>) -> Self {
self.records.extend(records);
self.last_access = SystemTime::now();
self
}
pub fn append_records(&mut self, records: Vec<Record>) {
self.records.extend(records);
self.last_access = SystemTime::now();
}
pub fn add_timestamp_us(mut self, timestamp: u64) -> Self {
self.records
.push_back(RecordBuilder::new().timestamp_us(timestamp).build());
self.last_access = SystemTime::now();
self
}
pub fn append_timestamp_us(&mut self, timestamp: u64) {
self.records
.push_back(RecordBuilder::new().timestamp_us(timestamp).build());
self.last_access = SystemTime::now();
}
pub fn add_timestamp(mut self, timestamp: SystemTime) -> Self {
self.records
.push_back(RecordBuilder::new().timestamp(timestamp).build());
self.last_access = SystemTime::now();
self
}
pub fn append_timestamp(&mut self, timestamp: SystemTime) {
self.records
.push_back(RecordBuilder::new().timestamp(timestamp).build());
self.last_access = SystemTime::now();
}
pub fn add_timestamps_us(mut self, timestamps: Vec<u64>) -> Self {
self.records.extend(
timestamps
.into_iter()
.map(|t| RecordBuilder::new().timestamp_us(t).build()),
);
self.last_access = SystemTime::now();
self
}
pub fn append_timestamps_us(&mut self, timestamps: Vec<u64>) {
self.records.extend(
timestamps
.into_iter()
.map(|t| RecordBuilder::new().timestamp_us(t).build()),
);
self.last_access = SystemTime::now();
}
pub fn add_timestamps(mut self, timestamps: Vec<SystemTime>) -> Self {
self.records.extend(
timestamps
.into_iter()
.map(|t| RecordBuilder::new().timestamp(t).build()),
);
self.last_access = SystemTime::now();
self
}
pub fn append_timestamps(&mut self, timestamps: Vec<SystemTime>) {
self.records.extend(
timestamps
.into_iter()
.map(|t| RecordBuilder::new().timestamp(t).build()),
);
self.last_access = SystemTime::now();
}
pub async fn send(mut self) -> Result<FailedRecordMap, ReductError> {
let method = match self.batch_type {
WriteBatchType::Write => Method::POST,
WriteBatchType::Update => Method::PATCH,
WriteBatchType::Remove => Method::DELETE,
};
let request = self
.client
.request(method, &format!("/b/{}/{}/batch", self.bucket, self.entry));
let content_length: usize = self.records.iter().map(|r| r.content_length()).sum();
let mut request = request.header(
CONTENT_TYPE,
HeaderValue::from_str("application/octet-stream").unwrap(),
);
request = match self.batch_type {
WriteBatchType::Update => {
request.header(CONTENT_LENGTH, HeaderValue::from_str("0").unwrap())
}
WriteBatchType::Write => request.header(
CONTENT_LENGTH,
HeaderValue::from_str(&content_length.to_string()).unwrap(),
),
WriteBatchType::Remove => {
request.header(CONTENT_LENGTH, HeaderValue::from_str("0").unwrap())
}
};
for record in &self.records {
let mut header_values = Vec::new();
match self.batch_type {
WriteBatchType::Update => {
header_values.push("0".to_string());
header_values.push("".to_string());
}
WriteBatchType::Write => {
header_values.push(record.content_length().to_string());
header_values.push(record.content_type().to_string());
}
WriteBatchType::Remove => {
header_values.push("0".to_string());
header_values.push("".to_string());
}
}
if !record.labels().is_empty() {
for (key, value) in record.labels() {
if value.contains(',') {
header_values.push(format!("{}=\"{}\"", key, value));
} else {
header_values.push(format!("{}={}", key, value));
}
}
}
request = request.header(
&format!("x-reduct-time-{}", record.timestamp_us()),
HeaderValue::from_str(&header_values.join(",").to_string()).unwrap(),
);
}
let client = Arc::clone(&self.client);
let response = match self.batch_type {
WriteBatchType::Update => client.send_request(request).await?,
WriteBatchType::Write => {
let stream = stream! {
while let Some(record) = self.records.pop_front() {
let mut stream = record.stream_bytes();
while let Some(bytes) = stream.next().await {
yield bytes;
}
}
};
client
.send_request(request.body(Body::wrap_stream(stream)))
.await?
}
WriteBatchType::Remove => client.send_request(request).await?,
};
let mut failed_records = FailedRecordMap::new();
response
.headers()
.iter()
.filter(|(key, _)| key.as_str().starts_with("x-reduct-error"))
.for_each(|(key, value)| {
let record_ts = key
.as_str()
.trim_start_matches("x-reduct-error-")
.parse::<u64>()
.unwrap();
let (status, message) = value.to_str().unwrap().split_once(',').unwrap();
failed_records.insert(
record_ts,
ReductError::new(
ErrorCode::try_from(status.parse::<i16>().unwrap()).unwrap(),
message,
),
);
});
Ok(failed_records)
}
pub fn size(&self) -> usize {
self.records.iter().map(|r| r.content_length()).sum()
}
pub fn record_count(&self) -> usize {
self.records.len()
}
pub fn last_access(&self) -> SystemTime {
self.last_access
}
pub fn clear(&mut self) {
self.records.clear();
}
}