reductstore/api/entry/
write_batched.rs

1// Copyright 2025 ReductSoftware UG
2// Licensed under the Business Source License 1.1
3
4use crate::api::{Components, HttpError};
5use crate::auth::policy::WriteAccessPolicy;
6use axum::body::{Body, BodyDataStream};
7use axum::extract::{Path, State};
8use axum::response::IntoResponse;
9use axum_extra::headers::{Expect, Header, HeaderMap};
10use bytes::Bytes;
11use futures_util::StreamExt;
12
13use crate::api::entry::common::err_to_batched_header;
14use crate::api::StateKeeper;
15use crate::replication::{Transaction, TransactionNotification};
16use crate::storage::bucket::Bucket;
17use crate::storage::entry::RecordDrainer;
18use log::{debug, error};
19use reduct_base::batch::{parse_batched_header, sort_headers_by_time, RecordHeader};
20use reduct_base::error::ReductError;
21use reduct_base::io::{RecordMeta, WriteRecord};
22use reduct_base::{bad_request, internal_server_error, unprocessable_entity};
23use std::collections::{BTreeMap, HashMap};
24use std::sync::Arc;
25use tokio::sync::mpsc::Receiver;
26use tokio::task::JoinHandle;
27use tokio::time::timeout;
28
29struct WriteContext {
30    time: u64,
31    header: RecordHeader,
32    writer: Box<dyn WriteRecord + Sync + Send>,
33}
34
35type ErrorMap = BTreeMap<u64, ReductError>;
36
37// POST /:bucket/:entry/batch
38pub(super) async fn write_batched_records(
39    State(keeper): State<Arc<StateKeeper>>,
40    headers: HeaderMap,
41    Path(path): Path<HashMap<String, String>>,
42    body: Body,
43) -> Result<impl IntoResponse, HttpError> {
44    let bucket = path.get("bucket_name").unwrap();
45    let components = keeper
46        .get_with_permissions(&headers.clone(), WriteAccessPolicy { bucket })
47        .await?;
48
49    let entry_name = path.get("entry_name").unwrap().clone();
50    let record_headers: Vec<_> = sort_headers_by_time(&headers)?;
51    let mut stream = body.into_data_stream();
52
53    let process_stream = async {
54        let mut timed_headers: Vec<(u64, RecordHeader)> = Vec::new();
55        for (time, v) in record_headers {
56            let header = parse_batched_header(v.to_str().unwrap())?;
57            timed_headers.push((time, header));
58        }
59
60        let content_length = check_and_get_content_length(&headers, &timed_headers)?;
61        let record_count = timed_headers.len();
62
63        let (rx_writer, spawn_handler) =
64            spawn_getting_writers(&components, &bucket, &entry_name, timed_headers)?;
65
66        if content_length > 0 {
67            receive_body_and_write_records(
68                bucket,
69                entry_name,
70                components,
71                record_count,
72                &mut stream,
73                rx_writer,
74            )
75            .await?;
76        } else {
77            write_only_metadata(bucket, entry_name, components, rx_writer).await?;
78        }
79
80        Ok(spawn_handler
81            .await
82            .map_err(|e| internal_server_error!("Failed to complete write operation: {}", e))?)
83    };
84
85    match process_stream.await {
86        Ok(error_map) => {
87            let mut headers = HeaderMap::new();
88            error_map.iter().for_each(|(time, err)| {
89                err_to_batched_header(&mut headers, *time, err);
90            });
91
92            Ok(headers.into())
93        }
94
95        Err(err) => {
96            if !headers.contains_key(Expect::name()) {
97                debug!("draining the stream");
98                while let Some(_) = stream.next().await {}
99            }
100            Err::<HeaderMap, HttpError>(err)
101        }
102    }
103}
104
105async fn notify_replication_write(
106    components: &Arc<Components>,
107    bucket: &str,
108    entry_name: &str,
109    ctx: &WriteContext,
110) -> Result<(), ReductError> {
111    components
112        .replication_repo
113        .write()
114        .await
115        .notify(TransactionNotification {
116            bucket: bucket.to_string(),
117            entry: entry_name.to_string(),
118            meta: RecordMeta::builder()
119                .timestamp(ctx.time)
120                .labels(ctx.header.labels.clone())
121                .build(),
122            event: Transaction::WriteRecord(ctx.time),
123        })?;
124    Ok(())
125}
126
127async fn write_only_metadata(
128    bucket: &String,
129    entry_name: String,
130    components: Arc<Components>,
131    mut rx_writer: Receiver<WriteContext>,
132) -> Result<(), ReductError> {
133    while let Some(mut ctx) = rx_writer.recv().await {
134        if let Err(err) = ctx
135            .writer
136            .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
137            .await
138        {
139            debug!("Timeout while sending EOF: {}", err);
140        }
141
142        notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
143    }
144
145    Ok(())
146}
147
148async fn receive_body_and_write_records(
149    bucket: &String,
150    entry_name: String,
151    components: Arc<Components>,
152    mut record_count: usize,
153    stream: &mut BodyDataStream,
154    mut rx_writer: Receiver<WriteContext>,
155) -> Result<(), ReductError> {
156    let mut written = 0;
157    let mut ctx = rx_writer
158        .recv()
159        .await
160        .ok_or(internal_server_error!("No writer found"))?;
161
162    while let Some(chunk) = timeout(components.cfg.io_conf.operation_timeout, stream.next())
163        .await
164        .map_err(|_| internal_server_error!("Timeout while receiving data"))?
165    {
166        let mut chunk =
167            chunk.map_err(|e| bad_request!("Error while receiving data chunk: {}", e))?;
168
169        while !chunk.is_empty() {
170            match write_chunk(
171                &mut ctx.writer,
172                chunk,
173                &mut written,
174                ctx.header.content_length.clone(),
175                components.cfg.io_conf.operation_timeout,
176            )
177            .await
178            {
179                Ok(None) => break, // finished writing the current record
180                Ok(Some(rest)) => {
181                    // finish writing the current record and start a new one
182                    if let Err(err) = ctx
183                        .writer
184                        .send_timeout(Ok(None), components.cfg.io_conf.operation_timeout)
185                        .await
186                    {
187                        debug!("Timeout while sending EOF: {}", err);
188                    }
189
190                    notify_replication_write(&components, bucket, &entry_name, &ctx).await?;
191
192                    chunk = rest;
193                    record_count -= 1;
194                    written = 0;
195
196                    ctx = match rx_writer.recv().await {
197                        Some(ctx) => ctx,
198                        None => break, // no more writers - stop the loop
199                    };
200                }
201                Err(err) => return Err(err),
202            }
203        }
204    }
205
206    if record_count != 0 {
207        return Err(bad_request!("Content is shorter than expected"));
208    }
209
210    Ok(())
211}
212
213fn spawn_getting_writers(
214    components: &Arc<Components>,
215    bucket_name: &str,
216    entry_name: &str,
217    timed_headers: Vec<(u64, RecordHeader)>,
218) -> Result<(Receiver<WriteContext>, JoinHandle<ErrorMap>), ReductError> {
219    let (tx_writer, rx_writer) = tokio::sync::mpsc::channel(64);
220
221    let bucket = components
222        .storage
223        .get_bucket(&bucket_name)?
224        .upgrade_and_unwrap();
225
226    let entry_name = entry_name.to_string();
227    let spawn_handler = tokio::spawn(async move {
228        let mut error_map = BTreeMap::new();
229
230        for (time, header) in timed_headers.into_iter() {
231            let writer =
232                start_writing(&entry_name, bucket.clone(), time, &header, &mut error_map).await;
233
234            tx_writer
235                .send(WriteContext {
236                    time,
237                    header,
238                    writer,
239                })
240                .await
241                .map_err(|err| error!("Failed to send the writer: {}", err))
242                .unwrap_or(());
243        }
244        error_map
245    });
246
247    Ok((rx_writer, spawn_handler))
248}
249
250async fn write_chunk(
251    writer: &mut Box<dyn WriteRecord + Sync + Send>,
252    chunk: Bytes,
253    written: &mut usize,
254    content_size: u64,
255    io_timeout: std::time::Duration,
256) -> Result<Option<Bytes>, ReductError> {
257    let to_write = content_size - *written as u64;
258    *written += chunk.len();
259    let (chunk, rest) = if (chunk.len() as u64) < to_write {
260        (chunk, None)
261    } else {
262        let chuck_to_write = chunk.slice(0..to_write as usize);
263        (chuck_to_write, Some(chunk.slice(to_write as usize..)))
264    };
265
266    writer.send_timeout(Ok(Some(chunk)), io_timeout).await?;
267    Ok(rest)
268}
269
270fn check_and_get_content_length(
271    headers: &HeaderMap,
272    timed_headers: &Vec<(u64, RecordHeader)>,
273) -> Result<u64, ReductError> {
274    let total_content_length: u64 = timed_headers
275        .iter()
276        .map(|(_, header)| header.content_length)
277        .sum();
278
279    if total_content_length
280        != headers
281            .get("content-length")
282            .ok_or(unprocessable_entity!("content-length header is required",))?
283            .to_str()
284            .unwrap()
285            .parse::<u64>()
286            .map_err(|_| unprocessable_entity!("Invalid content-length header"))?
287    {
288        return Err(unprocessable_entity!(
289            "content-length header does not match the sum of the content-lengths in the headers",
290        )
291        .into());
292    }
293
294    Ok(total_content_length)
295}
296
297async fn start_writing(
298    entry_name: &str,
299    bucket: Arc<Bucket>,
300    time: u64,
301    record_header: &RecordHeader,
302    error_map: &mut BTreeMap<u64, ReductError>,
303) -> Box<dyn WriteRecord + Sync + Send> {
304    let get_writer = async {
305        bucket
306            .begin_write(
307                entry_name,
308                time,
309                record_header.content_length.clone(),
310                record_header.content_type.clone(),
311                record_header.labels.clone(),
312            )
313            .await
314    };
315
316    match get_writer.await {
317        Ok(writer) => writer,
318        Err(err) => {
319            error_map.insert(time, err);
320            // drain the stream
321            Box::new(RecordDrainer::new())
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::api::entry::write_batched::write_batched_records;
330    use crate::api::tests::{headers, keeper, path_to_entry_1};
331
332    use axum_extra::headers::HeaderValue;
333    use reduct_base::error::ErrorCode;
334    use reduct_base::io::ReadRecord;
335    use rstest::{fixture, rstest};
336
337    #[rstest]
338    #[tokio::test]
339    async fn test_write_record_bad_timestamp(
340        #[future] keeper: Arc<StateKeeper>,
341        mut headers: HeaderMap,
342        path_to_entry_1: Path<HashMap<String, String>>,
343        #[future] body_stream: Body,
344    ) {
345        headers.insert("content-length", "10".parse().unwrap());
346        headers.insert("x-reduct-time-yyy", "10".parse().unwrap());
347
348        let err = write_batched_records(
349            State(keeper.await),
350            headers,
351            path_to_entry_1,
352            body_stream.await,
353        )
354        .await
355        .err()
356        .unwrap();
357
358        assert_eq!(
359            err,
360            HttpError::new(
361                ErrorCode::UnprocessableEntity,
362                "Invalid header 'x-reduct-time-yyy': must be an unix timestamp in microseconds",
363            )
364        );
365    }
366
367    #[rstest]
368    #[tokio::test]
369    async fn test_write_batched_invalid_header(
370        #[future] keeper: Arc<StateKeeper>,
371        mut headers: HeaderMap,
372        path_to_entry_1: Path<HashMap<String, String>>,
373        #[future] body_stream: Body,
374    ) {
375        headers.insert("content-length", "10".parse().unwrap());
376        headers.insert("x-reduct-time-1", "".parse().unwrap());
377
378        let err = write_batched_records(
379            State(keeper.await),
380            headers,
381            path_to_entry_1,
382            body_stream.await,
383        )
384        .await
385        .err()
386        .unwrap();
387
388        assert_eq!(
389            err,
390            HttpError::new(ErrorCode::UnprocessableEntity, "Invalid batched header")
391        );
392    }
393
394    #[rstest]
395    #[tokio::test]
396    async fn test_write_batched_records(
397        #[future] keeper: Arc<StateKeeper>,
398        mut headers: HeaderMap,
399        path_to_entry_1: Path<HashMap<String, String>>,
400        #[future] body_stream: Body,
401    ) {
402        let keeper = keeper.await;
403        let components = keeper.get_anonymous().await.unwrap();
404        headers.insert("content-length", "48".parse().unwrap());
405        headers.insert("x-reduct-time-1", "10,text/plain,a=b".parse().unwrap());
406        headers.insert(
407            "x-reduct-time-2",
408            "20,text/plain,c=\"d,f\"".parse().unwrap(),
409        );
410        headers.insert("x-reduct-time-10", "18,text/plain".parse().unwrap());
411
412        let stream = body_stream.await;
413
414        write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
415            .await
416            .unwrap();
417
418        let bucket = components
419            .storage
420            .get_bucket("bucket-1")
421            .unwrap()
422            .upgrade_and_unwrap();
423
424        {
425            let mut reader = bucket
426                .get_entry("entry-1")
427                .unwrap()
428                .upgrade_and_unwrap()
429                .begin_read(1)
430                .await
431                .unwrap();
432            assert_eq!(&reader.meta().labels()["a"], "b");
433            assert_eq!(reader.meta().content_type(), "text/plain");
434            assert_eq!(reader.meta().content_length(), 10);
435            assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
436        }
437        {
438            let mut reader = bucket
439                .get_entry("entry-1")
440                .unwrap()
441                .upgrade_and_unwrap()
442                .begin_read(2)
443                .await
444                .unwrap();
445            assert_eq!(&reader.meta().labels()["c"], "d,f");
446            assert_eq!(reader.meta().content_type(), "text/plain");
447            assert_eq!(reader.meta().content_length(), 20);
448            assert_eq!(
449                reader.read_chunk().unwrap(),
450                Ok(Bytes::from("abcdef1234567890abcd"))
451            );
452        }
453        {
454            let mut reader = bucket
455                .get_entry("entry-1")
456                .unwrap()
457                .upgrade_and_unwrap()
458                .begin_read(10)
459                .await
460                .unwrap();
461            assert!(reader.meta().labels().is_empty());
462            assert_eq!(reader.meta().content_type(), "text/plain");
463            assert_eq!(reader.meta().content_length(), 18);
464            assert_eq!(
465                reader.read_chunk().unwrap(),
466                Ok(Bytes::from("ef1234567890abcdef"))
467            );
468        }
469
470        let info = components
471            .replication_repo
472            .read()
473            .await
474            .get_info("api-test")
475            .unwrap();
476        assert_eq!(info.info.pending_records, 3);
477    }
478
479    #[rstest]
480    #[tokio::test]
481    async fn test_write_batched_records_with_empty_bodies(
482        #[future] keeper: Arc<StateKeeper>,
483        mut headers: HeaderMap,
484        path_to_entry_1: Path<HashMap<String, String>>,
485    ) {
486        let keeper = keeper.await;
487        let components = keeper.get_anonymous().await.unwrap();
488        headers.insert("content-length", "0".parse().unwrap());
489        headers.insert("x-reduct-time-1", "0,,a=b".parse().unwrap());
490        headers.insert("x-reduct-time-2", "0,,a=d".parse().unwrap());
491
492        let stream = Body::empty();
493
494        write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
495            .await
496            .unwrap();
497
498        let bucket = components
499            .storage
500            .get_bucket("bucket-1")
501            .unwrap()
502            .upgrade_and_unwrap();
503
504        {
505            let mut reader = bucket
506                .get_entry("entry-1")
507                .unwrap()
508                .upgrade_and_unwrap()
509                .begin_read(1)
510                .await
511                .unwrap();
512            assert_eq!(reader.meta().content_length(), 0);
513            assert_eq!(reader.read_chunk(), None);
514        }
515        {
516            let mut reader = bucket
517                .get_entry("entry-1")
518                .unwrap()
519                .upgrade_and_unwrap()
520                .begin_read(2)
521                .await
522                .unwrap();
523            assert_eq!(reader.meta().content_length(), 0);
524            assert_eq!(reader.read_chunk(), None);
525        }
526    }
527
528    #[rstest]
529    #[tokio::test]
530    async fn test_write_batched_records_error(
531        #[future] keeper: Arc<StateKeeper>,
532        mut headers: HeaderMap,
533        path_to_entry_1: Path<HashMap<String, String>>,
534        #[future] body_stream: Body,
535    ) {
536        let keeper = keeper.await;
537        let components = keeper.get_anonymous().await.unwrap();
538        {
539            let mut writer = components
540                .storage
541                .get_bucket("bucket-1")
542                .unwrap()
543                .upgrade_and_unwrap()
544                .begin_write("entry-1", 2, 20, "text/plain".to_string(), HashMap::new())
545                .await
546                .unwrap();
547            writer
548                .send(Ok(Some(Bytes::from(vec![0; 20]))))
549                .await
550                .unwrap();
551            writer.send(Ok(None)).await.unwrap();
552        }
553
554        headers.insert("content-length", "48".parse().unwrap());
555        headers.insert("x-reduct-time-1", "10,".parse().unwrap());
556        headers.insert("x-reduct-time-2", "20,".parse().unwrap());
557        headers.insert("x-reduct-time-3", "18,".parse().unwrap());
558
559        let stream = body_stream.await;
560
561        let resp =
562            write_batched_records(State(Arc::clone(&keeper)), headers, path_to_entry_1, stream)
563                .await
564                .unwrap()
565                .into_response();
566
567        let headers = resp.headers();
568        assert_eq!(headers.len(), 1);
569        assert_eq!(
570            headers.get("x-reduct-error-2").unwrap(),
571            &HeaderValue::from_static("409,A record with timestamp 2 already exists")
572        );
573
574        let bucket = components
575            .storage
576            .get_bucket("bucket-1")
577            .unwrap()
578            .upgrade_and_unwrap();
579        {
580            let mut reader = bucket.begin_read("entry-1", 1).await.unwrap();
581            assert_eq!(reader.meta().content_length(), 10);
582            assert_eq!(reader.read_chunk().unwrap(), Ok(Bytes::from("1234567890")));
583        }
584        {
585            let mut reader = bucket.begin_read("entry-1", 3).await.unwrap();
586            assert_eq!(reader.meta().content_length(), 18);
587            assert_eq!(
588                reader.read_chunk().unwrap(),
589                Ok(Bytes::from("ef1234567890abcdef"))
590            );
591        }
592    }
593
594    #[fixture]
595    async fn body_stream() -> Body {
596        Body::from("1234567890abcdef1234567890abcdef1234567890abcdef")
597    }
598}