1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use async_stream::stream;
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_SIZE: usize = 1000;
30
31#[derive(Debug)]
36pub struct ManyBuilder {
37 session: Session,
38 operations: Vec<BatchOperation>,
39}
40
41impl Session {
42 pub fn many(&self) -> ManyBuilder {
47 ManyBuilder {
48 session: self.clone(),
49 operations: vec![],
50 }
51 }
52}
53
54#[derive(Debug)]
55enum BatchOperation {
56 Get {
57 key: ObjectKey,
58 decompress: bool,
59 accept_encoding: Vec<Compression>,
60 },
61 Insert {
62 key: Option<ObjectKey>,
63 metadata: Metadata,
64 body: PutBody,
65 },
66 Delete {
67 key: ObjectKey,
68 },
69}
70
71impl From<GetBuilder> for BatchOperation {
72 fn from(value: GetBuilder) -> Self {
73 let GetBuilder {
74 key,
75 decompress,
76 accept_encoding,
77 session: _session,
78 } = value;
79 BatchOperation::Get {
80 key,
81 decompress,
82 accept_encoding,
83 }
84 }
85}
86
87impl From<PutBuilder> for BatchOperation {
88 fn from(value: PutBuilder) -> Self {
89 let PutBuilder {
90 key,
91 metadata,
92 body,
93 session: _session,
94 } = value;
95 BatchOperation::Insert {
96 key,
97 metadata,
98 body,
99 }
100 }
101}
102
103impl From<DeleteBuilder> for BatchOperation {
104 fn from(value: DeleteBuilder) -> Self {
105 let DeleteBuilder {
106 key,
107 session: _session,
108 } = value;
109 BatchOperation::Delete { key }
110 }
111}
112
113impl BatchOperation {
114 async fn into_part(self) -> crate::Result<Part> {
115 match self {
116 BatchOperation::Get { key, .. } => {
117 let headers = operation_headers("get", Some(&key));
118 Ok(Part::text("").headers(headers))
119 }
120 BatchOperation::Insert {
121 key,
122 metadata,
123 body,
124 } => {
125 let mut headers = operation_headers("insert", key.as_deref());
126 headers.extend(metadata.to_headers("")?);
127
128 let body = put::maybe_compress(body, metadata.compression);
129 Ok(Part::stream(body).headers(headers))
130 }
131 BatchOperation::Delete { key } => {
132 let headers = operation_headers("delete", Some(&key));
133 Ok(Part::text("").headers(headers))
134 }
135 }
136 }
137}
138
139fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
140 let mut headers = HeaderMap::new();
141 headers.insert(
142 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
143 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
144 );
145 if let Some(key) = key {
146 let encoded =
147 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
148 headers.insert(
149 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
150 HeaderValue::try_from(encoded)
151 .expect("percent-encoded string is always a valid header value"),
152 );
153 }
154 headers
155}
156
157#[derive(Debug)]
159pub enum OperationResult {
160 Get(ObjectKey, Result<Option<GetResponse>, Error>),
164 Put(ObjectKey, Result<PutResponse, Error>),
166 Delete(ObjectKey, Result<DeleteResponse, Error>),
168 Error(Error),
174}
175
176enum OperationContext {
178 Get {
179 key: ObjectKey,
180 decompress: bool,
181 accept_encoding: Vec<Compression>,
182 },
183 Insert {
184 key: Option<ObjectKey>,
185 },
186 Delete {
187 key: ObjectKey,
188 },
189}
190
191impl From<&BatchOperation> for OperationContext {
192 fn from(op: &BatchOperation) -> Self {
193 match op {
194 BatchOperation::Get {
195 key,
196 decompress,
197 accept_encoding,
198 } => OperationContext::Get {
199 key: key.clone(),
200 decompress: *decompress,
201 accept_encoding: accept_encoding.clone(),
202 },
203 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
204 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
205 }
206 }
207}
208
209impl OperationContext {
210 fn key(&self) -> Option<&str> {
211 match self {
212 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
213 OperationContext::Insert { key } => key.as_deref(),
214 }
215 }
216}
217
218impl OperationResult {
219 async fn from_field(
220 field: Field<'_>,
221 context_map: &HashMap<usize, OperationContext>,
222 ) -> (Option<usize>, Self) {
223 match Self::try_from_field(field, context_map).await {
224 Ok((index, result)) => (Some(index), result),
225 Err(e) => (None, OperationResult::Error(e)),
226 }
227 }
228
229 async fn try_from_field(
230 field: Field<'_>,
231 context_map: &HashMap<usize, OperationContext>,
232 ) -> Result<(usize, Self), Error> {
233 let mut headers = field.headers().clone();
234
235 let index: usize = headers
236 .remove(HEADER_BATCH_OPERATION_INDEX)
237 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
238 .ok_or_else(|| {
239 Error::MalformedResponse(format!(
240 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
241 ))
242 })?;
243
244 let status: u16 = headers
245 .remove(HEADER_BATCH_OPERATION_STATUS)
246 .and_then(|v| {
247 v.to_str().ok().and_then(|s| {
248 s.split_once(' ')
251 .map(|(code, _)| code)
252 .unwrap_or(s)
253 .parse()
254 .ok()
255 })
256 })
257 .ok_or_else(|| {
258 Error::MalformedResponse(format!(
259 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
260 ))
261 })?;
262
263 let ctx = context_map.get(&index).ok_or_else(|| {
264 Error::MalformedResponse(format!(
265 "response references unknown operation index {index}"
266 ))
267 })?;
268
269 let key = headers
271 .remove(HEADER_BATCH_OPERATION_KEY)
272 .and_then(|v| {
273 v.to_str()
274 .ok()
275 .and_then(|encoded| {
276 percent_encoding::percent_decode_str(encoded)
277 .decode_utf8()
278 .ok()
279 })
280 .map(|s| s.into_owned())
281 })
282 .or_else(|| ctx.key().map(str::to_owned));
283
284 let body = field.bytes().await?;
285
286 let is_error =
287 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
288
289 let key = match key {
294 Some(key) => key,
295 None if is_error => "<unknown>".to_owned(),
296 None => {
297 return Err(Error::MalformedResponse(format!(
298 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
299 )));
300 }
301 };
302 if is_error {
303 let message = String::from_utf8_lossy(&body).into_owned();
304 let error = Error::OperationFailure { status, message };
305
306 return Ok((
307 index,
308 match ctx {
309 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
310 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
311 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
312 },
313 ));
314 }
315
316 let result = match ctx {
317 OperationContext::Get {
318 decompress,
319 accept_encoding,
320 ..
321 } => {
322 if status == 404 {
323 OperationResult::Get(key, Ok(None))
324 } else {
325 let mut metadata = Metadata::from_headers(&headers, "")?;
326
327 let stream =
328 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
329 let stream =
330 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
331
332 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
333 }
334 }
335 OperationContext::Insert { .. } => {
336 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
337 }
338 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
339 };
340 Ok((index, result))
341 }
342}
343
344pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
346
347impl fmt::Debug for OperationResults {
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 f.write_str("OperationResults([Stream])")
350 }
351}
352
353impl Stream for OperationResults {
354 type Item = OperationResult;
355
356 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357 self.0.as_mut().poll_next(cx)
358 }
359}
360
361impl OperationResults {
362 pub async fn error_for_failures(
367 mut self,
368 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
369 let mut errs = Vec::new();
370 while let Some(res) = self.next().await {
371 match res {
372 OperationResult::Get(_, get) => {
373 if let Err(e) = get {
374 errs.push(e);
375 }
376 }
377 OperationResult::Put(_, put) => {
378 if let Err(e) = put {
379 errs.push(e);
380 }
381 }
382 OperationResult::Delete(_, delete) => {
383 if let Err(e) = delete {
384 errs.push(e);
385 }
386 }
387 OperationResult::Error(error) => errs.push(error),
388 }
389 }
390 if errs.is_empty() {
391 return Ok(());
392 }
393 Err(errs.into_iter())
394 }
395}
396
397async fn send_batch(
398 session: &Session,
399 operations: Vec<BatchOperation>,
400) -> crate::Result<Vec<OperationResult>> {
401 let context_map: HashMap<usize, OperationContext> = operations
402 .iter()
403 .enumerate()
404 .map(|(idx, op)| (idx, OperationContext::from(op)))
405 .collect();
406 let num_operations = operations.len();
407
408 let mut form = reqwest::multipart::Form::new();
409 for op in operations.into_iter() {
410 let part = op.into_part().await?;
411 form = form.part("part", part);
412 }
413
414 let request = session.batch_request()?.multipart(form);
415 let response = request.send().await?.error_for_status()?;
416
417 let boundary = response
418 .headers()
419 .get(CONTENT_TYPE)
420 .and_then(|v| v.to_str().ok())
421 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
422 .map(multer::parse_boundary)??;
423
424 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
425 let mut multipart = multer::Multipart::new(byte_stream, boundary);
426
427 let mut results = Vec::new();
428 let mut seen_indices = HashSet::new();
429 while let Some(field) = multipart.next_field().await? {
430 let (index, result) = OperationResult::from_field(field, &context_map).await;
431 if let Some(idx) = index {
432 seen_indices.insert(idx);
433 }
434 results.push(result);
435 }
436
437 for idx in 0..num_operations {
438 if !seen_indices.contains(&idx) {
439 let error = Error::MalformedResponse(format!(
440 "server did not return a response for operation at index {idx}"
441 ));
442 let result = match context_map.get(&idx) {
443 Some(ctx) => {
444 let key = ctx.key().unwrap_or("<unknown>").to_owned();
445 match ctx {
446 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
447 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
448 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
449 }
450 }
451 None => OperationResult::Error(error),
452 };
453 results.push(result);
454 }
455 }
456
457 Ok(results)
458}
459
460impl ManyBuilder {
461 pub fn send(self) -> OperationResults {
465 let session = self.session;
466 let mut operations = self.operations;
467
468 let inner = stream! {
469 while !operations.is_empty() {
470 let chunk_size = operations.len().min(MAX_BATCH_SIZE);
471 let chunk: Vec<_> = operations.drain(..chunk_size).collect();
472
473 let contexts: Vec<_> =
476 chunk.iter().map(OperationContext::from).collect();
477
478 match send_batch(&session, chunk).await {
479 Ok(results) => {
480 for result in results {
481 yield result;
482 }
483 }
484 Err(e) => {
485 let shared = std::sync::Arc::new(e);
486 for ctx in contexts {
487 let error = Error::Batch(shared.clone());
488 let key = ctx.key().unwrap_or("<unknown>").to_owned();
489 yield match ctx {
490 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
491 OperationContext::Insert { .. } => {
492 OperationResult::Put(key, Err(error))
493 }
494 OperationContext::Delete { .. } => {
495 OperationResult::Delete(key, Err(error))
496 }
497 };
498 }
499 }
500 }
501 }
502 };
503
504 OperationResults(Box::pin(inner))
505 }
506
507 #[allow(private_bounds)]
517 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
518 self.operations.push(builder.into());
519 self
520 }
521}