1use std::fmt::Debug;
2
3use mountpoint_s3_client::checksums::{Crc32c, crc32c, crc32c_from_base64};
4use mountpoint_s3_client::error::{ObjectClientError, PutObjectError};
5use mountpoint_s3_client::types::{
6 ChecksumAlgorithm, PutObjectParams, PutObjectResult, PutObjectTrailingChecksums, UploadReview,
7};
8use mountpoint_s3_client::{ObjectClient, PutObjectRequest};
9use tracing::error;
10
11use crate::ServerSideEncryption;
12use crate::async_util::{RemoteResult, Runtime};
13use crate::checksums::combine_checksums;
14
15use super::UploadError;
16
17const MAX_S3_MULTIPART_UPLOAD_PARTS: usize = 10000;
18
19pub struct UploadRequest<Client: ObjectClient> {
23 request: RemoteResult<Client::PutObjectRequest, ObjectClientError<PutObjectError, Client::ClientError>>,
24 bucket: String,
25 key: String,
26 next_request_offset: u64,
27 hasher: crc32c::Hasher,
28 maximum_upload_size: usize,
29 sse: ServerSideEncryption,
30}
31
32pub struct UploadRequestParams {
34 pub bucket: String,
35 pub key: String,
36 pub server_side_encryption: ServerSideEncryption,
37 pub default_checksum_algorithm: Option<ChecksumAlgorithm>,
38 pub storage_class: Option<String>,
39}
40
41impl<Client> UploadRequest<Client>
42where
43 Client: ObjectClient + Send + 'static,
44{
45 pub fn new(
46 runtime: &Runtime,
47 client: Client,
48 params: UploadRequestParams,
49 ) -> Result<Self, UploadError<Client::ClientError>> {
50 let mut put_object_params = PutObjectParams::new();
51
52 match ¶ms.default_checksum_algorithm {
53 Some(ChecksumAlgorithm::Crc32c) => {
54 put_object_params = put_object_params.trailing_checksums(PutObjectTrailingChecksums::Enabled);
55 }
56 Some(unsupported) => {
57 unimplemented!("checksum algorithm not supported: {:?}", unsupported);
58 }
59 None => {
60 put_object_params = put_object_params.trailing_checksums(PutObjectTrailingChecksums::ReviewOnly);
61 }
62 }
63
64 if let Some(storage_class) = ¶ms.storage_class {
65 put_object_params = put_object_params.storage_class(storage_class.clone());
66 }
67 let (sse_type, key_id) = params.server_side_encryption.clone().into_inner()?;
72 put_object_params = put_object_params.server_side_encryption(sse_type);
73 put_object_params = put_object_params.ssekms_key_id(key_id);
74
75 let put_bucket = params.bucket.to_owned();
76 let put_key = params.key.to_owned();
77 let maximum_upload_size = client.write_part_size().saturating_mul(MAX_S3_MULTIPART_UPLOAD_PARTS);
78 let request = runtime
79 .spawn_with_result(async move { client.put_object(&put_bucket, &put_key, &put_object_params).await })
80 .unwrap();
81
82 Ok(UploadRequest {
83 request,
84 bucket: params.bucket,
85 key: params.key,
86 next_request_offset: 0,
87 hasher: crc32c::Hasher::new(),
88 maximum_upload_size,
89 sse: params.server_side_encryption,
90 })
91 }
92
93 pub fn size(&self) -> u64 {
94 self.next_request_offset
95 }
96
97 pub async fn write(&mut self, offset: i64, data: &[u8]) -> Result<usize, UploadError<Client::ClientError>> {
98 let next_offset = self.next_request_offset;
99 if offset != next_offset as i64 {
100 return Err(UploadError::OutOfOrderWrite {
101 write_offset: offset as u64,
102 expected_offset: next_offset,
103 });
104 }
105 if next_offset + data.len() as u64 > self.maximum_upload_size as u64 {
106 return Err(UploadError::ObjectTooBig {
107 maximum_size: self.maximum_upload_size,
108 });
109 }
110
111 self.hasher.update(data);
112 self.request
113 .get_mut()
114 .await?
115 .ok_or(UploadError::UploadAlreadyTerminated)?
116 .write(data)
117 .await?;
118
119 self.next_request_offset += data.len() as u64;
120 Ok(data.len())
121 }
122
123 pub async fn complete(self) -> Result<PutObjectResult, UploadError<Client::ClientError>> {
124 let size = self.size();
125 let checksum = self.hasher.finalize();
126 let result = self
127 .request
128 .into_inner()
129 .await?
130 .ok_or(UploadError::UploadAlreadyTerminated)?
131 .review_and_complete(move |review| verify_checksums(review, size, checksum))
132 .await?;
133 if let Err(err) = self
134 .sse
135 .verify_response(result.sse_type.as_deref(), result.sse_kms_key_id.as_deref())
136 {
137 error!(key=?self.key, error=?err, "SSE settings were corrupted after the upload completion");
138 std::process::exit(1);
144 }
145 Ok(result)
146 }
147}
148
149impl<Client: ObjectClient> Debug for UploadRequest<Client> {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("UploadRequest")
152 .field("bucket", &self.bucket)
153 .field("key", &self.key)
154 .field("next_request_offset", &self.next_request_offset)
155 .field("hasher", &self.hasher)
156 .finish_non_exhaustive()
157 }
158}
159
160fn verify_checksums(review: UploadReview, expected_size: u64, expected_checksum: Crc32c) -> bool {
161 let mut uploaded_size = 0u64;
162 let mut uploaded_checksum = Crc32c::new(0);
163 for (i, part) in review.parts.iter().enumerate() {
164 uploaded_size += part.size;
165
166 let Some(checksum) = &part.checksum else {
167 error!(part_number = i + 1, "missing part checksum");
168 return false;
169 };
170 let checksum = match crc32c_from_base64(checksum) {
171 Ok(checksum) => checksum,
172 Err(error) => {
173 error!(part_number = i + 1, ?error, "error decoding part checksum");
174 return false;
175 }
176 };
177
178 uploaded_checksum = combine_checksums(uploaded_checksum, checksum, part.size as usize);
179 }
180
181 if uploaded_size != expected_size {
182 error!(
183 uploaded_size,
184 expected_size, "Total uploaded size differs from expected size"
185 );
186 return false;
187 }
188
189 if uploaded_checksum != expected_checksum {
190 error!(
191 ?uploaded_checksum,
192 ?expected_checksum,
193 "Combined checksum of all uploaded parts differs from expected checksum"
194 );
195 return false;
196 }
197
198 true
199}
200
201#[cfg(test)]
202mod tests {
203 use std::collections::HashMap;
204
205 use crate::fs::SseCorruptedError;
206 use crate::mem_limiter::{MINIMUM_MEM_LIMIT, MemoryLimiter};
207 use crate::memory::PagedPool;
208 use crate::sync::Arc;
209 use crate::upload::{Uploader, UploaderConfig};
210
211 use futures::executor::ThreadPool;
212 use mountpoint_s3_client::failure_client::{CountdownFailureConfig, countdown_failure_client};
213 use mountpoint_s3_client::mock_client::{MockClient, MockClientError};
214 use mountpoint_s3_client::types::ChecksumAlgorithm;
215 use test_case::test_case;
216
217 use super::*;
218
219 fn new_uploader_for_test<Client>(
220 client: Client,
221 storage_class: Option<String>,
222 server_side_encryption: ServerSideEncryption,
223 use_additional_checksums: bool,
224 ) -> Uploader<Client>
225 where
226 Client: ObjectClient + Clone + Send + Sync + 'static,
227 {
228 let buffer_size = client.write_part_size();
229 let pool = PagedPool::new_with_candidate_sizes([buffer_size]);
230 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
231 let mem_limiter = MemoryLimiter::new(pool.clone(), MINIMUM_MEM_LIMIT);
232 Uploader::new(
233 client,
234 runtime,
235 pool,
236 mem_limiter.into(),
237 UploaderConfig::new(buffer_size)
238 .storage_class(storage_class)
239 .server_side_encryption(server_side_encryption)
240 .default_checksum_algorithm(use_additional_checksums.then_some(ChecksumAlgorithm::Crc32c)),
241 )
242 }
243
244 #[tokio::test]
245 async fn complete_test() {
246 let bucket = "bucket";
247 let name = "hello";
248 let key = name;
249
250 let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
251 let uploader = new_uploader_for_test(client.clone(), None, ServerSideEncryption::default(), true);
252 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
253
254 _ = request.write(0, &[]).await.unwrap();
255
256 assert!(!client.contains_key(key));
257 assert!(client.is_upload_in_progress(key));
258
259 request.complete().await.unwrap();
260
261 assert!(client.contains_key(key));
262 assert!(!client.is_upload_in_progress(key));
263 }
264
265 #[tokio::test]
266 async fn write_order_test() {
267 let bucket = "bucket";
268 let name = "hello";
269 let key = name;
270 let storage_class = "INTELLIGENT_TIERING";
271
272 let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
273 let uploader = new_uploader_for_test(
274 client.clone(),
275 Some(storage_class.to_owned()),
276 ServerSideEncryption::default(),
277 true,
278 );
279
280 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
281
282 let data = b"foo";
283 let mut offset = 0;
284 offset += request.write(offset, data).await.unwrap() as i64;
285
286 request
287 .write(0, data)
288 .await
289 .expect_err("out of order write should fail");
290
291 offset += request
292 .write(offset, data)
293 .await
294 .expect("subsequent in order write should succeed") as i64;
295
296 let size = request.size();
297 assert_eq!(offset, size as i64);
298
299 request.complete().await.unwrap();
300 assert!(client.contains_key(key));
301 }
302
303 #[tokio::test]
304 async fn failure_test() {
305 let bucket = "bucket";
306 let name = "hello";
307 let key = name;
308
309 let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
310
311 let mut put_failures = HashMap::new();
312 put_failures.insert(1, Ok((1, MockClientError("error".to_owned().into()))));
313 put_failures.insert(2, Ok((2, MockClientError("error".to_owned().into()))));
314 put_failures.insert(3, Err(MockClientError("error".to_owned().into()).into()));
315
316 let failure_client = Arc::new(countdown_failure_client(
317 client.clone(),
318 CountdownFailureConfig {
319 put_failures,
320 ..Default::default()
321 },
322 ));
323
324 let uploader = new_uploader_for_test(failure_client.clone(), None, ServerSideEncryption::default(), true);
325
326 {
328 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
329
330 let data = b"foo";
331 request.write(0, data).await.expect_err("first write should fail");
332 }
333 assert!(!client.is_upload_in_progress(key));
334 assert!(!client.contains_key(key));
335
336 {
338 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
339
340 let data = b"foo";
341 _ = request.write(0, data).await.unwrap();
342
343 request.complete().await.expect_err("complete should fail");
344 }
345 assert!(!client.is_upload_in_progress(key));
346 assert!(!client.contains_key(key));
347
348 {
350 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
351
352 let data = b"foo";
353 request.write(0, data).await.expect_err("first write should fail");
354
355 let err = request
356 .write(0, data)
357 .await
358 .expect_err("subsequent writes should also fail");
359 assert!(matches!(err, UploadError::UploadAlreadyTerminated));
360
361 let err = request.complete().await.expect_err("complete should also fail");
362 assert!(matches!(err, UploadError::UploadAlreadyTerminated));
363 }
364 assert!(!client.is_upload_in_progress(key));
365 assert!(!client.contains_key(key));
366 }
367
368 #[test_case(8000; "divisible by max size")]
369 #[test_case(7000; "not divisible by max size")]
370 #[test_case(320001; "single write too big")]
371 #[tokio::test]
372 async fn maximum_size_test(write_size: usize) {
373 const PART_SIZE: usize = 32;
374
375 let bucket = "bucket";
376 let name = "hello";
377 let key = name;
378
379 let client = Arc::new(MockClient::config().bucket(bucket).part_size(PART_SIZE).build());
380 let uploader = new_uploader_for_test(client.clone(), None, ServerSideEncryption::default(), true);
381 let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
382
383 let successful_writes = PART_SIZE * MAX_S3_MULTIPART_UPLOAD_PARTS / write_size;
384 let data = vec![0xaa; write_size];
385 for i in 0..successful_writes {
386 let offset = i * write_size;
387 request.write(offset as i64, &data).await.expect("object should fit");
388 assert!(client.is_upload_in_progress(key));
389 }
390
391 let offset = successful_writes * write_size;
392 request
393 .write(offset as i64, &data)
394 .await
395 .expect_err("object should be too big");
396
397 drop(request);
398
399 assert!(!client.contains_key(key));
400 assert!(!client.is_upload_in_progress(key));
401 }
402
403 #[test_case(Some("aws:kmr"), Some("some_key_alias"))]
404 #[test_case(Some("aws:kms"), Some("some_key_ali`s"))]
405 #[test_case(None, Some("some_key_alias"))]
406 #[test_case(Some("aws:kms"), None)]
407 #[tokio::test]
408 async fn put_with_corrupted_sse_test(sse_type_corrupted: Option<&str>, key_id_corrupted: Option<&str>) {
409 let client = Arc::new(MockClient::config().build());
410 let mut uploader = new_uploader_for_test(
411 client,
412 None,
413 ServerSideEncryption::new(Some("aws:kms".to_string()), Some("some_key_alias".to_string())),
414 true,
415 );
416 uploader
417 .server_side_encryption
418 .corrupt_data(sse_type_corrupted.map(String::from), key_id_corrupted.map(String::from));
419 let err = uploader
420 .start_atomic_upload("bucket".to_owned(), "hello".to_owned())
421 .expect_err("sse checksum must be checked");
422 assert!(matches!(
423 err,
424 UploadError::SseCorruptedError(SseCorruptedError::ChecksumMismatch(_, _))
425 ));
426 }
427
428 #[tokio::test]
429 async fn put_with_good_sse_test() {
430 let bucket = "bucket";
431 let name = "hello";
432 let key = name;
433
434 let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
435 let uploader = new_uploader_for_test(
436 client,
437 None,
438 ServerSideEncryption::new(Some("aws:kms".to_string()), Some("some_key".to_string())),
439 true,
440 );
441 uploader
442 .start_atomic_upload(bucket.to_owned(), key.to_owned())
443 .expect("put with sse should succeed");
444 }
445}