use crate::api::components::CLIENT_IP_HEADER;
use crate::api::http::HttpError;
use crate::auth::policy::WriteAccessPolicy;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum_extra::headers::{Expect, Header, HeaderMap};
use crate::api::http::entry::common::{
parse_content_length_from_header, parse_timestamp_from_query,
};
use crate::api::http::StateKeeper;
use crate::api::limits::limit_scope_from_client_ip;
use crate::replication::Transaction::WriteRecord;
use crate::replication::TransactionNotification;
use futures_util::StreamExt;
use log::{debug, error};
use reduct_base::error::ReductError;
use reduct_base::io::RecordMeta;
use reduct_base::{bad_request, unprocessable_entity, Labels};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::time::timeout;
pub(super) async fn write_record(
State(keeper): State<Arc<StateKeeper>>,
headers: HeaderMap,
Path(path): Path<HashMap<String, String>>,
Query(params): Query<HashMap<String, String>>,
body: Body,
) -> Result<(), HttpError> {
let bucket = path.get("bucket_name").unwrap().clone();
let components = keeper
.get_with_permissions(&headers.clone(), WriteAccessPolicy { bucket: &bucket })
.await?;
let mut stream = body.into_data_stream();
let check_request_and_get_sender = async {
let ts = parse_timestamp_from_query(¶ms)?;
let content_size = parse_content_length_from_header(&headers)?;
let scope = limit_scope_from_client_ip(
headers
.get(CLIENT_IP_HEADER)
.and_then(|value| value.to_str().ok()),
);
components
.limits
.check_ingress_for(scope, content_size)
.await?;
let content_type = headers
.get("content-type")
.map_or("application/octet-stream", |v| v.to_str().unwrap())
.to_string();
let mut labels = Labels::new();
for (k, v) in headers.iter() {
if k.as_str().starts_with("x-reduct-label-") {
let key = k.as_str()[15..].to_string();
let value = match v.to_str() {
Ok(value) => value.to_string(),
Err(_) => {
return Err(unprocessable_entity!(
"Label values for {} must be valid UTF-8 strings",
k
)
.into());
}
};
labels.insert(key, value);
}
}
let sender = {
components
.storage
.begin_write(
&bucket,
path.get("entry_name").unwrap(),
ts,
content_size,
content_type,
labels.clone(),
)
.await?
};
Ok((ts, labels, sender))
};
let io_timeout = components.cfg.io_conf.operation_timeout;
match check_request_and_get_sender.await {
Ok((ts, labels, mut writer)) => {
macro_rules! send_chunk {
($chunk:expr) => {
writer.send_timeout($chunk, io_timeout).await?;
};
}
while let Some(chunk) = timeout(io_timeout, stream.next())
.await
.map_err(|_| bad_request!("Timeout while receiving data"))?
{
let chunk = match chunk {
Ok(chunk) => Ok(Some(chunk)),
Err(e) => {
error!("Error while receiving data: {}", e);
let err = HttpError::from(e).into_inner();
send_chunk!(Err(err.clone()));
return Err(err.into());
}
};
send_chunk!(chunk);
}
if let Err(err) = writer.send_timeout(Ok(None), io_timeout).await {
debug!("Timeout while sending EOF: {}", err);
}
components
.replication_repo
.write()
.await?
.notify(TransactionNotification {
bucket: bucket.clone(),
entry: path.get("entry_name").unwrap().to_string(),
meta: RecordMeta::builder().timestamp(ts).labels(labels).build(),
event: WriteRecord(ts),
})
.await?;
Ok(())
}
Err(e) => {
if !headers.contains_key(Expect::name()) {
debug!("draining the stream");
while let Some(_) = stream.next().await {}
}
Err(e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::http::tests::{
empty_body, ingress_limited_keeper, keeper, path_to_entry_1, storage_limited_keeper,
};
use axum_extra::headers::{Authorization, HeaderMapExt};
use reduct_base::error::ErrorCode;
use reduct_base::io::ReadRecord;
use reduct_base::not_found;
use rstest::*;
use tokio::time::{sleep, Duration};
#[rstest]
#[tokio::test]
async fn test_write_with_label_ok(
#[future] keeper: Arc<StateKeeper>,
headers: HeaderMap,
path_to_entry_1: Path<HashMap<String, String>>,
#[future] empty_body: Body,
) {
let keeper = keeper.await;
let components = keeper.get_anonymous().await.unwrap();
write_record(
State(keeper),
headers,
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"1".to_string(),
)])),
empty_body.await,
)
.await
.unwrap();
let record = components
.storage
.get_bucket("bucket-1")
.await
.unwrap()
.upgrade_and_unwrap()
.begin_read("entry-1", 1)
.await
.unwrap();
assert_eq!(&record.meta().labels()["x"], "y");
let info = components
.replication_repo
.read()
.await
.unwrap()
.get_info("api-test")
.await
.unwrap();
if info.info.pending_records == 0 {
sleep(Duration::from_millis(50)).await;
}
let info = components
.replication_repo
.read()
.await
.unwrap()
.get_info("api-test")
.await
.unwrap();
assert!(info.info.pending_records >= 1);
}
#[rstest]
#[tokio::test]
async fn test_write_bucket_not_found(
#[future] keeper: Arc<StateKeeper>,
headers: HeaderMap,
#[future] empty_body: Body,
) {
let path = Path(HashMap::from_iter(vec![
("bucket_name".to_string(), "XXX".to_string()),
("entry_name".to_string(), "entry-1".to_string()),
]));
let err = write_record(
State(keeper.await),
headers,
path,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"1".to_string(),
)])),
empty_body.await,
)
.await
.err()
.unwrap();
assert_eq!(err, not_found!("Bucket 'XXX' is not found").into());
}
#[rstest]
#[tokio::test]
async fn test_write_bad_ts(
#[future] keeper: Arc<StateKeeper>,
headers: HeaderMap,
path_to_entry_1: Path<HashMap<String, String>>,
#[future] empty_body: Body,
) {
let err = write_record(
State(keeper.await),
headers,
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"bad".to_string(),
)])),
empty_body.await,
)
.await
.err()
.unwrap();
assert_eq!(
err,
unprocessable_entity!("'ts' must be an unix timestamp in microseconds").into()
);
}
#[rstest]
#[tokio::test]
async fn test_write_ingress_rate_limit(
#[future] ingress_limited_keeper: Arc<StateKeeper>,
mut headers: HeaderMap,
path_to_entry_1: Path<HashMap<String, String>>,
) {
headers.insert("content-length", "2".parse().unwrap());
let err = write_record(
State(ingress_limited_keeper.await),
headers,
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"2".to_string(),
)])),
Body::from("ab"),
)
.await
.err()
.unwrap();
let err: ReductError = err.into();
assert_eq!(err.status, ErrorCode::TooManyRequests);
assert!(err.message.contains("ingress bytes"));
}
#[rstest]
#[tokio::test]
async fn test_write_storage_limit_exceeded(
#[future] storage_limited_keeper: Arc<StateKeeper>,
mut headers: HeaderMap,
path_to_entry_1: Path<HashMap<String, String>>,
) {
headers.insert("content-length", "1".parse().unwrap());
let err = write_record(
State(storage_limited_keeper.await),
headers,
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"2".to_string(),
)])),
Body::from("a"),
)
.await
.err()
.unwrap();
let err: ReductError = err.into();
assert_eq!(err.status, ErrorCode::InternalServerError);
assert_eq!(err.message, "storage limit exceeded");
}
#[fixture]
pub fn headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("content-length", "0".parse().unwrap());
headers.insert("x-reduct-label-x", "y".parse().unwrap());
headers.typed_insert(Authorization::bearer("init-token").unwrap());
headers
}
}