1use crate::api::{Components, HttpError};
5use crate::auth::policy::WriteAccessPolicy;
6use axum::body::{Body, BodyDataStream};
7use axum::extract::{Path, State};
8use axum::response::IntoResponse;
9use axum_extra::headers::{Expect, Header, HeaderMap};
10use bytes::Bytes;
11use futures_util::StreamExt;
12
13use crate::api::entry::common::err_to_batched_header;
14use crate::api::StateKeeper;
15use crate::replication::{Transaction, TransactionNotification};
16use crate::storage::bucket::Bucket;
17use crate::storage::entry::RecordDrainer;
18use log::{debug, error};
19use reduct_base::batch::{parse_batched_header, sort_headers_by_time, RecordHeader};
20use reduct_base::error::ReductError;
21use reduct_base::io::{RecordMeta, WriteRecord};
22use reduct_base::{bad_request, internal_server_error, unprocessable_entity};
23use std::collections::{BTreeMap, HashMap};
24use std::sync::Arc;
25use tokio::sync::mpsc::Receiver;
26use tokio::task::JoinHandle;
27use tokio::time::timeout;
28
29struct WriteContext {
30 time: u64,
31 header: RecordHeader,
32 writer: Box<dyn WriteRecord + Sync + Send>,
33}
34
35type ErrorMap = BTreeMap<u64, ReductError>;
36
37pub(super) async fn write_batched_records(
39 State(keeper): State<Arc<StateKeeper>>,
40 headers: HeaderMap,
41 Path(path): Path<HashMap<String, String>>,
42 body: Body,
43) -> Result<impl IntoResponse, HttpError> {
44 let bucket = path.get("bucket_name").unwrap();
45 let components = keeper
46 .get_with_permissions(&headers.clone(), WriteAccessPolicy { bucket })
47 .await?;
48
49 let entry_name = path.get("entry_name").unwrap().clone();
50 let record_headers: Vec<_> = sort_headers_by_time(&headers)?;
51 let mut stream = body.into_data_stream();
52
53 let process_stream = async {
54 let mut timed_headers: Vec<(u64, RecordHeader)> = Vec::new();
55 for (time, v) in record_headers {
56 let header = parse_batched_header(v.to_str().unwrap())?;
57 timed_headers.push((time, header));
58 }
59
60 let content_length = check_and_get_content_length(&headers, &timed_headers)?;
61 let (rx_writer, spawn_handler) =
62 spawn_getting_writers(&components, &bucket, &entry_name, timed_headers).await?;
63 receive_body_and_write_records(
64 bucket,
65 entry_name,
66 components,
67 content_length as i64,
68 &mut stream,
69 rx_writer,
70 )
71 .await?;
72
73 Ok(spawn_handler
74 .await
75 .map_err(|e| internal_server_error!("Failed to complete write operation: {}", e))?)
76 };
77
78 match process_stream.await {
79 Ok(error_map) => {
80 let mut headers = HeaderMap::new();
81 error_map.iter().for_each(|(time, err)| {
82 err_to_batched_header(&mut headers, *time, err);
83 });
84
85 Ok(headers.into())
86 }
87
88 Err(err) => {
89 if !headers.contains_key(Expect::name()) {
90 debug!("draining the stream");
91 while let Some(_) = stream.next().await {}
92 }
93 Err::<HeaderMap, HttpError>(err)
94 }
95 }
96}
97
98async fn notify_replication_write(
99 components: &Arc<Components>,
100 bucket: &str,
101 entry_name: &str,
102 ctx: &WriteContext,
103) -> Result<(), ReductError> {
104 components
105 .replication_repo
106 .write()
107 .await?
108 .notify(TransactionNotification {
109 bucket: bucket.to_string(),
110 entry: entry_name.to_string(),
111 meta: RecordMeta::builder()
112 .timestamp(ctx.time)
113 .labels(ctx.header.labels.clone())
114 .build(),
115 event: Transaction::WriteRecord(ctx.time),
116 })
117 .await?;
118 Ok(())
119}
120
121async fn receive_body_and_write_records(
122 bucket: &String,
123 entry_name: String,
124 components: Arc<Components>,
125 mut total_content_len: i64,
126 stream: &mut BodyDataStream,
127 mut rx_writer: Receiver<WriteContext>,
128) -> Result<(), ReductError> {
129 let mut chunk = Bytes::new();
130
131 let mut read_chunk = async || -> Result<Bytes, ReductError> {
132 if total_content_len == 0 {
133 return Ok(Bytes::new());
135 }
136 match timeout(components.cfg.io_conf.operation_timeout, stream.next())
137 .await
138 .map_err(|_| bad_request!("Timeout while receiving data"))?
139 {
140 Some(Ok(data_chunk)) => {
141 total_content_len -= data_chunk.len() as i64;
142 Ok(data_chunk)
143 }
144 Some(Err(e)) => Err(bad_request!("Error while receiving data chunk: {}", e)),
145 None => Err(bad_request!(
146 "Content is shorter than expected: no more data to read"
147 )),
148 }
149 };
150
151 while let Some(mut ctx) = rx_writer.recv().await {
152 let mut written = 0;
153
154 if chunk.is_empty() {
155 chunk = read_chunk().await?
156 }
157
158 loop {
159 match write_chunk(
160 &mut ctx.writer,
161 chunk,
162 &mut written,
163 ctx.header.content_length.clone(),
164 components.cfg.io_conf.operation_timeout,
165 )
166 .await
167 {
168 Ok(None) => {
169 chunk = read_chunk().await?;
171 continue;
172 }
173 Ok(Some(rest)) => {
174 if let Err(err) = ctx
177 .writer
178 .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
179 .await
180 {
181 debug!("Timeout while sending EOF: {}", err);
182 }
183
184 notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
185 chunk = rest;
186 break;
187 }
188 Err(err) => return Err(err),
189 }
190 }
191 }
192
193 Ok(())
194}
195
196async fn spawn_getting_writers(
197 components: &Arc<Components>,
198 bucket_name: &str,
199 entry_name: &str,
200 timed_headers: Vec<(u64, RecordHeader)>,
201) -> Result<(Receiver<WriteContext>, JoinHandle<ErrorMap>), ReductError> {
202 let (tx_writer, rx_writer) = tokio::sync::mpsc::channel(64);
203
204 let bucket = components
205 .storage
206 .get_bucket(&bucket_name)
207 .await?
208 .upgrade_and_unwrap();
209
210 let entry_name = entry_name.to_string();
211 let spawn_handler = tokio::spawn(async move {
212 let mut error_map = BTreeMap::new();
213
214 for (time, header) in timed_headers.into_iter() {
215 let writer =
216 start_writing(&entry_name, bucket.clone(), time, &header, &mut error_map).await;
217
218 tx_writer
219 .send(WriteContext {
220 time,
221 header,
222 writer,
223 })
224 .await
225 .map_err(|err| error!("Failed to send the writer: {}", err))
226 .unwrap_or(());
227 }
228 error_map
229 });
230
231 Ok((rx_writer, spawn_handler))
232}
233
234async fn write_chunk(
235 writer: &mut Box<dyn WriteRecord + Sync + Send>,
236 chunk: Bytes,
237 written: &mut usize,
238 content_size: u64,
239 io_timeout: std::time::Duration,
240) -> Result<Option<Bytes>, ReductError> {
241 let to_write = content_size - *written as u64;
242 *written += chunk.len();
243 let (chunk, rest) = if (chunk.len() as u64) < to_write {
244 (chunk, None)
245 } else {
246 let chuck_to_write = chunk.slice(0..to_write as usize);
247 (chuck_to_write, Some(chunk.slice(to_write as usize..)))
248 };
249
250 writer.send_timeout(Ok(Some(chunk)), io_timeout).await?;
251 Ok(rest)
252}
253
254fn check_and_get_content_length(
255 headers: &HeaderMap,
256 timed_headers: &Vec<(u64, RecordHeader)>,
257) -> Result<u64, ReductError> {
258 let total_content_length: u64 = timed_headers
259 .iter()
260 .map(|(_, header)| header.content_length)
261 .sum();
262
263 if total_content_length
264 != headers
265 .get("content-length")
266 .ok_or(unprocessable_entity!("content-length header is required",))?
267 .to_str()
268 .unwrap()
269 .parse::<u64>()
270 .map_err(|_| unprocessable_entity!("Invalid content-length header"))?
271 {
272 return Err(unprocessable_entity!(
273 "content-length header does not match the sum of the content-lengths in the headers",
274 )
275 .into());
276 }
277
278 Ok(total_content_length)
279}
280
281async fn start_writing(
282 entry_name: &str,
283 bucket: Arc<Bucket>,
284 time: u64,
285 record_header: &RecordHeader,
286 error_map: &mut BTreeMap<u64, ReductError>,
287) -> Box<dyn WriteRecord + Sync + Send> {
288 let get_writer = async {
289 bucket
290 .begin_write(
291 entry_name,
292 time,
293 record_header.content_length.clone(),
294 record_header.content_type.clone(),
295 record_header.labels.clone(),
296 )
297 .await
298 };
299
300 match get_writer.await {
301 Ok(writer) => writer,
302 Err(err) => {
303 error_map.insert(time, err);
304 Box::new(RecordDrainer::new())
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::api::entry::write_batched::write_batched_records;
314 use crate::api::tests::{headers, keeper, path_to_entry_1};
315 use axum_extra::headers::HeaderValue;
316 use futures_util::stream;
317 use reduct_base::error::ErrorCode;
318 use reduct_base::io::ReadRecord;
319 use reduct_base::msg::replication_api::ReplicationMode;
320 use rstest::{fixture, rstest};
321 use std::time::Duration;
322 use tokio::time::sleep;
323
324 #[rstest]
325 #[tokio::test]
326 async fn test_write_record_bad_timestamp(
327 #[future] keeper: Arc<StateKeeper>,
328 mut headers: HeaderMap,
329 path_to_entry_1: Path<HashMap<String, String>>,
330 #[future] body_stream: Body,
331 ) {
332 headers.insert("content-length", "10".parse().unwrap());
333 headers.insert("x-reduct-time-yyy", "10".parse().unwrap());
334
335 let err = write_batched_records(
336 State(keeper.await),
337 headers,
338 path_to_entry_1,
339 body_stream.await,
340 )
341 .await
342 .err()
343 .unwrap();
344
345 assert_eq!(
346 err,
347 HttpError::new(
348 ErrorCode::UnprocessableEntity,
349 "Invalid header 'x-reduct-time-yyy': must be an unix timestamp in microseconds",
350 )
351 );
352 }
353
354 #[rstest]
355 #[tokio::test]
356 async fn test_write_batched_invalid_header(
357 #[future] keeper: Arc<StateKeeper>,
358 mut headers: HeaderMap,
359 path_to_entry_1: Path<HashMap<String, String>>,
360 #[future] body_stream: Body,
361 ) {
362 headers.insert("content-length", "10".parse().unwrap());
363 headers.insert("x-reduct-time-1", "".parse().unwrap());
364
365 let err = write_batched_records(
366 State(keeper.await),
367 headers,
368 path_to_entry_1,
369 body_stream.await,
370 )
371 .await
372 .err()
373 .unwrap();
374
375 assert_eq!(
376 err,
377 HttpError::new(ErrorCode::UnprocessableEntity, "Invalid batched header")
378 );
379 }
380
381 #[rstest]
382 #[tokio::test(flavor = "multi_thread")]
383 async fn test_write_batched_records(
384 #[future] keeper: Arc<StateKeeper>,
385 mut headers: HeaderMap,
386 path_to_entry_1: Path<HashMap<String, String>>,
387 #[future] body_stream: Body,
388 ) {
389 let keeper = keeper.await;
390 let components = keeper.get_anonymous().await.unwrap();
391 components
392 .replication_repo
393 .write()
394 .await
395 .unwrap()
396 .set_mode("api-test", ReplicationMode::Paused)
397 .await
398 .unwrap();
399 headers.insert("content-length", "48".parse().unwrap());
400 headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
401 headers.insert(
402 "x-reduct-time-2",
403 "20,text/plain,c=\"d,f\"".parse().unwrap(),
404 );
405 headers.insert("x-reduct-time-10", "18,text/plain".parse().unwrap());
406
407 let stream = body_stream.await;
408
409 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
410 .await
411 .unwrap();
412
413 let bucket = components
414 .storage
415 .get_bucket("bucket-1")
416 .await
417 .unwrap()
418 .upgrade_and_unwrap();
419
420 {
421 let mut reader = bucket
422 .get_entry("entry-1")
423 .await
424 .unwrap()
425 .upgrade_and_unwrap()
426 .begin_read(1)
427 .await
428 .unwrap();
429 assert_eq!(&reader.meta().labels()["a"], "b");
430 assert_eq!(reader.meta().content_type(), "text/plain");
431 assert_eq!(reader.meta().content_length(), 10);
432 assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
433 }
434 {
435 let mut reader = bucket
436 .get_entry("entry-1")
437 .await
438 .unwrap()
439 .upgrade_and_unwrap()
440 .begin_read(2)
441 .await
442 .unwrap();
443 assert_eq!(&reader.meta().labels()["c"], "d,f");
444 assert_eq!(reader.meta().content_type(), "text/plain");
445 assert_eq!(reader.meta().content_length(), 20);
446 assert_eq!(
447 reader.read_chunk().unwrap(),
448 Ok(Bytes::from("abcdef1234567890abcd"))
449 );
450 }
451 {
452 let mut reader = bucket
453 .get_entry("entry-1")
454 .await
455 .unwrap()
456 .upgrade_and_unwrap()
457 .begin_read(10)
458 .await
459 .unwrap();
460 assert!(reader.meta().labels().is_empty());
461 assert_eq!(reader.meta().content_type(), "text/plain");
462 assert_eq!(reader.meta().content_length(), 18);
463 assert_eq!(
464 reader.read_chunk().unwrap(),
465 Ok(Bytes::from("ef1234567890abcdef"))
466 );
467 }
468
469 let mut pending_records = 0;
470 for _ in 0..20 {
471 pending_records = components
472 .replication_repo
473 .read()
474 .await
475 .unwrap()
476 .get_info("api-test")
477 .await
478 .unwrap()
479 .info
480 .pending_records;
481 if pending_records == 3 {
482 break;
483 }
484 sleep(Duration::from_millis(25)).await;
485 }
486 assert_eq!(pending_records, 3);
487 }
488
489 #[rstest]
490 #[tokio::test]
491 async fn test_write_batched_records_with_empty_bodies(
492 #[future] keeper: Arc<StateKeeper>,
493 mut headers: HeaderMap,
494 path_to_entry_1: Path<HashMap<String, String>>,
495 ) {
496 let keeper = keeper.await;
497 let components = keeper.get_anonymous().await.unwrap();
498 headers.insert("content-length", "0".parse().unwrap());
499 headers.insert("x-reduct-time-1", "0,,a=b".parse().unwrap());
500 headers.insert("x-reduct-time-2", "0,,a=d".parse().unwrap());
501
502 let stream = Body::empty();
503
504 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
505 .await
506 .unwrap();
507
508 let bucket = components
509 .storage
510 .get_bucket("bucket-1")
511 .await
512 .unwrap()
513 .upgrade_and_unwrap();
514
515 {
516 let mut reader = bucket
517 .get_entry("entry-1")
518 .await
519 .unwrap()
520 .upgrade_and_unwrap()
521 .begin_read(1)
522 .await
523 .unwrap();
524 assert_eq!(reader.meta().content_length(), 0);
525 assert_eq!(reader.read_chunk(), None);
526 }
527 {
528 let mut reader = bucket
529 .get_entry("entry-1")
530 .await
531 .unwrap()
532 .upgrade_and_unwrap()
533 .begin_read(2)
534 .await
535 .unwrap();
536 assert_eq!(reader.meta().content_length(), 0);
537 assert_eq!(reader.read_chunk(), None);
538 }
539 }
540
541 #[rstest]
542 #[tokio::test]
543 async fn test_write_batched_records_complex(
544 #[future] keeper: Arc<StateKeeper>,
545 mut headers: HeaderMap,
546 path_to_entry_1: Path<HashMap<String, String>>,
547 ) {
548 let keeper = keeper.await;
549 let components = keeper.get_anonymous().await.unwrap();
550 headers.insert("content-length", "1000000".parse().unwrap());
551 headers.insert("x-reduct-time-1", "500000,text/plain,a=b".parse().unwrap());
552 headers.insert("x-reduct-time-2", "0,text/plain,a=c".parse().unwrap());
553 headers.insert("x-reduct-time-3", "500000,text/plain,a=c".parse().unwrap());
554 headers.insert("x-reduct-time-4", "0,text/plain,a=c".parse().unwrap());
555
556 let stream = Body::from_stream(stream::iter(vec![
558 Ok::<Bytes, ReductError>(Bytes::from(vec![0; 600000])),
559 Ok(Bytes::from(vec![0; 300000])),
560 Ok(Bytes::from(vec![0; 100000])),
561 ]));
562
563 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
564 .await
565 .unwrap();
566
567 let bucket = components
568 .storage
569 .get_bucket("bucket-1")
570 .await
571 .unwrap()
572 .upgrade_and_unwrap();
573
574 {
575 let reader = bucket
576 .get_entry("entry-1")
577 .await
578 .unwrap()
579 .upgrade_and_unwrap()
580 .begin_read(1)
581 .await
582 .unwrap();
583 assert_eq!(reader.meta().content_length(), 500000);
584 }
585 {
586 let mut reader = bucket
587 .get_entry("entry-1")
588 .await
589 .unwrap()
590 .upgrade_and_unwrap()
591 .begin_read(2)
592 .await
593 .unwrap();
594 assert_eq!(reader.meta().content_length(), 0);
595 assert_eq!(reader.read_chunk(), None);
596 }
597 {
598 let reader = bucket
599 .get_entry("entry-1")
600 .await
601 .unwrap()
602 .upgrade_and_unwrap()
603 .begin_read(3)
604 .await
605 .unwrap();
606 assert_eq!(reader.meta().content_length(), 500000);
607 }
608 {
609 let mut reader = bucket
610 .get_entry("entry-1")
611 .await
612 .unwrap()
613 .upgrade_and_unwrap()
614 .begin_read(4)
615 .await
616 .unwrap();
617 assert_eq!(reader.meta().content_length(), 0);
618 assert_eq!(reader.read_chunk(), None);
619 }
620 }
621
622 #[rstest]
623 #[tokio::test(flavor = "multi_thread")]
624 async fn test_write_batched_records_error(
625 #[future] keeper: Arc<StateKeeper>,
626 mut headers: HeaderMap,
627 path_to_entry_1: Path<HashMap<String, String>>,
628 #[future] body_stream: Body,
629 ) {
630 let keeper = keeper.await;
631 let components = keeper.get_anonymous().await.unwrap();
632 {
633 let mut writer = components
634 .storage
635 .get_bucket("bucket-1")
636 .await
637 .unwrap()
638 .upgrade_and_unwrap()
639 .begin_write("entry-1", 2, 20, "text/plain".to_string(), HashMap::new())
640 .await
641 .unwrap();
642 writer
643 .send(Ok(Some(Bytes::from(vec![0; 20]))))
644 .await
645 .unwrap();
646 writer.send(Ok(None)).await.unwrap();
647 }
648
649 headers.insert("content-length", "48".parse().unwrap());
650 headers.insert("x-reduct-time-1", "10,".parse().unwrap());
651 headers.insert("x-reduct-time-2", "20,".parse().unwrap());
652 headers.insert("x-reduct-time-3", "18,".parse().unwrap());
653
654 let stream = body_stream.await;
655
656 let resp =
657 write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
658 .await
659 .unwrap()
660 .into_response();
661
662 let headers = resp.headers();
663 assert_eq!(headers.len(), 1);
664 assert_eq!(
665 headers.get("x-reduct-error-2").unwrap(),
666 &HeaderValue::from_static("409,A record with timestamp 2 already exists")
667 );
668
669 let bucket = components
670 .storage
671 .get_bucket("bucket-1")
672 .await
673 .unwrap()
674 .upgrade_and_unwrap();
675 {
676 let mut reader = bucket.begin_read("entry-1", 1).await.unwrap();
677 assert_eq!(reader.meta().content_length(), 10);
678 assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
679 }
680 {
681 let mut reader = bucket.begin_read("entry-1", 3).await.unwrap();
682 assert_eq!(reader.meta().content_length(), 18);
683 assert_eq!(
684 reader.read_chunk().unwrap(),
685 Ok(Bytes::from("ef1234567890abcdef"))
686 );
687 }
688 }
689
690 #[rstest]
691 #[tokio::test]
692 async fn test_write_batched_records_content_length_mismatch(
693 #[future] keeper: Arc<StateKeeper>,
694 mut headers: HeaderMap,
695 path_to_entry_1: Path<HashMap<String, String>>,
696 ) {
697 headers.insert("content-length", "60".parse().unwrap());
698 headers.insert("x-reduct-time-1", "40,text/plain,a=b".parse().unwrap());
699 headers.insert("x-reduct-time-2", "20,text/plain,c=d".parse().unwrap());
700 let stream = Body::from("123456");
701 let err = write_batched_records(
702 State(Arc::clone(&keeper.await)),
703 headers,
704 path_to_entry_1,
705 stream,
706 )
707 .await
708 .err()
709 .unwrap();
710
711 let err: ReductError = err.into();
712 assert_eq!(
713 err,
714 bad_request!("Content is shorter than expected: no more data to read")
715 );
716 }
717
718 #[rstest]
719 #[tokio::test]
720 async fn test_write_batched_records_errored_chunk(
721 #[future] keeper: Arc<StateKeeper>,
722 mut headers: HeaderMap,
723 path_to_entry_1: Path<HashMap<String, String>>,
724 ) {
725 headers.insert("content-length", "30".parse().unwrap());
726 headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
727 headers.insert("x-reduct-time-2", "20,text/plain,c=d".parse().unwrap());
728 let stream = Body::from_stream(stream::iter(vec![
729 Ok::<Bytes, ReductError>(Bytes::from("12345")),
730 Err(bad_request!("Simulated chunk error")),
731 ]));
732 let err = write_batched_records(
733 State(Arc::clone(&keeper.await)),
734 headers,
735 path_to_entry_1,
736 stream,
737 )
738 .await
739 .err()
740 .unwrap();
741
742 let err: ReductError = err.into();
743 assert_eq!(
744 err,
745 bad_request!("Error while receiving data chunk: [BadRequest] Simulated chunk error")
746 );
747 }
748
749 #[fixture]
750 async fn body_stream() -> Body {
751 Body::from("1234567890abcdef1234567890abcdef1234567890abcdef")
752 }
753}