infinitree_backends/
s3.rs

1use super::block_on;
2use anyhow::Context;
3use infinitree::{
4    backends::{Backend, Result},
5    object::{Object, ObjectId, ReadBuffer, ReadObject, WriteObject},
6};
7use reqwest::Client;
8pub use rusty_s3::{Bucket, Credentials};
9use rusty_s3::{S3Action, UrlStyle};
10use scc::HashMap;
11use std::{env, future::Future, sync::Arc, time::Duration};
12use tokio::{
13    sync::Semaphore,
14    task::{self, JoinError, JoinHandle},
15};
16
17mod region;
18pub use region::Region;
19
20struct InFlightTracker<TaskResult>
21where
22    TaskResult: 'static + Send,
23{
24    permits: Arc<Semaphore>,
25    active: Arc<HashMap<ObjectId, Arc<Option<JoinHandle<TaskResult>>>>>,
26}
27
28impl<TaskResult> Default for InFlightTracker<TaskResult>
29where
30    TaskResult: 'static + Send,
31{
32    fn default() -> Self {
33        Self {
34            permits: Semaphore::new(std::thread::available_parallelism().unwrap().get()).into(),
35            active: Arc::default(),
36        }
37    }
38}
39
40impl<TaskResult> Clone for InFlightTracker<TaskResult>
41where
42    TaskResult: 'static + Send,
43{
44    fn clone(&self) -> Self {
45        Self {
46            permits: self.permits.clone(),
47            active: self.active.clone(),
48        }
49    }
50}
51
52impl<TaskResult> InFlightTracker<TaskResult>
53where
54    TaskResult: 'static + Send,
55{
56    pub fn complete_all(&self) -> std::result::Result<Vec<TaskResult>, JoinError> {
57        block_on(async move {
58            let mut handles = vec![];
59            self.active
60                .retain_async(|_, v| {
61                    if let Some(handle) = std::mem::take(Arc::get_mut(v).unwrap()) {
62                        handles.push(handle);
63                    }
64
65                    false
66                })
67                .await;
68
69            futures::future::join_all(handles).await
70        })
71        .into_iter()
72        .filter(|result| match result {
73            Ok(_) => true,
74            Err(e) => !e.is_cancelled(),
75        })
76        .collect::<std::result::Result<Vec<_>, _>>()
77    }
78
79    pub fn add_task<F: 'static + Send + Future<Output = TaskResult>>(
80        &self,
81        key: ObjectId,
82        task: F,
83    ) {
84        let permits = self.permits.clone();
85        let active = self.active.clone();
86
87        block_on(async {
88            let permit = permits.acquire_owned().await;
89
90            let handle = Arc::new(Some(task::spawn(async move {
91                let _permit = permit;
92                let result = task.await;
93                active.remove_async(&key).await;
94                result
95            })));
96
97            match self.active.entry_async(key).await {
98                scc::hash_map::Entry::Occupied(mut entry) => {
99                    if let Some(handle) = entry.get().as_ref() {
100                        handle.abort();
101                    }
102
103                    *entry.get_mut() = handle.clone();
104                }
105                scc::hash_map::Entry::Vacant(entry) => {
106                    entry.insert_entry(handle);
107                }
108            }
109        })
110    }
111}
112
113#[derive(Clone)]
114pub struct S3 {
115    base_path: String,
116    client: Client,
117    bucket: Arc<Bucket>,
118    credentials: Arc<Credentials>,
119    in_flight: InFlightTracker<anyhow::Result<u16>>,
120}
121
122impl S3 {
123    pub fn new(region: Region, bucket: impl AsRef<str>) -> Result<Arc<Self>> {
124        let access_key = env::var("AWS_ACCESS_KEY_ID").context("Invalid credentials")?;
125        let secret_key = env::var("AWS_SECRET_ACCESS_KEY").context("Invalid credentials")?;
126
127        let creds = Credentials::new(access_key, secret_key);
128        Self::with_credentials(region, bucket, creds)
129    }
130
131    pub fn with_credentials(
132        region: Region,
133        bucket: impl AsRef<str>,
134        creds: Credentials,
135    ) -> Result<Arc<Self>> {
136        let (bucket_name, base_path) = match bucket.as_ref().split_once('/') {
137            Some((bucket, "")) => (bucket.to_string(), "".to_string()),
138            Some((bucket, path)) => (bucket.to_string(), format!("{path}/")),
139            None => (bucket.as_ref().to_string(), "".to_string()),
140        };
141
142        let bucket = Bucket::new(
143            region.endpoint().parse().context("Invalid endpoint URL")?,
144            if let Region::Custom { .. } = region {
145                UrlStyle::Path
146            } else {
147                UrlStyle::VirtualHost
148            },
149            bucket_name,
150            region.to_string(),
151        )
152        .context("Failed to connect to S3 bucket")?
153        .into();
154
155        Ok(Self {
156            bucket,
157            base_path,
158            client: reqwest::Client::new(),
159            credentials: creds.into(),
160            in_flight: InFlightTracker::default(),
161        }
162        .into())
163    }
164
165    fn get_path(&self, id: &ObjectId) -> String {
166        // note that `base_path` automatically has a "/" appended to the string
167        format!("{}{}", &self.base_path, id.to_string())
168    }
169}
170
171impl Backend for S3 {
172    fn write_object(&self, object: &WriteObject) -> Result<()> {
173        let body = object.as_inner().to_vec();
174        let key = self.get_path(object.id());
175        let id = *object.id();
176
177        let this = self.clone();
178        self.in_flight.add_task(id, async move {
179            let url = this
180                .bucket
181                .put_object(Some(&this.credentials), &key)
182                .sign(Duration::from_secs(30));
183
184            let resp = this
185                .client
186                .put(url)
187                .body(body)
188                .send()
189                .await
190                .expect("Server error");
191
192            let status_code = resp.status().as_u16();
193            let resp_body = resp.bytes().await.expect("Response error");
194            if (200..300).contains(&status_code) {
195                Ok(status_code)
196            } else {
197                panic!(
198                    "Bad response: {}, {}",
199                    status_code,
200                    String::from_utf8_lossy(resp_body.as_ref())
201                )
202            }
203        });
204
205        Ok(())
206    }
207
208    fn read_object(&self, id: &ObjectId) -> Result<Arc<ReadObject>> {
209        let this = self.clone();
210        let object: Result<Vec<u8>> = {
211            let key = self.get_path(id);
212
213            block_on(async move {
214                let url = this
215                    .bucket
216                    .get_object(Some(&this.credentials), &key)
217                    .sign(Duration::from_secs(30));
218
219                let resp = this.client.get(url).send().await.context("Query error")?;
220                let status_code = resp.status().as_u16();
221                let body = resp.bytes().await.context("Read error")?;
222
223                if (200..300).contains(&status_code) {
224                    Ok(body.to_vec())
225                } else {
226                    Err(anyhow::anyhow!(
227                        "Bad response: {}, {}",
228                        status_code,
229                        String::from_utf8_lossy(body.as_ref())
230                    )
231                    .into())
232                }
233            })
234        };
235
236        Ok(Arc::new(Object::with_id(*id, ReadBuffer::new(object?))))
237    }
238
239    fn sync(&self) -> Result<()> {
240        self.in_flight
241            .complete_all()
242            .context("Failed transactions with server")?;
243
244        Ok(())
245    }
246}
247
248#[cfg(test)]
249mod test {
250    use super::S3;
251    use crate::test::{write_and_wait_for_commit, TEST_DATA_DIR};
252    use hyper::server::conn::http1;
253    use hyper_util::rt::TokioIo;
254    use infinitree::{backends::Backend, object::WriteObject, ObjectId};
255    use s3s::{auth::SimpleAuth, service::S3ServiceBuilder};
256    use s3s_fs::FileSystem;
257    use std::net::SocketAddr;
258    use tokio::{net::TcpListener, task};
259
260    const AWS_ACCESS_KEY_ID: &str = "MEEMIEW3EEKI8IEY1U";
261    const AWS_SECRET_ACCESS_KEY_ID: &str = "noh8xah2thohv7laehei2lahBuno5FameiNi";
262
263    const SERVER_ADDR_RW: ([u8; 4], u16) = ([127, 0, 0, 1], 12312);
264    const SERVER_ADDR_RO: ([u8; 4], u16) = ([127, 0, 0, 1], 12313);
265
266    fn setup_s3_server(addr: &SocketAddr) {
267        let fs = FileSystem::new(TEST_DATA_DIR).unwrap();
268        let mut auth = SimpleAuth::new();
269        std::env::set_var("AWS_ACCESS_KEY_ID", AWS_ACCESS_KEY_ID);
270        std::env::set_var("AWS_SECRET_ACCESS_KEY", AWS_SECRET_ACCESS_KEY_ID);
271        auth.register(AWS_ACCESS_KEY_ID.into(), AWS_SECRET_ACCESS_KEY_ID.into());
272
273        let service = {
274            let mut b = S3ServiceBuilder::new(fs);
275            b.set_auth(auth);
276            b.build().into_shared()
277        };
278
279        let server = {
280            let addr = addr.clone();
281            task::spawn(async move {
282                let listener = TcpListener::bind(addr).await.unwrap();
283                loop {
284                    let (tcp, _) = listener.accept().await.unwrap();
285                    let io = TokioIo::new(tcp);
286                    let service = service.clone();
287
288                    tokio::task::spawn(async move {
289                        if let Err(err) = http1::Builder::new().serve_connection(io, service).await
290                        {
291                            println!("Error serving connection: {:?}", err);
292                        }
293                    });
294                }
295            })
296        };
297
298        let _server_handle = task::spawn(server);
299    }
300
301    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
302    async fn s3_write_read() {
303        let addr = SocketAddr::from(SERVER_ADDR_RW);
304        setup_s3_server(&addr);
305
306        let backend = S3::new(format!("http://{addr}").parse().unwrap(), "bucket").unwrap();
307
308        let mut object = WriteObject::default();
309        let id_2 = ObjectId::from_bytes(b"1234567890abcdef1234567890abcdef");
310
311        write_and_wait_for_commit(backend.as_ref(), &object);
312        let _obj_1_read_ref = backend.read_object(object.id()).unwrap();
313
314        object.set_id(id_2);
315        write_and_wait_for_commit(backend.as_ref(), &object);
316    }
317
318    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
319    #[should_panic(
320        expected = r#"Generic { source: Bad response: 404, <?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code></Error> }"#
321    )]
322    async fn s3_reading_nonexistent_object() {
323        let addr = SocketAddr::from(SERVER_ADDR_RO);
324        setup_s3_server(&addr);
325
326        let backend = S3::new(format!("http://{addr}").parse().unwrap(), "bucket").unwrap();
327
328        let id = ObjectId::from_bytes(b"2222222222abcdef1234567890abcdef");
329
330        let _obj_1_read_ref = backend.read_object(&id).unwrap();
331    }
332}