1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_OPS: usize = 1000;
29
30const MAX_BATCH_PART_SIZE: u32 = 1024 * 1024; const MAX_INDIVIDUAL_CONCURRENCY: usize = 5;
37
38const MAX_BATCH_CONCURRENCY: usize = 3;
40
41const MAX_BATCH_BODY_SIZE: u64 = 100 * 1024 * 1024; #[derive(Debug)]
49pub struct ManyBuilder {
50 session: Session,
51 operations: Vec<BatchOperation>,
52}
53
54impl Session {
55 pub fn many(&self) -> ManyBuilder {
60 ManyBuilder {
61 session: self.clone(),
62 operations: vec![],
63 }
64 }
65}
66
67#[derive(Debug)]
68#[allow(clippy::large_enum_variant)]
69enum BatchOperation {
70 Get {
71 key: ObjectKey,
72 decompress: bool,
73 accept_encoding: Vec<Compression>,
74 },
75 Insert {
76 key: Option<ObjectKey>,
77 metadata: Metadata,
78 body: PutBody,
79 },
80 Delete {
81 key: ObjectKey,
82 },
83}
84
85impl From<GetBuilder> for BatchOperation {
86 fn from(value: GetBuilder) -> Self {
87 let GetBuilder {
88 key,
89 decompress,
90 accept_encoding,
91 session: _session,
92 } = value;
93 BatchOperation::Get {
94 key,
95 decompress,
96 accept_encoding,
97 }
98 }
99}
100
101impl From<PutBuilder> for BatchOperation {
102 fn from(value: PutBuilder) -> Self {
103 let PutBuilder {
104 key,
105 metadata,
106 body,
107 session: _session,
108 } = value;
109 BatchOperation::Insert {
110 key,
111 metadata,
112 body,
113 }
114 }
115}
116
117impl From<DeleteBuilder> for BatchOperation {
118 fn from(value: DeleteBuilder) -> Self {
119 let DeleteBuilder {
120 key,
121 session: _session,
122 } = value;
123 BatchOperation::Delete { key }
124 }
125}
126
127impl BatchOperation {
128 async fn into_part(self) -> crate::Result<Part> {
129 match self {
130 BatchOperation::Get { key, .. } => {
131 let headers = operation_headers("get", Some(&key));
132 Ok(Part::text("").headers(headers))
133 }
134 BatchOperation::Insert {
135 key,
136 metadata,
137 body,
138 } => {
139 let mut headers = operation_headers("insert", key.as_deref());
140 headers.extend(metadata.to_headers("")?);
141
142 let body = put::maybe_compress(body, metadata.compression).await?;
143 Ok(Part::stream(body).headers(headers))
144 }
145 BatchOperation::Delete { key } => {
146 let headers = operation_headers("delete", Some(&key));
147 Ok(Part::text("").headers(headers))
148 }
149 }
150 }
151}
152
153fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
154 let mut headers = HeaderMap::new();
155 headers.insert(
156 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
157 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
158 );
159 if let Some(key) = key {
160 let encoded =
161 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
162 headers.insert(
163 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
164 HeaderValue::try_from(encoded)
165 .expect("percent-encoded string is always a valid header value"),
166 );
167 }
168 headers
169}
170
171#[derive(Debug)]
173pub enum OperationResult {
174 Get(ObjectKey, Result<Option<GetResponse>, Error>),
178 Put(ObjectKey, Result<PutResponse, Error>),
180 Delete(ObjectKey, Result<DeleteResponse, Error>),
182 Error(Error),
188}
189
190enum OperationContext {
192 Get {
193 key: ObjectKey,
194 decompress: bool,
195 accept_encoding: Vec<Compression>,
196 },
197 Insert {
198 key: Option<ObjectKey>,
199 },
200 Delete {
201 key: ObjectKey,
202 },
203}
204
205impl From<&BatchOperation> for OperationContext {
206 fn from(op: &BatchOperation) -> Self {
207 match op {
208 BatchOperation::Get {
209 key,
210 decompress,
211 accept_encoding,
212 } => OperationContext::Get {
213 key: key.clone(),
214 decompress: *decompress,
215 accept_encoding: accept_encoding.clone(),
216 },
217 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
218 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
219 }
220 }
221}
222
223impl OperationContext {
224 fn key(&self) -> Option<&str> {
225 match self {
226 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
227 OperationContext::Insert { key } => key.as_deref(),
228 }
229 }
230}
231
232enum Classified {
234 Batchable(BatchOperation, u64),
236 Individual(BatchOperation),
238 Failed(OperationResult),
240}
241
242fn error_result(ctx: OperationContext, error: Error) -> OperationResult {
244 let key = ctx.key().unwrap_or("<unknown>").to_owned();
245 match ctx {
246 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
247 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
248 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
249 }
250}
251
252impl OperationResult {
253 async fn from_field(
254 field: Field<'_>,
255 context_map: &HashMap<usize, OperationContext>,
256 ) -> (Option<usize>, Self) {
257 match Self::try_from_field(field, context_map).await {
258 Ok((index, result)) => (Some(index), result),
259 Err(e) => (None, OperationResult::Error(e)),
260 }
261 }
262
263 async fn try_from_field(
264 field: Field<'_>,
265 context_map: &HashMap<usize, OperationContext>,
266 ) -> Result<(usize, Self), Error> {
267 let mut headers = field.headers().clone();
268
269 let index: usize = headers
270 .remove(HEADER_BATCH_OPERATION_INDEX)
271 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
272 .ok_or_else(|| {
273 Error::MalformedResponse(format!(
274 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
275 ))
276 })?;
277
278 let status: u16 = headers
279 .remove(HEADER_BATCH_OPERATION_STATUS)
280 .and_then(|v| {
281 v.to_str().ok().and_then(|s| {
282 s.split_once(' ')
285 .map(|(code, _)| code)
286 .unwrap_or(s)
287 .parse()
288 .ok()
289 })
290 })
291 .ok_or_else(|| {
292 Error::MalformedResponse(format!(
293 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
294 ))
295 })?;
296
297 let ctx = context_map.get(&index).ok_or_else(|| {
298 Error::MalformedResponse(format!(
299 "response references unknown operation index {index}"
300 ))
301 })?;
302
303 let key = headers
305 .remove(HEADER_BATCH_OPERATION_KEY)
306 .and_then(|v| {
307 v.to_str()
308 .ok()
309 .and_then(|encoded| {
310 percent_encoding::percent_decode_str(encoded)
311 .decode_utf8()
312 .ok()
313 })
314 .map(|s| s.into_owned())
315 })
316 .or_else(|| ctx.key().map(str::to_owned));
317
318 let body = field.bytes().await?;
319
320 let is_error =
321 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
322
323 let key = match key {
328 Some(key) => key,
329 None if is_error => "<unknown>".to_owned(),
330 None => {
331 return Err(Error::MalformedResponse(format!(
332 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
333 )));
334 }
335 };
336 if is_error {
337 let message = String::from_utf8_lossy(&body).into_owned();
338 let error = Error::OperationFailure { status, message };
339
340 return Ok((
341 index,
342 match ctx {
343 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
344 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
345 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
346 },
347 ));
348 }
349
350 let result = match ctx {
351 OperationContext::Get {
352 decompress,
353 accept_encoding,
354 ..
355 } => {
356 if status == 404 {
357 OperationResult::Get(key, Ok(None))
358 } else {
359 let mut metadata = Metadata::from_headers(&headers, "")?;
360
361 let stream =
362 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
363 let stream =
364 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
365
366 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
367 }
368 }
369 OperationContext::Insert { .. } => {
370 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
371 }
372 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
373 };
374 Ok((index, result))
375 }
376}
377
378pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
380
381impl fmt::Debug for OperationResults {
382 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383 f.write_str("OperationResults([Stream])")
384 }
385}
386
387impl Stream for OperationResults {
388 type Item = OperationResult;
389
390 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
391 self.0.as_mut().poll_next(cx)
392 }
393}
394
395impl OperationResults {
396 pub async fn error_for_failures(
401 mut self,
402 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
403 let mut errs = Vec::new();
404 while let Some(res) = self.next().await {
405 match res {
406 OperationResult::Get(_, get) => {
407 if let Err(e) = get {
408 errs.push(e);
409 }
410 }
411 OperationResult::Put(_, put) => {
412 if let Err(e) = put {
413 errs.push(e);
414 }
415 }
416 OperationResult::Delete(_, delete) => {
417 if let Err(e) = delete {
418 errs.push(e);
419 }
420 }
421 OperationResult::Error(error) => errs.push(error),
422 }
423 }
424 if errs.is_empty() {
425 return Ok(());
426 }
427 Err(errs.into_iter())
428 }
429}
430
431async fn send_batch(
432 session: &Session,
433 operations: Vec<BatchOperation>,
434) -> crate::Result<Vec<OperationResult>> {
435 let mut context_map: HashMap<usize, OperationContext> = operations
436 .iter()
437 .enumerate()
438 .map(|(idx, op)| (idx, OperationContext::from(op)))
439 .collect();
440 let num_operations = operations.len();
441
442 let mut form = reqwest::multipart::Form::new();
443 for op in operations.into_iter() {
444 let part = op.into_part().await?;
445 form = form.part("part", part);
446 }
447
448 let request = session.batch_request()?.multipart(form);
449 let response = request.send().await?.error_for_status()?;
450
451 let boundary = response
452 .headers()
453 .get(CONTENT_TYPE)
454 .and_then(|v| v.to_str().ok())
455 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
456 .map(multer::parse_boundary)??;
457
458 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
459 let mut multipart = multer::Multipart::new(byte_stream, boundary);
460
461 let mut results = Vec::new();
462 let mut seen_indices = HashSet::new();
463 while let Some(field) = multipart.next_field().await? {
464 let (index, result) = OperationResult::from_field(field, &context_map).await;
465 if let Some(idx) = index {
466 seen_indices.insert(idx);
467 }
468 results.push(result);
469 }
470
471 for idx in 0..num_operations {
472 if !seen_indices.contains(&idx) {
473 let error = Error::MalformedResponse(format!(
474 "server did not return a response for operation at index {idx}"
475 ));
476 let result = match context_map.remove(&idx) {
477 Some(ctx) => error_result(ctx, error),
478 None => OperationResult::Error(error),
479 };
480 results.push(result);
481 }
482 }
483
484 Ok(results)
485}
486
487fn classify_fail(key: Option<ObjectKey>, error: Error) -> Classified {
488 Classified::Failed(OperationResult::Put(
489 key.unwrap_or_else(|| "<unknown>".to_owned()),
490 Err(error),
491 ))
492}
493
494async fn classify(op: BatchOperation) -> Classified {
499 match op {
500 BatchOperation::Insert {
501 key,
502 metadata,
503 body,
504 } => {
505 let size = match &body {
506 PutBody::Buffer(bytes) => Some(bytes.len() as u64),
507 PutBody::File(file) => match file.metadata().await {
508 Ok(meta) => Some(meta.len()),
509 Err(err) => return classify_fail(key, err.into()),
510 },
511 PutBody::Path(path) => match tokio::fs::metadata(path).await {
512 Ok(meta) => Some(meta.len()),
513 Err(err) => return classify_fail(key, err.into()),
514 },
515 PutBody::Stream(_) => None,
517 };
518
519 let op = BatchOperation::Insert {
520 key,
521 metadata,
522 body,
523 };
524
525 match size {
526 Some(s) if s <= MAX_BATCH_PART_SIZE as u64 => Classified::Batchable(op, s),
527 _ => Classified::Individual(op),
528 }
529 }
530 other => Classified::Batchable(other, 0),
531 }
532}
533
534async fn partition(
538 operations: Vec<BatchOperation>,
539) -> (
540 Vec<(BatchOperation, u64)>,
541 Vec<BatchOperation>,
542 Vec<OperationResult>,
543) {
544 let classified = futures_util::future::join_all(operations.into_iter().map(classify)).await;
545 let mut batchable = Vec::new();
546 let mut individual = Vec::new();
547 let mut failed = Vec::new();
548 for item in classified {
549 match item {
550 Classified::Batchable(op, size) => batchable.push((op, size)),
551 Classified::Individual(op) => individual.push(op),
552 Classified::Failed(result) => failed.push(result),
553 }
554 }
555 (batchable, individual, failed)
556}
557
558async fn execute_individual(op: BatchOperation, session: &Session) -> OperationResult {
560 match op {
561 BatchOperation::Get {
562 key,
563 decompress,
564 accept_encoding,
565 } => {
566 let get = GetBuilder {
567 session: session.clone(),
568 key: key.clone(),
569 decompress,
570 accept_encoding,
571 };
572 OperationResult::Get(key, get.send().await)
573 }
574 BatchOperation::Insert {
575 key,
576 metadata,
577 body,
578 } => {
579 let error_key = key.clone().unwrap_or_else(|| "<unknown>".to_owned());
580 let put = PutBuilder {
581 session: session.clone(),
582 metadata,
583 key,
584 body,
585 };
586 match put.send().await {
587 Ok(response) => OperationResult::Put(response.key.clone(), Ok(response)),
588 Err(err) => OperationResult::Put(error_key, Err(err)),
589 }
590 }
591 BatchOperation::Delete { key } => {
592 let delete = DeleteBuilder {
593 session: session.clone(),
594 key: key.clone(),
595 };
596 OperationResult::Delete(key, delete.send().await)
597 }
598 }
599}
600
601async fn execute_batch(operations: Vec<BatchOperation>, session: &Session) -> Vec<OperationResult> {
605 let contexts: Vec<_> = operations.iter().map(OperationContext::from).collect();
606 match send_batch(session, operations).await {
607 Ok(results) => results,
608 Err(e) => {
609 let shared = Arc::new(e);
610 contexts
611 .into_iter()
612 .map(|ctx| error_result(ctx, Error::Batch(shared.clone())))
613 .collect()
614 }
615 }
616}
617
618fn iter_batches(ops: Vec<(BatchOperation, u64)>) -> impl Iterator<Item = Vec<BatchOperation>> {
623 let mut remaining = ops.into_iter().peekable();
624
625 std::iter::from_fn(move || {
626 remaining.peek()?;
627 let mut batch_size = 0;
628 let mut batch = Vec::new();
629
630 while let Some((_, op_size)) = remaining.peek() {
631 if batch.len() >= MAX_BATCH_OPS
632 || (!batch.is_empty() && batch_size + op_size > MAX_BATCH_BODY_SIZE)
633 {
634 break;
635 }
636
637 let (op, op_size) = remaining.next().expect("peeked above");
638 batch_size += op_size;
639 batch.push(op);
640 }
641
642 Some(batch)
643 })
644}
645
646impl ManyBuilder {
647 pub async fn send(self) -> OperationResults {
651 let session = self.session;
652
653 let (batchable, individual, failed) = partition(self.operations).await;
655
656 let individual_results = futures_util::stream::iter(individual)
658 .map({
659 let session = session.clone();
660 move |op| {
661 let session = session.clone();
662 async move { execute_individual(op, &session).await }
663 }
664 })
665 .buffer_unordered(MAX_INDIVIDUAL_CONCURRENCY);
666
667 let batch_results = futures_util::stream::iter(iter_batches(batchable))
669 .map(move |chunk| {
670 let session = session.clone();
671 async move { execute_batch(chunk, &session).await }
672 })
673 .buffer_unordered(MAX_BATCH_CONCURRENCY)
674 .flat_map(futures_util::stream::iter);
675
676 let results = futures_util::stream::iter(failed)
677 .chain(individual_results)
678 .chain(batch_results);
679
680 OperationResults(results.boxed())
681 }
682
683 #[allow(private_bounds)]
693 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
694 self.operations.push(builder.into());
695 self
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 fn op(size: u64) -> (BatchOperation, u64) {
705 (
706 BatchOperation::Delete {
707 key: "k".to_owned(),
708 },
709 size,
710 )
711 }
712
713 fn batch_sizes(batches: &[Vec<BatchOperation>]) -> Vec<usize> {
714 batches.iter().map(Vec::len).collect()
715 }
716
717 fn batches(ops: Vec<(BatchOperation, u64)>) -> Vec<Vec<BatchOperation>> {
718 iter_batches(ops).collect()
719 }
720
721 #[test]
722 fn iter_batches_empty() {
723 assert!(batches(vec![]).is_empty());
724 }
725
726 #[test]
727 fn iter_batches_single_batch_count_limit() {
728 let ops: Vec<_> = (0..1000).map(|_| op(1)).collect();
730 assert_eq!(batch_sizes(&batches(ops)), vec![1000]);
731 }
732
733 #[test]
734 fn iter_batches_splits_on_count_limit() {
735 let ops: Vec<_> = (0..1001).map(|_| op(1)).collect();
737 assert_eq!(batch_sizes(&batches(ops)), vec![1000, 1]);
738 }
739
740 #[test]
741 fn iter_batches_exactly_at_size_limit() {
742 let ops: Vec<_> = (0..100).map(|_| op(1024 * 1024)).collect();
744 assert_eq!(batch_sizes(&batches(ops)), vec![100]);
745 }
746
747 #[test]
748 fn iter_batches_splits_on_size_limit() {
749 let ops: Vec<_> = (0..101).map(|_| op(1024 * 1024)).collect();
751 assert_eq!(batch_sizes(&batches(ops)), vec![100, 1]);
752 }
753
754 #[test]
755 fn iter_batches_size_limit_hits_before_count_limit() {
756 let op_size = 600 * 1024;
758 let ops: Vec<_> = (0..200).map(|_| op(op_size)).collect();
759 let result = batches(ops);
760 let per_batch = (MAX_BATCH_BODY_SIZE / op_size) as usize;
762 assert!(result.len() > 1, "expected multiple batches");
763 for batch in &result[..result.len() - 1] {
764 assert_eq!(batch.len(), per_batch);
765 }
766 }
767}