1use super::block_on;
2use anyhow::Context;
3use infinitree::{
4 backends::{Backend, Result},
5 object::{Object, ObjectId, ReadBuffer, ReadObject, WriteObject},
6};
7use reqwest::Client;
8pub use rusty_s3::{Bucket, Credentials};
9use rusty_s3::{S3Action, UrlStyle};
10use scc::HashMap;
11use std::{env, future::Future, sync::Arc, time::Duration};
12use tokio::{
13 sync::Semaphore,
14 task::{self, JoinError, JoinHandle},
15};
16
17mod region;
18pub use region::Region;
19
20struct InFlightTracker<TaskResult>
21where
22 TaskResult: 'static + Send,
23{
24 permits: Arc<Semaphore>,
25 active: Arc<HashMap<ObjectId, Arc<Option<JoinHandle<TaskResult>>>>>,
26}
27
28impl<TaskResult> Default for InFlightTracker<TaskResult>
29where
30 TaskResult: 'static + Send,
31{
32 fn default() -> Self {
33 Self {
34 permits: Semaphore::new(std::thread::available_parallelism().unwrap().get()).into(),
35 active: Arc::default(),
36 }
37 }
38}
39
40impl<TaskResult> Clone for InFlightTracker<TaskResult>
41where
42 TaskResult: 'static + Send,
43{
44 fn clone(&self) -> Self {
45 Self {
46 permits: self.permits.clone(),
47 active: self.active.clone(),
48 }
49 }
50}
51
52impl<TaskResult> InFlightTracker<TaskResult>
53where
54 TaskResult: 'static + Send,
55{
56 pub fn complete_all(&self) -> std::result::Result<Vec<TaskResult>, JoinError> {
57 block_on(async move {
58 let mut handles = vec![];
59 self.active
60 .retain_async(|_, v| {
61 if let Some(handle) = std::mem::take(Arc::get_mut(v).unwrap()) {
62 handles.push(handle);
63 }
64
65 false
66 })
67 .await;
68
69 futures::future::join_all(handles).await
70 })
71 .into_iter()
72 .filter(|result| match result {
73 Ok(_) => true,
74 Err(e) => !e.is_cancelled(),
75 })
76 .collect::<std::result::Result<Vec<_>, _>>()
77 }
78
79 pub fn add_task<F: 'static + Send + Future<Output = TaskResult>>(
80 &self,
81 key: ObjectId,
82 task: F,
83 ) {
84 let permits = self.permits.clone();
85 let active = self.active.clone();
86
87 block_on(async {
88 let permit = permits.acquire_owned().await;
89
90 let handle = Arc::new(Some(task::spawn(async move {
91 let _permit = permit;
92 let result = task.await;
93 active.remove_async(&key).await;
94 result
95 })));
96
97 match self.active.entry_async(key).await {
98 scc::hash_map::Entry::Occupied(mut entry) => {
99 if let Some(handle) = entry.get().as_ref() {
100 handle.abort();
101 }
102
103 *entry.get_mut() = handle.clone();
104 }
105 scc::hash_map::Entry::Vacant(entry) => {
106 entry.insert_entry(handle);
107 }
108 }
109 })
110 }
111}
112
113#[derive(Clone)]
114pub struct S3 {
115 base_path: String,
116 client: Client,
117 bucket: Arc<Bucket>,
118 credentials: Arc<Credentials>,
119 in_flight: InFlightTracker<anyhow::Result<u16>>,
120}
121
122impl S3 {
123 pub fn new(region: Region, bucket: impl AsRef<str>) -> Result<Arc<Self>> {
124 let access_key = env::var("AWS_ACCESS_KEY_ID").context("Invalid credentials")?;
125 let secret_key = env::var("AWS_SECRET_ACCESS_KEY").context("Invalid credentials")?;
126
127 let creds = Credentials::new(access_key, secret_key);
128 Self::with_credentials(region, bucket, creds)
129 }
130
131 pub fn with_credentials(
132 region: Region,
133 bucket: impl AsRef<str>,
134 creds: Credentials,
135 ) -> Result<Arc<Self>> {
136 let (bucket_name, base_path) = match bucket.as_ref().split_once('/') {
137 Some((bucket, "")) => (bucket.to_string(), "".to_string()),
138 Some((bucket, path)) => (bucket.to_string(), format!("{path}/")),
139 None => (bucket.as_ref().to_string(), "".to_string()),
140 };
141
142 let bucket = Bucket::new(
143 region.endpoint().parse().context("Invalid endpoint URL")?,
144 if let Region::Custom { .. } = region {
145 UrlStyle::Path
146 } else {
147 UrlStyle::VirtualHost
148 },
149 bucket_name,
150 region.to_string(),
151 )
152 .context("Failed to connect to S3 bucket")?
153 .into();
154
155 Ok(Self {
156 bucket,
157 base_path,
158 client: reqwest::Client::new(),
159 credentials: creds.into(),
160 in_flight: InFlightTracker::default(),
161 }
162 .into())
163 }
164
165 fn get_path(&self, id: &ObjectId) -> String {
166 format!("{}{}", &self.base_path, id.to_string())
168 }
169}
170
171impl Backend for S3 {
172 fn write_object(&self, object: &WriteObject) -> Result<()> {
173 let body = object.as_inner().to_vec();
174 let key = self.get_path(object.id());
175 let id = *object.id();
176
177 let this = self.clone();
178 self.in_flight.add_task(id, async move {
179 let url = this
180 .bucket
181 .put_object(Some(&this.credentials), &key)
182 .sign(Duration::from_secs(30));
183
184 let resp = this
185 .client
186 .put(url)
187 .body(body)
188 .send()
189 .await
190 .expect("Server error");
191
192 let status_code = resp.status().as_u16();
193 let resp_body = resp.bytes().await.expect("Response error");
194 if (200..300).contains(&status_code) {
195 Ok(status_code)
196 } else {
197 panic!(
198 "Bad response: {}, {}",
199 status_code,
200 String::from_utf8_lossy(resp_body.as_ref())
201 )
202 }
203 });
204
205 Ok(())
206 }
207
208 fn read_object(&self, id: &ObjectId) -> Result<Arc<ReadObject>> {
209 let this = self.clone();
210 let object: Result<Vec<u8>> = {
211 let key = self.get_path(id);
212
213 block_on(async move {
214 let url = this
215 .bucket
216 .get_object(Some(&this.credentials), &key)
217 .sign(Duration::from_secs(30));
218
219 let resp = this.client.get(url).send().await.context("Query error")?;
220 let status_code = resp.status().as_u16();
221 let body = resp.bytes().await.context("Read error")?;
222
223 if (200..300).contains(&status_code) {
224 Ok(body.to_vec())
225 } else {
226 Err(anyhow::anyhow!(
227 "Bad response: {}, {}",
228 status_code,
229 String::from_utf8_lossy(body.as_ref())
230 )
231 .into())
232 }
233 })
234 };
235
236 Ok(Arc::new(Object::with_id(*id, ReadBuffer::new(object?))))
237 }
238
239 fn sync(&self) -> Result<()> {
240 self.in_flight
241 .complete_all()
242 .context("Failed transactions with server")?;
243
244 Ok(())
245 }
246}
247
248#[cfg(test)]
249mod test {
250 use super::S3;
251 use crate::test::{write_and_wait_for_commit, TEST_DATA_DIR};
252 use hyper::server::conn::http1;
253 use hyper_util::rt::TokioIo;
254 use infinitree::{backends::Backend, object::WriteObject, ObjectId};
255 use s3s::{auth::SimpleAuth, service::S3ServiceBuilder};
256 use s3s_fs::FileSystem;
257 use std::net::SocketAddr;
258 use tokio::{net::TcpListener, task};
259
260 const AWS_ACCESS_KEY_ID: &str = "MEEMIEW3EEKI8IEY1U";
261 const AWS_SECRET_ACCESS_KEY_ID: &str = "noh8xah2thohv7laehei2lahBuno5FameiNi";
262
263 const SERVER_ADDR_RW: ([u8; 4], u16) = ([127, 0, 0, 1], 12312);
264 const SERVER_ADDR_RO: ([u8; 4], u16) = ([127, 0, 0, 1], 12313);
265
266 fn setup_s3_server(addr: &SocketAddr) {
267 let fs = FileSystem::new(TEST_DATA_DIR).unwrap();
268 let mut auth = SimpleAuth::new();
269 std::env::set_var("AWS_ACCESS_KEY_ID", AWS_ACCESS_KEY_ID);
270 std::env::set_var("AWS_SECRET_ACCESS_KEY", AWS_SECRET_ACCESS_KEY_ID);
271 auth.register(AWS_ACCESS_KEY_ID.into(), AWS_SECRET_ACCESS_KEY_ID.into());
272
273 let service = {
274 let mut b = S3ServiceBuilder::new(fs);
275 b.set_auth(auth);
276 b.build().into_shared()
277 };
278
279 let server = {
280 let addr = addr.clone();
281 task::spawn(async move {
282 let listener = TcpListener::bind(addr).await.unwrap();
283 loop {
284 let (tcp, _) = listener.accept().await.unwrap();
285 let io = TokioIo::new(tcp);
286 let service = service.clone();
287
288 tokio::task::spawn(async move {
289 if let Err(err) = http1::Builder::new().serve_connection(io, service).await
290 {
291 println!("Error serving connection: {:?}", err);
292 }
293 });
294 }
295 })
296 };
297
298 let _server_handle = task::spawn(server);
299 }
300
301 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
302 async fn s3_write_read() {
303 let addr = SocketAddr::from(SERVER_ADDR_RW);
304 setup_s3_server(&addr);
305
306 let backend = S3::new(format!("http://{addr}").parse().unwrap(), "bucket").unwrap();
307
308 let mut object = WriteObject::default();
309 let id_2 = ObjectId::from_bytes(b"1234567890abcdef1234567890abcdef");
310
311 write_and_wait_for_commit(backend.as_ref(), &object);
312 let _obj_1_read_ref = backend.read_object(object.id()).unwrap();
313
314 object.set_id(id_2);
315 write_and_wait_for_commit(backend.as_ref(), &object);
316 }
317
318 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
319 #[should_panic(
320 expected = r#"Generic { source: Bad response: 404, <?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code></Error> }"#
321 )]
322 async fn s3_reading_nonexistent_object() {
323 let addr = SocketAddr::from(SERVER_ADDR_RO);
324 setup_s3_server(&addr);
325
326 let backend = S3::new(format!("http://{addr}").parse().unwrap(), "bucket").unwrap();
327
328 let id = ObjectId::from_bytes(b"2222222222abcdef1234567890abcdef");
329
330 let _obj_1_read_ref = backend.read_object(&id).unwrap();
331 }
332}