debian_packaging/repository/
s3.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use {
6    crate::{
7        error::{DebianError, Result},
8        io::{ContentDigest, MultiDigester},
9        repository::{
10            RepositoryPathVerification, RepositoryPathVerificationState, RepositoryWrite,
11            RepositoryWriter,
12        },
13    },
14    async_trait::async_trait,
15    futures::{AsyncRead, AsyncReadExt as FuturesAsyncReadExt},
16    rusoto_core::{ByteStream, Client, Region, RusotoError},
17    rusoto_s3::{
18        GetBucketLocationRequest, GetObjectError, GetObjectRequest, HeadObjectError,
19        HeadObjectRequest, PutObjectRequest, S3Client, S3,
20    },
21    std::{borrow::Cow, pin::Pin, str::FromStr},
22    tokio::io::AsyncReadExt as TokioAsyncReadExt,
23};
24
25pub struct S3Writer {
26    client: S3Client,
27    bucket: String,
28    key_prefix: Option<String>,
29}
30
31impl S3Writer {
32    /// Create a new S3 writer bound to a named bucket with optional key prefix.
33    ///
34    /// This will construct a default AWS [Client].
35    pub fn new(region: Region, bucket: impl ToString, key_prefix: Option<&str>) -> Self {
36        Self {
37            client: S3Client::new(region),
38            bucket: bucket.to_string(),
39            key_prefix: key_prefix.map(|x| x.trim_matches('/').to_string()),
40        }
41    }
42
43    /// Create a new S3 writer bound to a named bucket, optional key prefix, with an AWS [Client].
44    ///
45    /// This is like [Self::new()] except the caller can pass in the AWS [Client] to use.
46    pub fn new_with_client(
47        client: Client,
48        region: Region,
49        bucket: impl ToString,
50        key_prefix: Option<&str>,
51    ) -> Self {
52        Self {
53            client: S3Client::new_with_client(client, region),
54            bucket: bucket.to_string(),
55            key_prefix: key_prefix.map(|x| x.trim_matches('/').to_string()),
56        }
57    }
58
59    /// Compute the S3 key name given a repository relative path.
60    pub fn path_to_key(&self, path: &str) -> String {
61        if let Some(prefix) = &self.key_prefix {
62            format!("{}/{}", prefix, path.trim_matches('/'))
63        } else {
64            path.trim_matches('/').to_string()
65        }
66    }
67}
68
69#[async_trait]
70impl RepositoryWriter for S3Writer {
71    async fn verify_path<'path>(
72        &self,
73        path: &'path str,
74        expected_content: Option<(u64, ContentDigest)>,
75    ) -> Result<RepositoryPathVerification<'path>> {
76        if let Some((expected_size, expected_digest)) = expected_content {
77            let req = GetObjectRequest {
78                bucket: self.bucket.clone(),
79                key: self.path_to_key(path),
80                ..Default::default()
81            };
82
83            match self.client.get_object(req).await {
84                Ok(output) => {
85                    // Fast path without having to ready the body.
86                    if let Some(cl) = output.content_length {
87                        if cl as u64 != expected_size {
88                            return Ok(RepositoryPathVerification {
89                                path,
90                                state: RepositoryPathVerificationState::ExistsIntegrityMismatch,
91                            });
92                        }
93                    }
94
95                    if let Some(body) = output.body {
96                        let mut digester = MultiDigester::default();
97
98                        let mut remaining = expected_size;
99                        let mut reader = body.into_async_read();
100                        let mut buf = [0u8; 16384];
101
102                        loop {
103                            let size = reader
104                                .read(&mut buf[..])
105                                .await
106                                .map_err(|e| DebianError::RepositoryIoPath(path.to_string(), e))?;
107
108                            digester.update(&buf[0..size]);
109
110                            let size = size as u64;
111
112                            if size >= remaining || size == 0 {
113                                break;
114                            }
115
116                            remaining -= size;
117                        }
118
119                        let digests = digester.finish();
120
121                        Ok(RepositoryPathVerification {
122                            path,
123                            state: if !digests.matches_digest(&expected_digest) {
124                                RepositoryPathVerificationState::ExistsIntegrityMismatch
125                            } else {
126                                RepositoryPathVerificationState::ExistsIntegrityVerified
127                            },
128                        })
129                    } else {
130                        Ok(RepositoryPathVerification {
131                            path,
132                            state: RepositoryPathVerificationState::Missing,
133                        })
134                    }
135                }
136                Err(RusotoError::Service(GetObjectError::NoSuchKey(_))) => {
137                    Ok(RepositoryPathVerification {
138                        path,
139                        state: RepositoryPathVerificationState::Missing,
140                    })
141                }
142                Err(e) => Err(DebianError::RepositoryIoPath(
143                    path.to_string(),
144                    std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
145                )),
146            }
147        } else {
148            let req = HeadObjectRequest {
149                bucket: self.bucket.clone(),
150                key: self.path_to_key(path),
151                ..Default::default()
152            };
153
154            match self.client.head_object(req).await {
155                Ok(_) => Ok(RepositoryPathVerification {
156                    path,
157                    state: RepositoryPathVerificationState::ExistsNoIntegrityCheck,
158                }),
159                Err(RusotoError::Service(HeadObjectError::NoSuchKey(_))) => {
160                    Ok(RepositoryPathVerification {
161                        path,
162                        state: RepositoryPathVerificationState::Missing,
163                    })
164                }
165                Err(e) => Err(DebianError::RepositoryIoPath(
166                    path.to_string(),
167                    std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
168                )),
169            }
170        }
171    }
172
173    async fn write_path<'path, 'reader>(
174        &self,
175        path: Cow<'path, str>,
176        mut reader: Pin<Box<dyn AsyncRead + Send + 'reader>>,
177    ) -> Result<RepositoryWrite<'path>> {
178        // rusoto wants a Stream<Bytes>. There's no easy way to convert from an AsyncRead to a
179        // Stream. So we just buffer content locally.
180        // TODO implement this properly
181        let mut buf = vec![];
182        reader
183            .read_to_end(&mut buf)
184            .await
185            .map_err(|e| DebianError::RepositoryIoPath(path.to_string(), e))?;
186
187        let bytes_written = buf.len() as u64;
188        let stream = futures::stream::once(async { Ok(bytes::Bytes::from(buf)) });
189
190        let req = PutObjectRequest {
191            bucket: self.bucket.clone(),
192            key: self.path_to_key(path.as_ref()),
193            body: Some(ByteStream::new(stream)),
194            ..Default::default()
195        };
196
197        match self.client.put_object(req).await {
198            Ok(_) => Ok(RepositoryWrite {
199                path,
200                bytes_written,
201            }),
202            Err(e) => Err(DebianError::RepositoryIoPath(
203                path.to_string(),
204                std::io::Error::new(std::io::ErrorKind::Other, format!("S3 error: {:?}", e)),
205            )),
206        }
207    }
208}
209
210/// Attempt to resolve the AWS region of an S3 bucket.
211pub async fn get_bucket_region(bucket: impl ToString) -> Result<Region> {
212    get_bucket_region_with_client(S3Client::new(Region::UsEast1), bucket).await
213}
214
215/// Attempt to resolve the AWS region of an S3 bucket using a provided [S3Client].
216pub async fn get_bucket_region_with_client(
217    client: S3Client,
218    bucket: impl ToString,
219) -> Result<Region> {
220    let req = GetBucketLocationRequest {
221        bucket: bucket.to_string(),
222        ..Default::default()
223    };
224
225    match client.get_bucket_location(req).await {
226        Ok(res) => {
227            if let Some(constraint) = res.location_constraint {
228                Ok(Region::from_str(&constraint)
229                    .map_err(|_| DebianError::S3BadRegion(constraint))?)
230            } else {
231                Ok(Region::UsEast1)
232            }
233        }
234        Err(e) => Err(DebianError::Io(std::io::Error::new(
235            std::io::ErrorKind::Other,
236            format!("S3 error: {:?}", e),
237        ))),
238    }
239}