1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use async_stream::stream;
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::Metadata;
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_SIZE: usize = 1000;
30
31#[derive(Debug)]
36pub struct ManyBuilder {
37 session: Session,
38 operations: Vec<BatchOperation>,
39}
40
41impl Session {
42 pub fn many(&self) -> ManyBuilder {
47 ManyBuilder {
48 session: self.clone(),
49 operations: vec![],
50 }
51 }
52}
53
54#[derive(Debug)]
55enum BatchOperation {
56 Get {
57 key: ObjectKey,
58 decompress: bool,
59 },
60 Insert {
61 key: Option<ObjectKey>,
62 metadata: Metadata,
63 body: PutBody,
64 },
65 Delete {
66 key: ObjectKey,
67 },
68}
69
70impl From<GetBuilder> for BatchOperation {
71 fn from(value: GetBuilder) -> Self {
72 let GetBuilder {
73 key,
74 decompress,
75 session: _session,
76 } = value;
77 BatchOperation::Get { key, decompress }
78 }
79}
80
81impl From<PutBuilder> for BatchOperation {
82 fn from(value: PutBuilder) -> Self {
83 let PutBuilder {
84 key,
85 metadata,
86 body,
87 session: _session,
88 } = value;
89 BatchOperation::Insert {
90 key,
91 metadata,
92 body,
93 }
94 }
95}
96
97impl From<DeleteBuilder> for BatchOperation {
98 fn from(value: DeleteBuilder) -> Self {
99 let DeleteBuilder {
100 key,
101 session: _session,
102 } = value;
103 BatchOperation::Delete { key }
104 }
105}
106
107impl BatchOperation {
108 async fn into_part(self) -> crate::Result<Part> {
109 match self {
110 BatchOperation::Get { key, .. } => {
111 let headers = operation_headers("get", Some(&key));
112 Ok(Part::text("").headers(headers))
113 }
114 BatchOperation::Insert {
115 key,
116 metadata,
117 body,
118 } => {
119 let mut headers = operation_headers("insert", key.as_deref());
120 headers.extend(metadata.to_headers("")?);
121
122 let body = put::maybe_compress(body, metadata.compression);
123 Ok(Part::stream(body).headers(headers))
124 }
125 BatchOperation::Delete { key } => {
126 let headers = operation_headers("delete", Some(&key));
127 Ok(Part::text("").headers(headers))
128 }
129 }
130 }
131}
132
133fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
134 let mut headers = HeaderMap::new();
135 headers.insert(
136 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
137 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
138 );
139 if let Some(key) = key {
140 let encoded =
141 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
142 headers.insert(
143 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
144 HeaderValue::try_from(encoded)
145 .expect("percent-encoded string is always a valid header value"),
146 );
147 }
148 headers
149}
150
151#[derive(Debug)]
153pub enum OperationResult {
154 Get(ObjectKey, Result<Option<GetResponse>, Error>),
158 Put(ObjectKey, Result<PutResponse, Error>),
160 Delete(ObjectKey, Result<DeleteResponse, Error>),
162 Error(Error),
168}
169
170enum OperationContext {
172 Get { key: ObjectKey, decompress: bool },
173 Insert { key: Option<ObjectKey> },
174 Delete { key: ObjectKey },
175}
176
177impl From<&BatchOperation> for OperationContext {
178 fn from(op: &BatchOperation) -> Self {
179 match op {
180 BatchOperation::Get { key, decompress } => OperationContext::Get {
181 key: key.clone(),
182 decompress: *decompress,
183 },
184 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
185 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
186 }
187 }
188}
189
190impl OperationContext {
191 fn key(&self) -> Option<&str> {
192 match self {
193 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
194 OperationContext::Insert { key } => key.as_deref(),
195 }
196 }
197}
198
199impl OperationResult {
200 async fn from_field(
201 field: Field<'_>,
202 context_map: &HashMap<usize, OperationContext>,
203 ) -> (Option<usize>, Self) {
204 match Self::try_from_field(field, context_map).await {
205 Ok((index, result)) => (Some(index), result),
206 Err(e) => (None, OperationResult::Error(e)),
207 }
208 }
209
210 async fn try_from_field(
211 field: Field<'_>,
212 context_map: &HashMap<usize, OperationContext>,
213 ) -> Result<(usize, Self), Error> {
214 let mut headers = field.headers().clone();
215
216 let index: usize = headers
217 .remove(HEADER_BATCH_OPERATION_INDEX)
218 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
219 .ok_or_else(|| {
220 Error::MalformedResponse(format!(
221 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
222 ))
223 })?;
224
225 let status: u16 = headers
226 .remove(HEADER_BATCH_OPERATION_STATUS)
227 .and_then(|v| {
228 v.to_str().ok().and_then(|s| {
229 s.split_once(' ')
232 .map(|(code, _)| code)
233 .unwrap_or(s)
234 .parse()
235 .ok()
236 })
237 })
238 .ok_or_else(|| {
239 Error::MalformedResponse(format!(
240 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
241 ))
242 })?;
243
244 let ctx = context_map.get(&index).ok_or_else(|| {
245 Error::MalformedResponse(format!(
246 "response references unknown operation index {index}"
247 ))
248 })?;
249
250 let key = headers
252 .remove(HEADER_BATCH_OPERATION_KEY)
253 .and_then(|v| {
254 v.to_str()
255 .ok()
256 .and_then(|encoded| {
257 percent_encoding::percent_decode_str(encoded)
258 .decode_utf8()
259 .ok()
260 })
261 .map(|s| s.into_owned())
262 })
263 .or_else(|| ctx.key().map(str::to_owned));
264
265 let body = field.bytes().await?;
266
267 let is_error =
268 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
269
270 let key = match key {
275 Some(key) => key,
276 None if is_error => "<unknown>".to_owned(),
277 None => {
278 return Err(Error::MalformedResponse(format!(
279 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
280 )));
281 }
282 };
283 if is_error {
284 let message = String::from_utf8_lossy(&body).into_owned();
285 let error = Error::OperationFailure { status, message };
286
287 return Ok((
288 index,
289 match ctx {
290 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
291 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
292 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
293 },
294 ));
295 }
296
297 let result = match ctx {
298 OperationContext::Get { decompress, .. } => {
299 if status == 404 {
300 OperationResult::Get(key, Ok(None))
301 } else {
302 let mut metadata = Metadata::from_headers(&headers, "")?;
303
304 let stream =
305 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
306 let stream = get::maybe_decompress(stream, &mut metadata, *decompress);
307
308 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
309 }
310 }
311 OperationContext::Insert { .. } => {
312 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
313 }
314 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
315 };
316 Ok((index, result))
317 }
318}
319
320pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
322
323impl fmt::Debug for OperationResults {
324 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325 f.write_str("OperationResults([Stream])")
326 }
327}
328
329impl Stream for OperationResults {
330 type Item = OperationResult;
331
332 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
333 self.0.as_mut().poll_next(cx)
334 }
335}
336
337impl OperationResults {
338 pub async fn error_for_failures(
343 mut self,
344 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
345 let mut errs = Vec::new();
346 while let Some(res) = self.next().await {
347 match res {
348 OperationResult::Get(_, get) => {
349 if let Err(e) = get {
350 errs.push(e);
351 }
352 }
353 OperationResult::Put(_, put) => {
354 if let Err(e) = put {
355 errs.push(e);
356 }
357 }
358 OperationResult::Delete(_, delete) => {
359 if let Err(e) = delete {
360 errs.push(e);
361 }
362 }
363 OperationResult::Error(error) => errs.push(error),
364 }
365 }
366 if errs.is_empty() {
367 return Ok(());
368 }
369 Err(errs.into_iter())
370 }
371}
372
373async fn send_batch(
374 session: &Session,
375 operations: Vec<BatchOperation>,
376) -> crate::Result<Vec<OperationResult>> {
377 let context_map: HashMap<usize, OperationContext> = operations
378 .iter()
379 .enumerate()
380 .map(|(idx, op)| (idx, OperationContext::from(op)))
381 .collect();
382 let num_operations = operations.len();
383
384 let mut form = reqwest::multipart::Form::new();
385 for op in operations.into_iter() {
386 let part = op.into_part().await?;
387 form = form.part("part", part);
388 }
389
390 let request = session.batch_request()?.multipart(form);
391 let response = request.send().await?.error_for_status()?;
392
393 let boundary = response
394 .headers()
395 .get(CONTENT_TYPE)
396 .and_then(|v| v.to_str().ok())
397 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
398 .map(multer::parse_boundary)??;
399
400 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
401 let mut multipart = multer::Multipart::new(byte_stream, boundary);
402
403 let mut results = Vec::new();
404 let mut seen_indices = HashSet::new();
405 while let Some(field) = multipart.next_field().await? {
406 let (index, result) = OperationResult::from_field(field, &context_map).await;
407 if let Some(idx) = index {
408 seen_indices.insert(idx);
409 }
410 results.push(result);
411 }
412
413 for idx in 0..num_operations {
414 if !seen_indices.contains(&idx) {
415 let error = Error::MalformedResponse(format!(
416 "server did not return a response for operation at index {idx}"
417 ));
418 let result = match context_map.get(&idx) {
419 Some(ctx) => {
420 let key = ctx.key().unwrap_or("<unknown>").to_owned();
421 match ctx {
422 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
423 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
424 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
425 }
426 }
427 None => OperationResult::Error(error),
428 };
429 results.push(result);
430 }
431 }
432
433 Ok(results)
434}
435
436impl ManyBuilder {
437 pub fn send(self) -> OperationResults {
441 let session = self.session;
442 let mut operations = self.operations;
443
444 let inner = stream! {
445 while !operations.is_empty() {
446 let chunk_size = operations.len().min(MAX_BATCH_SIZE);
447 let chunk: Vec<_> = operations.drain(..chunk_size).collect();
448
449 let contexts: Vec<_> =
452 chunk.iter().map(OperationContext::from).collect();
453
454 match send_batch(&session, chunk).await {
455 Ok(results) => {
456 for result in results {
457 yield result;
458 }
459 }
460 Err(e) => {
461 let shared = std::sync::Arc::new(e);
462 for ctx in contexts {
463 let error = Error::Batch(shared.clone());
464 let key = ctx.key().unwrap_or("<unknown>").to_owned();
465 yield match ctx {
466 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
467 OperationContext::Insert { .. } => {
468 OperationResult::Put(key, Err(error))
469 }
470 OperationContext::Delete { .. } => {
471 OperationResult::Delete(key, Err(error))
472 }
473 };
474 }
475 }
476 }
477 }
478 };
479
480 OperationResults(Box::pin(inner))
481 }
482
483 #[allow(private_bounds)]
493 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
494 self.operations.push(builder.into());
495 self
496 }
497}