Skip to main content

mountpoint_s3_fs/upload/
atomic.rs

1use std::fmt::Debug;
2
3use mountpoint_s3_client::checksums::{Crc32c, crc32c, crc32c_from_base64};
4use mountpoint_s3_client::error::{ObjectClientError, PutObjectError};
5use mountpoint_s3_client::types::{
6    ChecksumAlgorithm, PutObjectParams, PutObjectResult, PutObjectTrailingChecksums, UploadReview,
7};
8use mountpoint_s3_client::{ObjectClient, PutObjectRequest};
9use tracing::error;
10
11use crate::ServerSideEncryption;
12use crate::async_util::{RemoteResult, Runtime};
13use crate::checksums::combine_checksums;
14
15use super::UploadError;
16
17const MAX_S3_MULTIPART_UPLOAD_PARTS: usize = 10000;
18
19/// Manages the upload of an object to S3.
20///
21/// Wraps a PutObject request and enforces sequential writes.
22pub struct UploadRequest<Client: ObjectClient> {
23    request: RemoteResult<Client::PutObjectRequest, ObjectClientError<PutObjectError, Client::ClientError>>,
24    bucket: String,
25    key: String,
26    next_request_offset: u64,
27    hasher: crc32c::Hasher,
28    maximum_upload_size: usize,
29    sse: ServerSideEncryption,
30}
31
32/// Parameters to initialize an [UploadRequest].
33pub struct UploadRequestParams {
34    pub bucket: String,
35    pub key: String,
36    pub server_side_encryption: ServerSideEncryption,
37    pub default_checksum_algorithm: Option<ChecksumAlgorithm>,
38    pub storage_class: Option<String>,
39}
40
41impl<Client> UploadRequest<Client>
42where
43    Client: ObjectClient + Send + 'static,
44{
45    pub fn new(
46        runtime: &Runtime,
47        client: Client,
48        params: UploadRequestParams,
49    ) -> Result<Self, UploadError<Client::ClientError>> {
50        let mut put_object_params = PutObjectParams::new();
51
52        match &params.default_checksum_algorithm {
53            Some(ChecksumAlgorithm::Crc32c) => {
54                put_object_params = put_object_params.trailing_checksums(PutObjectTrailingChecksums::Enabled);
55            }
56            Some(unsupported) => {
57                unimplemented!("checksum algorithm not supported: {:?}", unsupported);
58            }
59            None => {
60                put_object_params = put_object_params.trailing_checksums(PutObjectTrailingChecksums::ReviewOnly);
61            }
62        }
63
64        if let Some(storage_class) = &params.storage_class {
65            put_object_params = put_object_params.storage_class(storage_class.clone());
66        }
67        // If we have detected corruption of SSE settings, we return an error, which will currently be reported as
68        // `libc::EIO` on `open()`. MP won't be able to open files for write from this point, but this is a relatively
69        // low-risk error as data can not be uploaded with wrong SSE settings yet. Thus there is no strong reason for
70        // MP to crash and it may continue serving read's.
71        let (sse_type, key_id) = params.server_side_encryption.clone().into_inner()?;
72        put_object_params = put_object_params.server_side_encryption(sse_type);
73        put_object_params = put_object_params.ssekms_key_id(key_id);
74
75        let put_bucket = params.bucket.to_owned();
76        let put_key = params.key.to_owned();
77        let maximum_upload_size = client.write_part_size().saturating_mul(MAX_S3_MULTIPART_UPLOAD_PARTS);
78        let request = runtime
79            .spawn_with_result(async move { client.put_object(&put_bucket, &put_key, &put_object_params).await })
80            .unwrap();
81
82        Ok(UploadRequest {
83            request,
84            bucket: params.bucket,
85            key: params.key,
86            next_request_offset: 0,
87            hasher: crc32c::Hasher::new(),
88            maximum_upload_size,
89            sse: params.server_side_encryption,
90        })
91    }
92
93    pub fn size(&self) -> u64 {
94        self.next_request_offset
95    }
96
97    pub async fn write(&mut self, offset: i64, data: &[u8]) -> Result<usize, UploadError<Client::ClientError>> {
98        let next_offset = self.next_request_offset;
99        if offset != next_offset as i64 {
100            return Err(UploadError::OutOfOrderWrite {
101                write_offset: offset as u64,
102                expected_offset: next_offset,
103            });
104        }
105        if next_offset + data.len() as u64 > self.maximum_upload_size as u64 {
106            return Err(UploadError::ObjectTooBig {
107                maximum_size: self.maximum_upload_size,
108            });
109        }
110
111        self.hasher.update(data);
112        self.request
113            .get_mut()
114            .await?
115            .ok_or(UploadError::UploadAlreadyTerminated)?
116            .write(data)
117            .await?;
118
119        self.next_request_offset += data.len() as u64;
120        Ok(data.len())
121    }
122
123    pub async fn complete(self) -> Result<PutObjectResult, UploadError<Client::ClientError>> {
124        let size = self.size();
125        let checksum = self.hasher.finalize();
126        let result = self
127            .request
128            .into_inner()
129            .await?
130            .ok_or(UploadError::UploadAlreadyTerminated)?
131            .review_and_complete(move |review| verify_checksums(review, size, checksum))
132            .await?;
133        if let Err(err) = self
134            .sse
135            .verify_response(result.sse_type.as_deref(), result.sse_kms_key_id.as_deref())
136        {
137            error!(key=?self.key, error=?err, "SSE settings were corrupted after the upload completion");
138            // Reaching this point is very unlikely and means that SSE settings were corrupted in transit or on S3 side, this may be a sign of a bug
139            // in CRT code or S3. Thus, we terminate Mountpoint to send the most noticeable signal to customer about the issue. We prefer exiting
140            // instead of returning an error because:
141            // 1. this error would only be reported on `flush` which many applications ignore and
142            // 2. the reported error is severe as the object was already uploaded to S3.
143            std::process::exit(1);
144        }
145        Ok(result)
146    }
147}
148
149impl<Client: ObjectClient> Debug for UploadRequest<Client> {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("UploadRequest")
152            .field("bucket", &self.bucket)
153            .field("key", &self.key)
154            .field("next_request_offset", &self.next_request_offset)
155            .field("hasher", &self.hasher)
156            .finish_non_exhaustive()
157    }
158}
159
160fn verify_checksums(review: UploadReview, expected_size: u64, expected_checksum: Crc32c) -> bool {
161    let mut uploaded_size = 0u64;
162    let mut uploaded_checksum = Crc32c::new(0);
163    for (i, part) in review.parts.iter().enumerate() {
164        uploaded_size += part.size;
165
166        let Some(checksum) = &part.checksum else {
167            error!(part_number = i + 1, "missing part checksum");
168            return false;
169        };
170        let checksum = match crc32c_from_base64(checksum) {
171            Ok(checksum) => checksum,
172            Err(error) => {
173                error!(part_number = i + 1, ?error, "error decoding part checksum");
174                return false;
175            }
176        };
177
178        uploaded_checksum = combine_checksums(uploaded_checksum, checksum, part.size as usize);
179    }
180
181    if uploaded_size != expected_size {
182        error!(
183            uploaded_size,
184            expected_size, "Total uploaded size differs from expected size"
185        );
186        return false;
187    }
188
189    if uploaded_checksum != expected_checksum {
190        error!(
191            ?uploaded_checksum,
192            ?expected_checksum,
193            "Combined checksum of all uploaded parts differs from expected checksum"
194        );
195        return false;
196    }
197
198    true
199}
200
201#[cfg(test)]
202mod tests {
203    use std::collections::HashMap;
204
205    use crate::fs::SseCorruptedError;
206    use crate::mem_limiter::{MINIMUM_MEM_LIMIT, MemoryLimiter};
207    use crate::memory::PagedPool;
208    use crate::sync::Arc;
209    use crate::upload::{Uploader, UploaderConfig};
210
211    use futures::executor::ThreadPool;
212    use mountpoint_s3_client::failure_client::{CountdownFailureConfig, countdown_failure_client};
213    use mountpoint_s3_client::mock_client::{MockClient, MockClientError};
214    use mountpoint_s3_client::types::ChecksumAlgorithm;
215    use test_case::test_case;
216
217    use super::*;
218
219    fn new_uploader_for_test<Client>(
220        client: Client,
221        storage_class: Option<String>,
222        server_side_encryption: ServerSideEncryption,
223        use_additional_checksums: bool,
224    ) -> Uploader<Client>
225    where
226        Client: ObjectClient + Clone + Send + Sync + 'static,
227    {
228        let buffer_size = client.write_part_size();
229        let pool = PagedPool::new_with_candidate_sizes([buffer_size]);
230        let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
231        let mem_limiter = MemoryLimiter::new(pool.clone(), MINIMUM_MEM_LIMIT);
232        Uploader::new(
233            client,
234            runtime,
235            pool,
236            mem_limiter.into(),
237            UploaderConfig::new(buffer_size)
238                .storage_class(storage_class)
239                .server_side_encryption(server_side_encryption)
240                .default_checksum_algorithm(use_additional_checksums.then_some(ChecksumAlgorithm::Crc32c)),
241        )
242    }
243
244    #[tokio::test]
245    async fn complete_test() {
246        let bucket = "bucket";
247        let name = "hello";
248        let key = name;
249
250        let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
251        let uploader = new_uploader_for_test(client.clone(), None, ServerSideEncryption::default(), true);
252        let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
253
254        _ = request.write(0, &[]).await.unwrap();
255
256        assert!(!client.contains_key(key));
257        assert!(client.is_upload_in_progress(key));
258
259        request.complete().await.unwrap();
260
261        assert!(client.contains_key(key));
262        assert!(!client.is_upload_in_progress(key));
263    }
264
265    #[tokio::test]
266    async fn write_order_test() {
267        let bucket = "bucket";
268        let name = "hello";
269        let key = name;
270        let storage_class = "INTELLIGENT_TIERING";
271
272        let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
273        let uploader = new_uploader_for_test(
274            client.clone(),
275            Some(storage_class.to_owned()),
276            ServerSideEncryption::default(),
277            true,
278        );
279
280        let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
281
282        let data = b"foo";
283        let mut offset = 0;
284        offset += request.write(offset, data).await.unwrap() as i64;
285
286        request
287            .write(0, data)
288            .await
289            .expect_err("out of order write should fail");
290
291        offset += request
292            .write(offset, data)
293            .await
294            .expect("subsequent in order write should succeed") as i64;
295
296        let size = request.size();
297        assert_eq!(offset, size as i64);
298
299        request.complete().await.unwrap();
300        assert!(client.contains_key(key));
301    }
302
303    #[tokio::test]
304    async fn failure_test() {
305        let bucket = "bucket";
306        let name = "hello";
307        let key = name;
308
309        let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
310
311        let mut put_failures = HashMap::new();
312        put_failures.insert(1, Ok((1, MockClientError("error".to_owned().into()))));
313        put_failures.insert(2, Ok((2, MockClientError("error".to_owned().into()))));
314        put_failures.insert(3, Err(MockClientError("error".to_owned().into()).into()));
315
316        let failure_client = Arc::new(countdown_failure_client(
317            client.clone(),
318            CountdownFailureConfig {
319                put_failures,
320                ..Default::default()
321            },
322        ));
323
324        let uploader = new_uploader_for_test(failure_client.clone(), None, ServerSideEncryption::default(), true);
325
326        // First request fails on first write.
327        {
328            let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
329
330            let data = b"foo";
331            request.write(0, data).await.expect_err("first write should fail");
332        }
333        assert!(!client.is_upload_in_progress(key));
334        assert!(!client.contains_key(key));
335
336        // Second request fails on complete (after one write).
337        {
338            let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
339
340            let data = b"foo";
341            _ = request.write(0, data).await.unwrap();
342
343            request.complete().await.expect_err("complete should fail");
344        }
345        assert!(!client.is_upload_in_progress(key));
346        assert!(!client.contains_key(key));
347
348        // Third request fails on first write (because CreateMPU returns an error).
349        {
350            let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
351
352            let data = b"foo";
353            request.write(0, data).await.expect_err("first write should fail");
354
355            let err = request
356                .write(0, data)
357                .await
358                .expect_err("subsequent writes should also fail");
359            assert!(matches!(err, UploadError::UploadAlreadyTerminated));
360
361            let err = request.complete().await.expect_err("complete should also fail");
362            assert!(matches!(err, UploadError::UploadAlreadyTerminated));
363        }
364        assert!(!client.is_upload_in_progress(key));
365        assert!(!client.contains_key(key));
366    }
367
368    #[test_case(8000; "divisible by max size")]
369    #[test_case(7000; "not divisible by max size")]
370    #[test_case(320001; "single write too big")]
371    #[tokio::test]
372    async fn maximum_size_test(write_size: usize) {
373        const PART_SIZE: usize = 32;
374
375        let bucket = "bucket";
376        let name = "hello";
377        let key = name;
378
379        let client = Arc::new(MockClient::config().bucket(bucket).part_size(PART_SIZE).build());
380        let uploader = new_uploader_for_test(client.clone(), None, ServerSideEncryption::default(), true);
381        let mut request = uploader.start_atomic_upload(bucket.to_owned(), key.to_owned()).unwrap();
382
383        let successful_writes = PART_SIZE * MAX_S3_MULTIPART_UPLOAD_PARTS / write_size;
384        let data = vec![0xaa; write_size];
385        for i in 0..successful_writes {
386            let offset = i * write_size;
387            request.write(offset as i64, &data).await.expect("object should fit");
388            assert!(client.is_upload_in_progress(key));
389        }
390
391        let offset = successful_writes * write_size;
392        request
393            .write(offset as i64, &data)
394            .await
395            .expect_err("object should be too big");
396
397        drop(request);
398
399        assert!(!client.contains_key(key));
400        assert!(!client.is_upload_in_progress(key));
401    }
402
403    #[test_case(Some("aws:kmr"), Some("some_key_alias"))]
404    #[test_case(Some("aws:kms"), Some("some_key_ali`s"))]
405    #[test_case(None, Some("some_key_alias"))]
406    #[test_case(Some("aws:kms"), None)]
407    #[tokio::test]
408    async fn put_with_corrupted_sse_test(sse_type_corrupted: Option<&str>, key_id_corrupted: Option<&str>) {
409        let client = Arc::new(MockClient::config().build());
410        let mut uploader = new_uploader_for_test(
411            client,
412            None,
413            ServerSideEncryption::new(Some("aws:kms".to_string()), Some("some_key_alias".to_string())),
414            true,
415        );
416        uploader
417            .server_side_encryption
418            .corrupt_data(sse_type_corrupted.map(String::from), key_id_corrupted.map(String::from));
419        let err = uploader
420            .start_atomic_upload("bucket".to_owned(), "hello".to_owned())
421            .expect_err("sse checksum must be checked");
422        assert!(matches!(
423            err,
424            UploadError::SseCorruptedError(SseCorruptedError::ChecksumMismatch(_, _))
425        ));
426    }
427
428    #[tokio::test]
429    async fn put_with_good_sse_test() {
430        let bucket = "bucket";
431        let name = "hello";
432        let key = name;
433
434        let client = Arc::new(MockClient::config().bucket(bucket).part_size(32).build());
435        let uploader = new_uploader_for_test(
436            client,
437            None,
438            ServerSideEncryption::new(Some("aws:kms".to_string()), Some("some_key".to_string())),
439            true,
440        );
441        uploader
442            .start_atomic_upload(bucket.to_owned(), key.to_owned())
443            .expect("put with sse should succeed");
444    }
445}