1use crate::api::{Components, HttpError};
5use crate::auth::policy::WriteAccessPolicy;
6use axum::body::{Body, BodyDataStream};
7use axum::extract::{Path, State};
8use axum::response::IntoResponse;
9use axum_extra::headers::{Expect, Header, HeaderMap};
10use bytes::Bytes;
11use futures_util::StreamExt;
12
13use crate::api::entry::common::err_to_batched_header;
14use crate::api::StateKeeper;
15use crate::replication::{Transaction, TransactionNotification};
16use crate::storage::bucket::Bucket;
17use crate::storage::entry::RecordDrainer;
18use log::{debug, error};
19use reduct_base::batch::{parse_batched_header, sort_headers_by_time, RecordHeader};
20use reduct_base::error::ReductError;
21use reduct_base::io::{RecordMeta, WriteRecord};
22use reduct_base::{bad_request, internal_server_error, unprocessable_entity};
23use std::collections::{BTreeMap, HashMap};
24use std::sync::Arc;
25use tokio::sync::mpsc::Receiver;
26use tokio::task::JoinHandle;
27use tokio::time::timeout;
28
29struct WriteContext {
30 time: u64,
31 header: RecordHeader,
32 writer: Box<dyn WriteRecord + Sync + Send>,
33}
34
35type ErrorMap = BTreeMap<u64, ReductError>;
36
37pub(super) async fn write_batched_records(
39 State(keeper): State<Arc<StateKeeper>>,
40 headers: HeaderMap,
41 Path(path): Path<HashMap<String, String>>,
42 body: Body,
43) -> Result<impl IntoResponse, HttpError> {
44 let bucket = path.get("bucket_name").unwrap();
45 let components = keeper
46 .get_with_permissions(&headers.clone(), WriteAccessPolicy { bucket })
47 .await?;
48
49 let entry_name = path.get("entry_name").unwrap().clone();
50 let record_headers: Vec<_> = sort_headers_by_time(&headers)?;
51 let mut stream = body.into_data_stream();
52
53 let process_stream = async {
54 let mut timed_headers: Vec<(u64, RecordHeader)> = Vec::new();
55 for (time, v) in record_headers {
56 let header = parse_batched_header(v.to_str().unwrap())?;
57 timed_headers.push((time, header));
58 }
59
60 let content_length = check_and_get_content_length(&headers, &timed_headers)?;
61 let record_count = timed_headers.len();
62
63 let (rx_writer, spawn_handler) =
64 spawn_getting_writers(&components, &bucket, &entry_name, timed_headers)?;
65
66 if content_length > 0 {
67 receive_body_and_write_records(
68 bucket,
69 entry_name,
70 components,
71 record_count,
72 &mut stream,
73 rx_writer,
74 )
75 .await?;
76 } else {
77 write_only_metadata(bucket, entry_name, components, rx_writer).await?;
78 }
79
80 Ok(spawn_handler
81 .await
82 .map_err(|e| internal_server_error!("Failed to complete write operation: {}", e))?)
83 };
84
85 match process_stream.await {
86 Ok(error_map) => {
87 let mut headers = HeaderMap::new();
88 error_map.iter().for_each(|(time, err)| {
89 err_to_batched_header(&mut headers, *time, err);
90 });
91
92 Ok(headers.into())
93 }
94
95 Err(err) => {
96 if !headers.contains_key(Expect::name()) {
97 debug!("draining the stream");
98 while let Some(_) = stream.next().await {}
99 }
100 Err::<HeaderMap, HttpError>(err)
101 }
102 }
103}
104
105async fn notify_replication_write(
106 components: &Arc<Components>,
107 bucket: &str,
108 entry_name: &str,
109 ctx: &WriteContext,
110) -> Result<(), ReductError> {
111 components
112 .replication_repo
113 .write()
114 .await
115 .notify(TransactionNotification {
116 bucket: bucket.to_string(),
117 entry: entry_name.to_string(),
118 meta: RecordMeta::builder()
119 .timestamp(ctx.time)
120 .labels(ctx.header.labels.clone())
121 .build(),
122 event: Transaction::WriteRecord(ctx.time),
123 })?;
124 Ok(())
125}
126
127async fn write_only_metadata(
128 bucket: &String,
129 entry_name: String,
130 components: Arc<Components>,
131 mut rx_writer: Receiver<WriteContext>,
132) -> Result<(), ReductError> {
133 while let Some(mut ctx) = rx_writer.recv().await {
134 if let Err(err) = ctx
135 .writer
136 .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
137 .await
138 {
139 debug!("Timeout while sending EOF: {}", err);
140 }
141
142 notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
143 }
144
145 Ok(())
146}
147
148async fn receive_body_and_write_records(
149 bucket: &String,
150 entry_name: String,
151 components: Arc<Components>,
152 mut record_count: usize,
153 stream: &mut BodyDataStream,
154 mut rx_writer: Receiver<WriteContext>,
155) -> Result<(), ReductError> {
156 let mut written = 0;
157 let mut ctx = rx_writer
158 .recv()
159 .await
160 .ok_or(internal_server_error!("No writer found"))?;
161
162 while let Some(chunk) = timeout(components.cfg.io_conf.operation_timeout, stream.next())
163 .await
164 .map_err(|_| internal_server_error!("Timeout while receiving data"))?
165 {
166 let mut chunk =
167 chunk.map_err(|e| bad_request!("Error while receiving data chunk: {}", e))?;
168
169 while !chunk.is_empty() {
170 match write_chunk(
171 &mut ctx.writer,
172 chunk,
173 &mut written,
174 ctx.header.content_length.clone(),
175 components.cfg.io_conf.operation_timeout,
176 )
177 .await
178 {
179 Ok(None) => break, Ok(Some(rest)) => {
181 if let Err(err) = ctx
183 .writer
184 .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
185 .await
186 {
187 debug!("Timeout while sending EOF: {}", err);
188 }
189
190 notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
191
192 chunk = rest;
193 record_count -= 1;
194 written = 0;
195
196 ctx = match rx_writer.recv().await {
197 Some(ctx) => ctx,
198 None => break, };
200 }
201 Err(err) => return Err(err),
202 }
203 }
204 }
205
206 if record_count != 0 {
207 return Err(bad_request!("Content is shorter than expected"));
208 }
209
210 Ok(())
211}
212
213fn spawn_getting_writers(
214 components: &Arc<Components>,
215 bucket_name: &str,
216 entry_name: &str,
217 timed_headers: Vec<(u64, RecordHeader)>,
218) -> Result<(Receiver<WriteContext>, JoinHandle<ErrorMap>), ReductError> {
219 let (tx_writer, rx_writer) = tokio::sync::mpsc::channel(64);
220
221 let bucket = components
222 .storage
223 .get_bucket(&bucket_name)?
224 .upgrade_and_unwrap();
225
226 let entry_name = entry_name.to_string();
227 let spawn_handler = tokio::spawn(async move {
228 let mut error_map = BTreeMap::new();
229
230 for (time, header) in timed_headers.into_iter() {
231 let writer =
232 start_writing(&entry_name, bucket.clone(), time, &header, &mut error_map).await;
233
234 tx_writer
235 .send(WriteContext {
236 time,
237 header,
238 writer,
239 })
240 .await
241 .map_err(|err| error!("Failed to send the writer: {}", err))
242 .unwrap_or(());
243 }
244 error_map
245 });
246
247 Ok((rx_writer, spawn_handler))
248}
249
250async fn write_chunk(
251 writer: &mut Box<dyn WriteRecord + Sync + Send>,
252 chunk: Bytes,
253 written: &mut usize,
254 content_size: u64,
255 io_timeout: std::time::Duration,
256) -> Result<Option<Bytes>, ReductError> {
257 let to_write = content_size - *written as u64;
258 *written += chunk.len();
259 let (chunk, rest) = if (chunk.len() as u64) < to_write {
260 (chunk, None)
261 } else {
262 let chuck_to_write = chunk.slice(0..to_write as usize);
263 (chuck_to_write, Some(chunk.slice(to_write as usize..)))
264 };
265
266 writer.send_timeout(Ok(Some(chunk)), io_timeout).await?;
267 Ok(rest)
268}
269
270fn check_and_get_content_length(
271 headers: &HeaderMap,
272 timed_headers: &Vec<(u64, RecordHeader)>,
273) -> Result<u64, ReductError> {
274 let total_content_length: u64 = timed_headers
275 .iter()
276 .map(|(_, header)| header.content_length)
277 .sum();
278
279 if total_content_length
280 != headers
281 .get("content-length")
282 .ok_or(unprocessable_entity!("content-length header is required",))?
283 .to_str()
284 .unwrap()
285 .parse::<u64>()
286 .map_err(|_| unprocessable_entity!("Invalid content-length header"))?
287 {
288 return Err(unprocessable_entity!(
289 "content-length header does not match the sum of the content-lengths in the headers",
290 )
291 .into());
292 }
293
294 Ok(total_content_length)
295}
296
297async fn start_writing(
298 entry_name: &str,
299 bucket: Arc<Bucket>,
300 time: u64,
301 record_header: &RecordHeader,
302 error_map: &mut BTreeMap<u64, ReductError>,
303) -> Box<dyn WriteRecord + Sync + Send> {
304 let get_writer = async {
305 bucket
306 .begin_write(
307 entry_name,
308 time,
309 record_header.content_length.clone(),
310 record_header.content_type.clone(),
311 record_header.labels.clone(),
312 )
313 .await
314 };
315
316 match get_writer.await {
317 Ok(writer) => writer,
318 Err(err) => {
319 error_map.insert(time, err);
320 Box::new(RecordDrainer::new())
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::api::entry::write_batched::write_batched_records;
330 use crate::api::tests::{headers, keeper, path_to_entry_1};
331
332 use axum_extra::headers::HeaderValue;
333 use reduct_base::error::ErrorCode;
334 use reduct_base::io::ReadRecord;
335 use rstest::{fixture, rstest};
336
337 #[rstest]
338 #[tokio::test]
339 async fn test_write_record_bad_timestamp(
340 #[future] keeper: Arc<StateKeeper>,
341 mut headers: HeaderMap,
342 path_to_entry_1: Path<HashMap<String, String>>,
343 #[future] body_stream: Body,
344 ) {
345 headers.insert("content-length", "10".parse().unwrap());
346 headers.insert("x-reduct-time-yyy", "10".parse().unwrap());
347
348 let err = write_batched_records(
349 State(keeper.await),
350 headers,
351 path_to_entry_1,
352 body_stream.await,
353 )
354 .await
355 .err()
356 .unwrap();
357
358 assert_eq!(
359 err,
360 HttpError::new(
361 ErrorCode::UnprocessableEntity,
362 "Invalid header 'x-reduct-time-yyy': must be an unix timestamp in microseconds",
363 )
364 );
365 }
366
367 #[rstest]
368 #[tokio::test]
369 async fn test_write_batched_invalid_header(
370 #[future] keeper: Arc<StateKeeper>,
371 mut headers: HeaderMap,
372 path_to_entry_1: Path<HashMap<String, String>>,
373 #[future] body_stream: Body,
374 ) {
375 headers.insert("content-length", "10".parse().unwrap());
376 headers.insert("x-reduct-time-1", "".parse().unwrap());
377
378 let err = write_batched_records(
379 State(keeper.await),
380 headers,
381 path_to_entry_1,
382 body_stream.await,
383 )
384 .await
385 .err()
386 .unwrap();
387
388 assert_eq!(
389 err,
390 HttpError::new(ErrorCode::UnprocessableEntity, "Invalid batched header")
391 );
392 }
393
394 #[rstest]
395 #[tokio::test]
396 async fn test_write_batched_records(
397 #[future] keeper: Arc<StateKeeper>,
398 mut headers: HeaderMap,
399 path_to_entry_1: Path<HashMap<String, String>>,
400 #[future] body_stream: Body,
401 ) {
402 let keeper = keeper.await;
403 let components = keeper.get_anonymous().await.unwrap();
404 headers.insert("content-length", "48".parse().unwrap());
405 headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
406 headers.insert(
407 "x-reduct-time-2",
408 "20,text/plain,c=\"d,f\"".parse().unwrap(),
409 );
410 headers.insert("x-reduct-time-10", "18,text/plain".parse().unwrap());
411
412 let stream = body_stream.await;
413
414 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
415 .await
416 .unwrap();
417
418 let bucket = components
419 .storage
420 .get_bucket("bucket-1")
421 .unwrap()
422 .upgrade_and_unwrap();
423
424 {
425 let mut reader = bucket
426 .get_entry("entry-1")
427 .unwrap()
428 .upgrade_and_unwrap()
429 .begin_read(1)
430 .await
431 .unwrap();
432 assert_eq!(&reader.meta().labels()["a"], "b");
433 assert_eq!(reader.meta().content_type(), "text/plain");
434 assert_eq!(reader.meta().content_length(), 10);
435 assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
436 }
437 {
438 let mut reader = bucket
439 .get_entry("entry-1")
440 .unwrap()
441 .upgrade_and_unwrap()
442 .begin_read(2)
443 .await
444 .unwrap();
445 assert_eq!(&reader.meta().labels()["c"], "d,f");
446 assert_eq!(reader.meta().content_type(), "text/plain");
447 assert_eq!(reader.meta().content_length(), 20);
448 assert_eq!(
449 reader.read_chunk().unwrap(),
450 Ok(Bytes::from("abcdef1234567890abcd"))
451 );
452 }
453 {
454 let mut reader = bucket
455 .get_entry("entry-1")
456 .unwrap()
457 .upgrade_and_unwrap()
458 .begin_read(10)
459 .await
460 .unwrap();
461 assert!(reader.meta().labels().is_empty());
462 assert_eq!(reader.meta().content_type(), "text/plain");
463 assert_eq!(reader.meta().content_length(), 18);
464 assert_eq!(
465 reader.read_chunk().unwrap(),
466 Ok(Bytes::from("ef1234567890abcdef"))
467 );
468 }
469
470 let info = components
471 .replication_repo
472 .read()
473 .await
474 .get_info("api-test")
475 .unwrap();
476 assert_eq!(info.info.pending_records, 3);
477 }
478
479 #[rstest]
480 #[tokio::test]
481 async fn test_write_batched_records_with_empty_bodies(
482 #[future] keeper: Arc<StateKeeper>,
483 mut headers: HeaderMap,
484 path_to_entry_1: Path<HashMap<String, String>>,
485 ) {
486 let keeper = keeper.await;
487 let components = keeper.get_anonymous().await.unwrap();
488 headers.insert("content-length", "0".parse().unwrap());
489 headers.insert("x-reduct-time-1", "0,,a=b".parse().unwrap());
490 headers.insert("x-reduct-time-2", "0,,a=d".parse().unwrap());
491
492 let stream = Body::empty();
493
494 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
495 .await
496 .unwrap();
497
498 let bucket = components
499 .storage
500 .get_bucket("bucket-1")
501 .unwrap()
502 .upgrade_and_unwrap();
503
504 {
505 let mut reader = bucket
506 .get_entry("entry-1")
507 .unwrap()
508 .upgrade_and_unwrap()
509 .begin_read(1)
510 .await
511 .unwrap();
512 assert_eq!(reader.meta().content_length(), 0);
513 assert_eq!(reader.read_chunk(), None);
514 }
515 {
516 let mut reader = bucket
517 .get_entry("entry-1")
518 .unwrap()
519 .upgrade_and_unwrap()
520 .begin_read(2)
521 .await
522 .unwrap();
523 assert_eq!(reader.meta().content_length(), 0);
524 assert_eq!(reader.read_chunk(), None);
525 }
526 }
527
528 #[rstest]
529 #[tokio::test]
530 async fn test_write_batched_records_error(
531 #[future] keeper: Arc<StateKeeper>,
532 mut headers: HeaderMap,
533 path_to_entry_1: Path<HashMap<String, String>>,
534 #[future] body_stream: Body,
535 ) {
536 let keeper = keeper.await;
537 let components = keeper.get_anonymous().await.unwrap();
538 {
539 let mut writer = components
540 .storage
541 .get_bucket("bucket-1")
542 .unwrap()
543 .upgrade_and_unwrap()
544 .begin_write("entry-1", 2, 20, "text/plain".to_string(), HashMap::new())
545 .await
546 .unwrap();
547 writer
548 .send(Ok(Some(Bytes::from(vec![0; 20]))))
549 .await
550 .unwrap();
551 writer.send(Ok(None)).await.unwrap();
552 }
553
554 headers.insert("content-length", "48".parse().unwrap());
555 headers.insert("x-reduct-time-1", "10,".parse().unwrap());
556 headers.insert("x-reduct-time-2", "20,".parse().unwrap());
557 headers.insert("x-reduct-time-3", "18,".parse().unwrap());
558
559 let stream = body_stream.await;
560
561 let resp =
562 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
563 .await
564 .unwrap()
565 .into_response();
566
567 let headers = resp.headers();
568 assert_eq!(headers.len(), 1);
569 assert_eq!(
570 headers.get("x-reduct-error-2").unwrap(),
571 &HeaderValue::from_static("409,A record with timestamp 2 already exists")
572 );
573
574 let bucket = components
575 .storage
576 .get_bucket("bucket-1")
577 .unwrap()
578 .upgrade_and_unwrap();
579 {
580 let mut reader = bucket.begin_read("entry-1", 1).await.unwrap();
581 assert_eq!(reader.meta().content_length(), 10);
582 assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
583 }
584 {
585 let mut reader = bucket.begin_read("entry-1", 3).await.unwrap();
586 assert_eq!(reader.meta().content_length(), 18);
587 assert_eq!(
588 reader.read_chunk().unwrap(),
589 Ok(Bytes::from("ef1234567890abcdef"))
590 );
591 }
592 }
593
594 #[fixture]
595 async fn body_stream() -> Body {
596 Body::from("1234567890abcdef1234567890abcdef1234567890abcdef")
597 }
598}