use crate::api::components::CLIENT_IP_HEADER;
use crate::api::http::entry::MethodExtractor;
use crate::api::http::Components;
use crate::api::http::HttpError;
use crate::auth::policy::ReadAccessPolicy;
use reduct_base::error::ReductError;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use axum_extra::headers::HeaderMap;
use crate::api::http::entry::common::check_and_extract_ts_or_query_id;
use crate::api::http::utils::{make_headers_from_reader, RecordStream};
use crate::api::http::StateKeeper;
use crate::api::limits::{limit_scope_from_client_ip, LimitScope};
use crate::core::sync::AsyncRwLock;
use crate::core::weak::Weak;
use crate::storage::entry::{Entry, RecordReader};
use crate::storage::query::QueryRx;
use reduct_base::bad_request;
use reduct_base::io::ReadRecord;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub(super) async fn read_record(
State(keeper): State<Arc<StateKeeper>>,
Path(path): Path<HashMap<String, String>>,
Query(params): Query<HashMap<String, String>>,
headers: HeaderMap,
method: MethodExtractor,
) -> Result<impl IntoResponse, HttpError> {
let bucket_name = path.get("bucket_name").unwrap();
let entry_name = path.get("entry_name").unwrap();
let components = keeper
.get_with_permissions(
&headers,
ReadAccessPolicy {
bucket: bucket_name,
},
)
.await?;
let entry = components
.storage
.get_bucket(bucket_name)
.await?
.upgrade()?
.get_entry(entry_name)
.await?;
let last_record = entry.upgrade()?.info().await?.latest_record;
let (query_id, ts) = check_and_extract_ts_or_query_id(params, last_record)?;
let scope = limit_scope_from_client_ip(
headers
.get(CLIENT_IP_HEADER)
.and_then(|value| value.to_str().ok()),
);
fetch_and_response_single_record(
components,
entry,
ts,
query_id,
method.name() == "HEAD",
scope,
)
.await
}
async fn fetch_and_response_single_record(
components: Arc<Components>,
entry: Weak<Entry>,
ts: Option<u64>,
query_id: Option<u64>,
empty_body: bool,
scope: LimitScope,
) -> Result<impl IntoResponse, HttpError> {
let entry = entry.upgrade()?;
let reader = if let Some(ts) = ts {
entry.begin_read(ts).await?
} else {
let query_id = query_id.unwrap();
let (rx, _) = entry.get_query_receiver(query_id).await?;
let query_path = format!("{}/{}/{}", entry.bucket_name(), entry.name(), query_id);
next_record_reader(rx, &query_path).await?
};
if !empty_body {
components
.limits
.check_egress_for(scope, reader.meta().content_length())
.await?;
let headers = make_headers_from_reader(reader.meta());
return Ok((
headers,
Body::from_stream(RecordStream::new(
Arc::new(Mutex::new(Box::new(reader))),
empty_body,
)),
));
}
let headers = make_headers_from_reader(reader.meta());
Ok((
headers,
Body::from_stream(RecordStream::new(
Arc::new(Mutex::new(Box::new(reader))),
empty_body,
)),
))
}
async fn next_record_reader(
rx: Weak<AsyncRwLock<QueryRx>>,
query_path: &str,
) -> Result<RecordReader, HttpError> {
let rc = rx
.upgrade()
.map_err(|_| bad_request!("Query '{}' was closed", query_path))?;
let mut rx = rc.write().await?;
if let Some(reader) = rx.recv().await {
reader.map_err(|e| HttpError::from(e))
} else {
Err(bad_request!("Query'{}' was closed: broken channel", query_path).into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::http::entry::tests::query;
use crate::api::http::tests::{egress_limited_keeper, headers, keeper, path_to_entry_1};
use axum::body::to_bytes;
use bytes::Bytes;
use reduct_base::error::ErrorCode;
use reduct_base::error::ErrorCode::NotFound;
use rstest::*;
#[rstest]
#[case("GET", "Hey!!!")]
#[case("HEAD", "")]
#[tokio::test(flavor = "multi_thread")]
async fn test_single_read_ts(
#[future] keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
#[case] method: String,
#[case] body: String,
) {
let keeper = keeper.await;
let response = read_record(
State(keeper.clone()),
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"0".to_string(),
)])),
headers,
MethodExtractor::new(&method),
)
.await
.unwrap()
.into_response();
let headers = response.headers();
assert_eq!(headers["x-reduct-time"], "0");
assert_eq!(headers["content-type"], "text/plain");
assert_eq!(headers["content-length"], "6");
assert_eq!(
to_bytes(response.into_body(), usize::MAX).await.unwrap(),
Bytes::from(body)
);
}
#[rstest]
#[case("GET", "Hey!!!")]
#[case("HEAD", "")]
#[tokio::test(flavor = "multi_thread")]
async fn test_single_read_query(
#[future] keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
#[case] method: String,
#[case] body: String,
) {
let keeper = keeper.await;
let query_id = query(&path_to_entry_1, keeper.clone(), None).await;
let response = read_record(
State(keeper.clone()),
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"q".to_string(),
query_id.to_string(),
)])),
headers,
MethodExtractor::new(&method),
)
.await
.unwrap()
.into_response();
let headers = response.headers();
assert_eq!(headers["x-reduct-time"], "0");
assert_eq!(headers["content-type"], "text/plain");
assert_eq!(headers["content-length"], "6");
assert_eq!(
to_bytes(response.into_body(), usize::MAX).await.unwrap(),
Bytes::from(body)
);
}
#[rstest]
#[tokio::test]
async fn test_single_read_bucket_not_found(
#[future] keeper: Arc<StateKeeper>,
headers: HeaderMap,
) {
let keeper = keeper.await;
let path = Path(HashMap::from_iter(vec![
("bucket_name".to_string(), "XXX".to_string()),
("entry_name".to_string(), "entru-1".to_string()),
]));
let err = read_record(
State(keeper.clone()),
path,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"0".to_string(),
)])),
headers,
MethodExtractor::new("GET"),
)
.await
.err()
.unwrap();
assert_eq!(err, HttpError::new(NotFound, "Bucket 'XXX' is not found"));
}
#[rstest]
#[tokio::test]
async fn test_single_read_ts_not_found(
#[future] keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
) {
let keeper = keeper.await;
let err = read_record(
State(keeper.clone()),
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"1".to_string(),
)])),
headers,
MethodExtractor::new("GET"),
)
.await
.err()
.unwrap();
assert_eq!(
err,
HttpError::new(NotFound, "Record 1 not found in block bucket-1/entry-1/0")
);
}
#[rstest]
#[tokio::test]
async fn test_single_read_bad_ts(
#[future] keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
) {
let keeper = keeper.await;
let err = read_record(
State(keeper.clone()),
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"bad".to_string(),
)])),
headers,
MethodExtractor::new("GET"),
)
.await
.err()
.unwrap();
assert_eq!(
err,
HttpError::new(
ErrorCode::UnprocessableEntity,
"'ts' must be an unix timestamp in microseconds",
)
);
}
#[rstest]
#[tokio::test]
async fn test_single_read_query_not_found(
#[future] keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
) {
let keeper = keeper.await;
let err = read_record(
State(keeper.clone()),
path_to_entry_1,
Query(HashMap::from_iter(vec![("q".to_string(), "1".to_string())])),
headers,
MethodExtractor::new("GET"),
)
.await
.err()
.unwrap();
assert_eq!(
err,
HttpError::new(
NotFound,
"Query 1 not found and it might have expired. Check TTL in your query request."
)
);
}
mod next_record_reader {
use super::*;
#[rstest]
#[tokio::test]
async fn test_next_record_reader_err() {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let rx = Arc::new(AsyncRwLock::new(rx));
drop(tx);
assert_eq!(
next_record_reader(rx.into(), "test").await.err().unwrap(),
HttpError::new(ErrorCode::BadRequest, "Query 'test' was closed")
);
}
}
mod steam_wrapper {
use super::*;
use crate::storage::entry::io::record_reader::tests::MockRecord;
use futures_util::Stream;
#[rstest]
fn test_size_hint() {
let wrapper =
RecordStream::new(Arc::new(Mutex::new(Box::new(MockRecord::new()))), false);
assert_eq!(wrapper.size_hint(), (0, None));
}
}
#[rstest]
#[tokio::test]
async fn test_single_read_egress_rate_limit(
#[future] egress_limited_keeper: Arc<StateKeeper>,
path_to_entry_1: Path<HashMap<String, String>>,
headers: HeaderMap,
) {
let err = read_record(
State(egress_limited_keeper.await),
path_to_entry_1,
Query(HashMap::from_iter(vec![(
"ts".to_string(),
"0".to_string(),
)])),
headers,
MethodExtractor::new("GET"),
)
.await
.err()
.unwrap();
assert_eq!(err.status(), ErrorCode::TooManyRequests);
assert!(err.message().contains("egress bytes"));
}
}