1use crate::{
21 BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta,
22 ObjectStore, Path, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, StreamExt,
23 UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48 inner: Arc<T>,
49 max_requests: usize,
50 semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54 pub fn new(inner: T, max_requests: usize) -> Self {
58 Self {
59 inner: Arc::new(inner),
60 max_requests,
61 semaphore: Arc::new(Semaphore::new(max_requests)),
62 }
63 }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69 }
70}
71
72#[async_trait]
73impl<T: ObjectStore> ObjectStore for LimitStore<T> {
74 async fn put(&self, location: &Path, payload: PutPayload) -> Result<PutResult> {
75 let _permit = self.semaphore.acquire().await.unwrap();
76 self.inner.put(location, payload).await
77 }
78
79 async fn put_opts(
80 &self,
81 location: &Path,
82 payload: PutPayload,
83 opts: PutOptions,
84 ) -> Result<PutResult> {
85 let _permit = self.semaphore.acquire().await.unwrap();
86 self.inner.put_opts(location, payload, opts).await
87 }
88 async fn put_multipart(&self, location: &Path) -> Result<Box<dyn MultipartUpload>> {
89 let upload = self.inner.put_multipart(location).await?;
90 Ok(Box::new(LimitUpload {
91 semaphore: Arc::clone(&self.semaphore),
92 upload,
93 }))
94 }
95
96 async fn put_multipart_opts(
97 &self,
98 location: &Path,
99 opts: PutMultipartOptions,
100 ) -> Result<Box<dyn MultipartUpload>> {
101 let upload = self.inner.put_multipart_opts(location, opts).await?;
102 Ok(Box::new(LimitUpload {
103 semaphore: Arc::clone(&self.semaphore),
104 upload,
105 }))
106 }
107
108 async fn get(&self, location: &Path) -> Result<GetResult> {
109 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
110 let r = self.inner.get(location).await?;
111 Ok(permit_get_result(r, permit))
112 }
113
114 async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
115 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
116 let r = self.inner.get_opts(location, options).await?;
117 Ok(permit_get_result(r, permit))
118 }
119
120 async fn get_range(&self, location: &Path, range: Range<u64>) -> Result<Bytes> {
121 let _permit = self.semaphore.acquire().await.unwrap();
122 self.inner.get_range(location, range).await
123 }
124
125 async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> Result<Vec<Bytes>> {
126 let _permit = self.semaphore.acquire().await.unwrap();
127 self.inner.get_ranges(location, ranges).await
128 }
129
130 async fn head(&self, location: &Path) -> Result<ObjectMeta> {
131 let _permit = self.semaphore.acquire().await.unwrap();
132 self.inner.head(location).await
133 }
134
135 async fn delete(&self, location: &Path) -> Result<()> {
136 let _permit = self.semaphore.acquire().await.unwrap();
137 self.inner.delete(location).await
138 }
139
140 fn delete_stream<'a>(
141 &'a self,
142 locations: BoxStream<'a, Result<Path>>,
143 ) -> BoxStream<'a, Result<Path>> {
144 self.inner.delete_stream(locations)
145 }
146
147 fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result<ObjectMeta>> {
148 let prefix = prefix.cloned();
149 let inner = Arc::clone(&self.inner);
150 let fut = Arc::clone(&self.semaphore)
151 .acquire_owned()
152 .map(move |permit| {
153 let s = inner.list(prefix.as_ref());
154 PermitWrapper::new(s, permit.unwrap())
155 });
156 fut.into_stream().flatten().boxed()
157 }
158
159 fn list_with_offset(
160 &self,
161 prefix: Option<&Path>,
162 offset: &Path,
163 ) -> BoxStream<'static, Result<ObjectMeta>> {
164 let prefix = prefix.cloned();
165 let offset = offset.clone();
166 let inner = Arc::clone(&self.inner);
167 let fut = Arc::clone(&self.semaphore)
168 .acquire_owned()
169 .map(move |permit| {
170 let s = inner.list_with_offset(prefix.as_ref(), &offset);
171 PermitWrapper::new(s, permit.unwrap())
172 });
173 fut.into_stream().flatten().boxed()
174 }
175
176 async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
177 let _permit = self.semaphore.acquire().await.unwrap();
178 self.inner.list_with_delimiter(prefix).await
179 }
180
181 async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
182 let _permit = self.semaphore.acquire().await.unwrap();
183 self.inner.copy(from, to).await
184 }
185
186 async fn rename(&self, from: &Path, to: &Path) -> Result<()> {
187 let _permit = self.semaphore.acquire().await.unwrap();
188 self.inner.rename(from, to).await
189 }
190
191 async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
192 let _permit = self.semaphore.acquire().await.unwrap();
193 self.inner.copy_if_not_exists(from, to).await
194 }
195
196 async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
197 let _permit = self.semaphore.acquire().await.unwrap();
198 self.inner.rename_if_not_exists(from, to).await
199 }
200}
201
202fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
203 let payload = match r.payload {
204 #[cfg(all(feature = "fs", not(target_arch = "wasm32")))]
205 v @ GetResultPayload::File(_, _) => v,
206 GetResultPayload::Stream(s) => {
207 GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
208 }
209 };
210 GetResult { payload, ..r }
211}
212
213struct PermitWrapper<T> {
215 inner: T,
216 #[allow(dead_code)]
217 permit: OwnedSemaphorePermit,
218}
219
220impl<T> PermitWrapper<T> {
221 fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
222 Self { inner, permit }
223 }
224}
225
226impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
227 type Item = T::Item;
228
229 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
230 Pin::new(&mut self.inner).poll_next(cx)
231 }
232
233 fn size_hint(&self) -> (usize, Option<usize>) {
234 self.inner.size_hint()
235 }
236}
237
238#[derive(Debug)]
240pub struct LimitUpload {
241 upload: Box<dyn MultipartUpload>,
242 semaphore: Arc<Semaphore>,
243}
244
245impl LimitUpload {
246 pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
248 Self {
249 upload,
250 semaphore: Arc::new(Semaphore::new(max_concurrency)),
251 }
252 }
253}
254
255#[async_trait]
256impl MultipartUpload for LimitUpload {
257 fn put_part(&mut self, data: PutPayload) -> UploadPart {
258 let upload = self.upload.put_part(data);
259 let s = Arc::clone(&self.semaphore);
260 Box::pin(async move {
261 let _permit = s.acquire().await.unwrap();
262 upload.await
263 })
264 }
265
266 async fn complete(&mut self) -> Result<PutResult> {
267 let _permit = self.semaphore.acquire().await.unwrap();
268 self.upload.complete().await
269 }
270
271 async fn abort(&mut self) -> Result<()> {
272 let _permit = self.semaphore.acquire().await.unwrap();
273 self.upload.abort().await
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use crate::integration::*;
280 use crate::limit::LimitStore;
281 use crate::memory::InMemory;
282 use crate::ObjectStore;
283 use futures::stream::StreamExt;
284 use std::pin::Pin;
285 use std::time::Duration;
286 use tokio::time::timeout;
287
288 #[tokio::test]
289 async fn limit_test() {
290 let max_requests = 10;
291 let memory = InMemory::new();
292 let integration = LimitStore::new(memory, max_requests);
293
294 put_get_delete_list(&integration).await;
295 get_opts(&integration).await;
296 list_uses_directories_correctly(&integration).await;
297 list_with_delimiter(&integration).await;
298 rename_and_copy(&integration).await;
299 stream_get(&integration).await;
300
301 let mut streams = Vec::with_capacity(max_requests);
302 for _ in 0..max_requests {
303 let mut stream = integration.list(None).peekable();
304 Pin::new(&mut stream).peek().await; streams.push(stream);
306 }
307
308 let t = Duration::from_millis(20);
309
310 let fut = integration.list(None).collect::<Vec<_>>();
312 assert!(timeout(t, fut).await.is_err());
313
314 streams.pop();
316
317 integration.list(None).collect::<Vec<_>>().await;
319 }
320}