Skip to main content

reductstore/api/entry/
write_batched.rs

1// Copyright 2025 ReductSoftware UG
2// Licensed under the Business Source License 1.1
3
4use crate::api::{Components, HttpError};
5use crate::auth::policy::WriteAccessPolicy;
6use axum::body::{Body, BodyDataStream};
7use axum::extract::{Path, State};
8use axum::response::IntoResponse;
9use axum_extra::headers::{Expect, Header, HeaderMap};
10use bytes::Bytes;
11use futures_util::StreamExt;
12
13use crate::api::entry::common::err_to_batched_header;
14use crate::api::StateKeeper;
15use crate::replication::{Transaction, TransactionNotification};
16use crate::storage::bucket::Bucket;
17use crate::storage::entry::RecordDrainer;
18use log::{debug, error};
19use reduct_base::batch::{parse_batched_header, sort_headers_by_time, RecordHeader};
20use reduct_base::error::ReductError;
21use reduct_base::io::{RecordMeta, WriteRecord};
22use reduct_base::{bad_request, internal_server_error, unprocessable_entity};
23use std::collections::{BTreeMap, HashMap};
24use std::sync::Arc;
25use tokio::sync::mpsc::Receiver;
26use tokio::task::JoinHandle;
27use tokio::time::timeout;
28
29struct WriteContext {
30    time: u64,
31    header: RecordHeader,
32    writer: Box<dyn WriteRecord + Sync + Send>,
33}
34
35type ErrorMap = BTreeMap<u64, ReductError>;
36
37// POST /:bucket/:entry/batch
38pub(super) async fn write_batched_records(
39    State(keeper): State<Arc<StateKeeper>>,
40    headers: HeaderMap,
41    Path(path): Path<HashMap<String, String>>,
42    body: Body,
43) -> Result<impl IntoResponse, HttpError> {
44    let bucket = path.get("bucket_name").unwrap();
45    let components = keeper
46        .get_with_permissions(&headers.clone(), WriteAccessPolicy { bucket })
47        .await?;
48
49    let entry_name = path.get("entry_name").unwrap().clone();
50    let record_headers: Vec<_> = sort_headers_by_time(&headers)?;
51    let mut stream = body.into_data_stream();
52
53    let process_stream = async {
54        let mut timed_headers: Vec<(u64, RecordHeader)> = Vec::new();
55        for (time, v) in record_headers {
56            let header = parse_batched_header(v.to_str().unwrap())?;
57            timed_headers.push((time, header));
58        }
59
60        let content_length = check_and_get_content_length(&headers, &timed_headers)?;
61        let (rx_writer, spawn_handler) =
62            spawn_getting_writers(&components, &bucket, &entry_name, timed_headers).await?;
63        receive_body_and_write_records(
64            bucket,
65            entry_name,
66            components,
67            content_length as i64,
68            &mut stream,
69            rx_writer,
70        )
71        .await?;
72
73        Ok(spawn_handler
74            .await
75            .map_err(|e| internal_server_error!("Failed to complete write operation: {}", e))?)
76    };
77
78    match process_stream.await {
79        Ok(error_map) => {
80            let mut headers = HeaderMap::new();
81            error_map.iter().for_each(|(time, err)| {
82                err_to_batched_header(&mut headers, *time, err);
83            });
84
85            Ok(headers.into())
86        }
87
88        Err(err) => {
89            if !headers.contains_key(Expect::name()) {
90                debug!("draining the stream");
91                while let Some(_) = stream.next().await {}
92            }
93            Err::<HeaderMap, HttpError>(err)
94        }
95    }
96}
97
98async fn notify_replication_write(
99    components: &Arc<Components>,
100    bucket: &str,
101    entry_name: &str,
102    ctx: &WriteContext,
103) -> Result<(), ReductError> {
104    components
105        .replication_repo
106        .write()
107        .await?
108        .notify(TransactionNotification {
109            bucket: bucket.to_string(),
110            entry: entry_name.to_string(),
111            meta: RecordMeta::builder()
112                .timestamp(ctx.time)
113                .labels(ctx.header.labels.clone())
114                .build(),
115            event: Transaction::WriteRecord(ctx.time),
116        })
117        .await?;
118    Ok(())
119}
120
121async fn receive_body_and_write_records(
122    bucket: &String,
123    entry_name: String,
124    components: Arc<Components>,
125    mut total_content_len: i64,
126    stream: &mut BodyDataStream,
127    mut rx_writer: Receiver<WriteContext>,
128) -> Result<(), ReductError> {
129    let mut chunk = Bytes::new();
130
131    let mut read_chunk = async || -> Result<Bytes, ReductError> {
132        if total_content_len == 0 {
133            // it makes the code simpler to handle the last empty chunk case and empty records
134            return Ok(Bytes::new());
135        }
136        match timeout(components.cfg.io_conf.operation_timeout, stream.next())
137            .await
138            .map_err(|_| bad_request!("Timeout while receiving data"))?
139        {
140            Some(Ok(data_chunk)) => {
141                total_content_len -= data_chunk.len() as i64;
142                Ok(data_chunk)
143            }
144            Some(Err(e)) => Err(bad_request!("Error while receiving data chunk: {}", e)),
145            None => Err(bad_request!(
146                "Content is shorter than expected: no more data to read"
147            )),
148        }
149    };
150
151    while let Some(mut ctx) = rx_writer.recv().await {
152        let mut written = 0;
153
154        if chunk.is_empty() {
155            chunk = read_chunk().await?
156        }
157
158        loop {
159            match write_chunk(
160                &mut ctx.writer,
161                chunk,
162                &mut written,
163                ctx.header.content_length.clone(),
164                components.cfg.io_conf.operation_timeout,
165            )
166            .await
167            {
168                Ok(None) => {
169                    // chunk is written but record is not finished yet
170                    chunk = read_chunk().await?;
171                    continue;
172                }
173                Ok(Some(rest)) => {
174                    // finish writing the current record and start a new one
175                    // finished writing the current record
176                    if let Err(err) = ctx
177                        .writer
178                        .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
179                        .await
180                    {
181                        debug!("Timeout while sending EOF: {}", err);
182                    }
183
184                    notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
185                    chunk = rest;
186                    break;
187                }
188                Err(err) => return Err(err),
189            }
190        }
191    }
192
193    Ok(())
194}
195
196async fn spawn_getting_writers(
197    components: &Arc<Components>,
198    bucket_name: &str,
199    entry_name: &str,
200    timed_headers: Vec<(u64, RecordHeader)>,
201) -> Result<(Receiver<WriteContext>, JoinHandle<ErrorMap>), ReductError> {
202    let (tx_writer, rx_writer) = tokio::sync::mpsc::channel(64);
203
204    let bucket = components
205        .storage
206        .get_bucket(&bucket_name)
207        .await?
208        .upgrade_and_unwrap();
209
210    let entry_name = entry_name.to_string();
211    let spawn_handler = tokio::spawn(async move {
212        let mut error_map = BTreeMap::new();
213
214        for (time, header) in timed_headers.into_iter() {
215            let writer =
216                start_writing(&entry_name, bucket.clone(), time, &header, &mut error_map).await;
217
218            tx_writer
219                .send(WriteContext {
220                    time,
221                    header,
222                    writer,
223                })
224                .await
225                .map_err(|err| error!("Failed to send the writer: {}", err))
226                .unwrap_or(());
227        }
228        error_map
229    });
230
231    Ok((rx_writer, spawn_handler))
232}
233
234async fn write_chunk(
235    writer: &mut Box<dyn WriteRecord + Sync + Send>,
236    chunk: Bytes,
237    written: &mut usize,
238    content_size: u64,
239    io_timeout: std::time::Duration,
240) -> Result<Option<Bytes>, ReductError> {
241    let to_write = content_size - *written as u64;
242    *written += chunk.len();
243    let (chunk, rest) = if (chunk.len() as u64) < to_write {
244        (chunk, None)
245    } else {
246        let chuck_to_write = chunk.slice(0..to_write as usize);
247        (chuck_to_write, Some(chunk.slice(to_write as usize..)))
248    };
249
250    writer.send_timeout(Ok(Some(chunk)), io_timeout).await?;
251    Ok(rest)
252}
253
254fn check_and_get_content_length(
255    headers: &HeaderMap,
256    timed_headers: &Vec<(u64, RecordHeader)>,
257) -> Result<u64, ReductError> {
258    let total_content_length: u64 = timed_headers
259        .iter()
260        .map(|(_, header)| header.content_length)
261        .sum();
262
263    if total_content_length
264        != headers
265            .get("content-length")
266            .ok_or(unprocessable_entity!("content-length header is required",))?
267            .to_str()
268            .unwrap()
269            .parse::<u64>()
270            .map_err(|_| unprocessable_entity!("Invalid content-length header"))?
271    {
272        return Err(unprocessable_entity!(
273            "content-length header does not match the sum of the content-lengths in the headers",
274        )
275        .into());
276    }
277
278    Ok(total_content_length)
279}
280
281async fn start_writing(
282    entry_name: &str,
283    bucket: Arc<Bucket>,
284    time: u64,
285    record_header: &RecordHeader,
286    error_map: &mut BTreeMap<u64, ReductError>,
287) -> Box<dyn WriteRecord + Sync + Send> {
288    let get_writer = async {
289        bucket
290            .begin_write(
291                entry_name,
292                time,
293                record_header.content_length.clone(),
294                record_header.content_type.clone(),
295                record_header.labels.clone(),
296            )
297            .await
298    };
299
300    match get_writer.await {
301        Ok(writer) => writer,
302        Err(err) => {
303            error_map.insert(time, err);
304            // drain the stream
305            Box::new(RecordDrainer::new())
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::api::entry::write_batched::write_batched_records;
314    use crate::api::tests::{headers, keeper, path_to_entry_1};
315    use axum_extra::headers::HeaderValue;
316    use futures_util::stream;
317    use reduct_base::error::ErrorCode;
318    use reduct_base::io::ReadRecord;
319    use reduct_base::msg::replication_api::ReplicationMode;
320    use rstest::{fixture, rstest};
321    use std::time::Duration;
322    use tokio::time::sleep;
323
324    #[rstest]
325    #[tokio::test]
326    async fn test_write_record_bad_timestamp(
327        #[future] keeper: Arc<StateKeeper>,
328        mut headers: HeaderMap,
329        path_to_entry_1: Path<HashMap<String, String>>,
330        #[future] body_stream: Body,
331    ) {
332        headers.insert("content-length", "10".parse().unwrap());
333        headers.insert("x-reduct-time-yyy", "10".parse().unwrap());
334
335        let err = write_batched_records(
336            State(keeper.await),
337            headers,
338            path_to_entry_1,
339            body_stream.await,
340        )
341        .await
342        .err()
343        .unwrap();
344
345        assert_eq!(
346            err,
347            HttpError::new(
348                ErrorCode::UnprocessableEntity,
349                "Invalid header 'x-reduct-time-yyy': must be an unix timestamp in microseconds",
350            )
351        );
352    }
353
354    #[rstest]
355    #[tokio::test]
356    async fn test_write_batched_invalid_header(
357        #[future] keeper: Arc<StateKeeper>,
358        mut headers: HeaderMap,
359        path_to_entry_1: Path<HashMap<String, String>>,
360        #[future] body_stream: Body,
361    ) {
362        headers.insert("content-length", "10".parse().unwrap());
363        headers.insert("x-reduct-time-1", "".parse().unwrap());
364
365        let err = write_batched_records(
366            State(keeper.await),
367            headers,
368            path_to_entry_1,
369            body_stream.await,
370        )
371        .await
372        .err()
373        .unwrap();
374
375        assert_eq!(
376            err,
377            HttpError::new(ErrorCode::UnprocessableEntity, "Invalid batched header")
378        );
379    }
380
381    #[rstest]
382    #[tokio::test(flavor = "multi_thread")]
383    async fn test_write_batched_records(
384        #[future] keeper: Arc<StateKeeper>,
385        mut headers: HeaderMap,
386        path_to_entry_1: Path<HashMap<String, String>>,
387        #[future] body_stream: Body,
388    ) {
389        let keeper = keeper.await;
390        let components = keeper.get_anonymous().await.unwrap();
391        components
392            .replication_repo
393            .write()
394            .await
395            .unwrap()
396            .set_mode("api-test", ReplicationMode::Paused)
397            .await
398            .unwrap();
399        headers.insert("content-length", "48".parse().unwrap());
400        headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
401        headers.insert(
402            "x-reduct-time-2",
403            "20,text/plain,c=\"d,f\"".parse().unwrap(),
404        );
405        headers.insert("x-reduct-time-10", "18,text/plain".parse().unwrap());
406
407        let stream = body_stream.await;
408
409        write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
410            .await
411            .unwrap();
412
413        let bucket = components
414            .storage
415            .get_bucket("bucket-1")
416            .await
417            .unwrap()
418            .upgrade_and_unwrap();
419
420        {
421            let mut reader = bucket
422                .get_entry("entry-1")
423                .await
424                .unwrap()
425                .upgrade_and_unwrap()
426                .begin_read(1)
427                .await
428                .unwrap();
429            assert_eq!(&reader.meta().labels()["a"], "b");
430            assert_eq!(reader.meta().content_type(), "text/plain");
431            assert_eq!(reader.meta().content_length(), 10);
432            assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
433        }
434        {
435            let mut reader = bucket
436                .get_entry("entry-1")
437                .await
438                .unwrap()
439                .upgrade_and_unwrap()
440                .begin_read(2)
441                .await
442                .unwrap();
443            assert_eq!(&reader.meta().labels()["c"], "d,f");
444            assert_eq!(reader.meta().content_type(), "text/plain");
445            assert_eq!(reader.meta().content_length(), 20);
446            assert_eq!(
447                reader.read_chunk().unwrap(),
448                Ok(Bytes::from("abcdef1234567890abcd"))
449            );
450        }
451        {
452            let mut reader = bucket
453                .get_entry("entry-1")
454                .await
455                .unwrap()
456                .upgrade_and_unwrap()
457                .begin_read(10)
458                .await
459                .unwrap();
460            assert!(reader.meta().labels().is_empty());
461            assert_eq!(reader.meta().content_type(), "text/plain");
462            assert_eq!(reader.meta().content_length(), 18);
463            assert_eq!(
464                reader.read_chunk().unwrap(),
465                Ok(Bytes::from("ef1234567890abcdef"))
466            );
467        }
468
469        let mut pending_records = 0;
470        for _ in 0..20 {
471            pending_records = components
472                .replication_repo
473                .read()
474                .await
475                .unwrap()
476                .get_info("api-test")
477                .await
478                .unwrap()
479                .info
480                .pending_records;
481            if pending_records == 3 {
482                break;
483            }
484            sleep(Duration::from_millis(25)).await;
485        }
486        assert_eq!(pending_records, 3);
487    }
488
489    #[rstest]
490    #[tokio::test]
491    async fn test_write_batched_records_with_empty_bodies(
492        #[future] keeper: Arc<StateKeeper>,
493        mut headers: HeaderMap,
494        path_to_entry_1: Path<HashMap<String, String>>,
495    ) {
496        let keeper = keeper.await;
497        let components = keeper.get_anonymous().await.unwrap();
498        headers.insert("content-length", "0".parse().unwrap());
499        headers.insert("x-reduct-time-1", "0,,a=b".parse().unwrap());
500        headers.insert("x-reduct-time-2", "0,,a=d".parse().unwrap());
501
502        let stream = Body::empty();
503
504        write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
505            .await
506            .unwrap();
507
508        let bucket = components
509            .storage
510            .get_bucket("bucket-1")
511            .await
512            .unwrap()
513            .upgrade_and_unwrap();
514
515        {
516            let mut reader = bucket
517                .get_entry("entry-1")
518                .await
519                .unwrap()
520                .upgrade_and_unwrap()
521                .begin_read(1)
522                .await
523                .unwrap();
524            assert_eq!(reader.meta().content_length(), 0);
525            assert_eq!(reader.read_chunk(), None);
526        }
527        {
528            let mut reader = bucket
529                .get_entry("entry-1")
530                .await
531                .unwrap()
532                .upgrade_and_unwrap()
533                .begin_read(2)
534                .await
535                .unwrap();
536            assert_eq!(reader.meta().content_length(), 0);
537            assert_eq!(reader.read_chunk(), None);
538        }
539    }
540
541    #[rstest]
542    #[tokio::test]
543    async fn test_write_batched_records_complex(
544        #[future] keeper: Arc<StateKeeper>,
545        mut headers: HeaderMap,
546        path_to_entry_1: Path<HashMap<String, String>>,
547    ) {
548        let keeper = keeper.await;
549        let components = keeper.get_anonymous().await.unwrap();
550        headers.insert("content-length", "1000000".parse().unwrap());
551        headers.insert("x-reduct-time-1", "500000,text/plain,a=b".parse().unwrap());
552        headers.insert("x-reduct-time-2", "0,text/plain,a=c".parse().unwrap());
553        headers.insert("x-reduct-time-3", "500000,text/plain,a=c".parse().unwrap());
554        headers.insert("x-reduct-time-4", "0,text/plain,a=c".parse().unwrap());
555
556        // the body will be split into 3 parts: 600000, 300000, 100000
557        let stream = Body::from_stream(stream::iter(vec![
558            Ok::<Bytes, ReductError>(Bytes::from(vec![0; 600000])),
559            Ok(Bytes::from(vec![0; 300000])),
560            Ok(Bytes::from(vec![0; 100000])),
561        ]));
562
563        write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
564            .await
565            .unwrap();
566
567        let bucket = components
568            .storage
569            .get_bucket("bucket-1")
570            .await
571            .unwrap()
572            .upgrade_and_unwrap();
573
574        {
575            let reader = bucket
576                .get_entry("entry-1")
577                .await
578                .unwrap()
579                .upgrade_and_unwrap()
580                .begin_read(1)
581                .await
582                .unwrap();
583            assert_eq!(reader.meta().content_length(), 500000);
584        }
585        {
586            let mut reader = bucket
587                .get_entry("entry-1")
588                .await
589                .unwrap()
590                .upgrade_and_unwrap()
591                .begin_read(2)
592                .await
593                .unwrap();
594            assert_eq!(reader.meta().content_length(), 0);
595            assert_eq!(reader.read_chunk(), None);
596        }
597        {
598            let reader = bucket
599                .get_entry("entry-1")
600                .await
601                .unwrap()
602                .upgrade_and_unwrap()
603                .begin_read(3)
604                .await
605                .unwrap();
606            assert_eq!(reader.meta().content_length(), 500000);
607        }
608        {
609            let mut reader = bucket
610                .get_entry("entry-1")
611                .await
612                .unwrap()
613                .upgrade_and_unwrap()
614                .begin_read(4)
615                .await
616                .unwrap();
617            assert_eq!(reader.meta().content_length(), 0);
618            assert_eq!(reader.read_chunk(), None);
619        }
620    }
621
622    #[rstest]
623    #[tokio::test(flavor = "multi_thread")]
624    async fn test_write_batched_records_error(
625        #[future] keeper: Arc<StateKeeper>,
626        mut headers: HeaderMap,
627        path_to_entry_1: Path<HashMap<String, String>>,
628        #[future] body_stream: Body,
629    ) {
630        let keeper = keeper.await;
631        let components = keeper.get_anonymous().await.unwrap();
632        {
633            let mut writer = components
634                .storage
635                .get_bucket("bucket-1")
636                .await
637                .unwrap()
638                .upgrade_and_unwrap()
639                .begin_write("entry-1", 2, 20, "text/plain".to_string(), HashMap::new())
640                .await
641                .unwrap();
642            writer
643                .send(Ok(Some(Bytes::from(vec![0; 20]))))
644                .await
645                .unwrap();
646            writer.send(Ok(None)).await.unwrap();
647        }
648
649        headers.insert("content-length", "48".parse().unwrap());
650        headers.insert("x-reduct-time-1", "10,".parse().unwrap());
651        headers.insert("x-reduct-time-2", "20,".parse().unwrap());
652        headers.insert("x-reduct-time-3", "18,".parse().unwrap());
653
654        let stream = body_stream.await;
655
656        let resp =
657            write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
658                .await
659                .unwrap()
660                .into_response();
661
662        let headers = resp.headers();
663        assert_eq!(headers.len(), 1);
664        assert_eq!(
665            headers.get("x-reduct-error-2").unwrap(),
666            &HeaderValue::from_static("409,A record with timestamp 2 already exists")
667        );
668
669        let bucket = components
670            .storage
671            .get_bucket("bucket-1")
672            .await
673            .unwrap()
674            .upgrade_and_unwrap();
675        {
676            let mut reader = bucket.begin_read("entry-1", 1).await.unwrap();
677            assert_eq!(reader.meta().content_length(), 10);
678            assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
679        }
680        {
681            let mut reader = bucket.begin_read("entry-1", 3).await.unwrap();
682            assert_eq!(reader.meta().content_length(), 18);
683            assert_eq!(
684                reader.read_chunk().unwrap(),
685                Ok(Bytes::from("ef1234567890abcdef"))
686            );
687        }
688    }
689
690    #[rstest]
691    #[tokio::test]
692    async fn test_write_batched_records_content_length_mismatch(
693        #[future] keeper: Arc<StateKeeper>,
694        mut headers: HeaderMap,
695        path_to_entry_1: Path<HashMap<String, String>>,
696    ) {
697        headers.insert("content-length", "60".parse().unwrap());
698        headers.insert("x-reduct-time-1", "40,text/plain,a=b".parse().unwrap());
699        headers.insert("x-reduct-time-2", "20,text/plain,c=d".parse().unwrap());
700        let stream = Body::from("123456");
701        let err = write_batched_records(
702            State(Arc::clone(&keeper.await)),
703            headers,
704            path_to_entry_1,
705            stream,
706        )
707        .await
708        .err()
709        .unwrap();
710
711        let err: ReductError = err.into();
712        assert_eq!(
713            err,
714            bad_request!("Content is shorter than expected: no more data to read")
715        );
716    }
717
718    #[rstest]
719    #[tokio::test]
720    async fn test_write_batched_records_errored_chunk(
721        #[future] keeper: Arc<StateKeeper>,
722        mut headers: HeaderMap,
723        path_to_entry_1: Path<HashMap<String, String>>,
724    ) {
725        headers.insert("content-length", "30".parse().unwrap());
726        headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
727        headers.insert("x-reduct-time-2", "20,text/plain,c=d".parse().unwrap());
728        let stream = Body::from_stream(stream::iter(vec![
729            Ok::<Bytes, ReductError>(Bytes::from("12345")),
730            Err(bad_request!("Simulated chunk error")),
731        ]));
732        let err = write_batched_records(
733            State(Arc::clone(&keeper.await)),
734            headers,
735            path_to_entry_1,
736            stream,
737        )
738        .await
739        .err()
740        .unwrap();
741
742        let err: ReductError = err.into();
743        assert_eq!(
744            err,
745            bad_request!("Error while receiving data chunk: [BadRequest] Simulated chunk error")
746        );
747    }
748
749    #[fixture]
750    async fn body_stream() -> Body {
751        Body::from("1234567890abcdef1234567890abcdef1234567890abcdef")
752    }
753}