object_store/
limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! An object store that limits the maximum concurrency of the wrapped implementation
19
20use crate::{
21    BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta,
22    ObjectStore, Path, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, StreamExt,
23    UploadPart,
24};
25use async_trait::async_trait;
26use bytes::Bytes;
27use futures::{FutureExt, Stream};
28use std::ops::Range;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use tokio::sync::{OwnedSemaphorePermit, Semaphore};
33
34/// Store wrapper that wraps an inner store and limits the maximum number of concurrent
35/// object store operations. Where each call to an [`ObjectStore`] member function is
36/// considered a single operation, even if it may result in more than one network call
37///
38/// ```
39/// # use object_store::memory::InMemory;
40/// # use object_store::limit::LimitStore;
41///
42/// // Create an in-memory `ObjectStore` limited to 20 concurrent requests
43/// let store = LimitStore::new(InMemory::new(), 20);
44/// ```
45///
46#[derive(Debug)]
47pub struct LimitStore<T: ObjectStore> {
48    inner: Arc<T>,
49    max_requests: usize,
50    semaphore: Arc<Semaphore>,
51}
52
53impl<T: ObjectStore> LimitStore<T> {
54    /// Create new limit store that will limit the maximum
55    /// number of outstanding concurrent requests to
56    /// `max_requests`
57    pub fn new(inner: T, max_requests: usize) -> Self {
58        Self {
59            inner: Arc::new(inner),
60            max_requests,
61            semaphore: Arc::new(Semaphore::new(max_requests)),
62        }
63    }
64}
65
66impl<T: ObjectStore> std::fmt::Display for LimitStore<T> {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "LimitStore({}, {})", self.max_requests, self.inner)
69    }
70}
71
72#[async_trait]
73impl<T: ObjectStore> ObjectStore for LimitStore<T> {
74    async fn put(&self, location: &Path, payload: PutPayload) -> Result<PutResult> {
75        let _permit = self.semaphore.acquire().await.unwrap();
76        self.inner.put(location, payload).await
77    }
78
79    async fn put_opts(
80        &self,
81        location: &Path,
82        payload: PutPayload,
83        opts: PutOptions,
84    ) -> Result<PutResult> {
85        let _permit = self.semaphore.acquire().await.unwrap();
86        self.inner.put_opts(location, payload, opts).await
87    }
88    async fn put_multipart(&self, location: &Path) -> Result<Box<dyn MultipartUpload>> {
89        let upload = self.inner.put_multipart(location).await?;
90        Ok(Box::new(LimitUpload {
91            semaphore: Arc::clone(&self.semaphore),
92            upload,
93        }))
94    }
95
96    async fn put_multipart_opts(
97        &self,
98        location: &Path,
99        opts: PutMultipartOptions,
100    ) -> Result<Box<dyn MultipartUpload>> {
101        let upload = self.inner.put_multipart_opts(location, opts).await?;
102        Ok(Box::new(LimitUpload {
103            semaphore: Arc::clone(&self.semaphore),
104            upload,
105        }))
106    }
107
108    async fn get(&self, location: &Path) -> Result<GetResult> {
109        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
110        let r = self.inner.get(location).await?;
111        Ok(permit_get_result(r, permit))
112    }
113
114    async fn get_opts(&self, location: &Path, options: GetOptions) -> Result<GetResult> {
115        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
116        let r = self.inner.get_opts(location, options).await?;
117        Ok(permit_get_result(r, permit))
118    }
119
120    async fn get_range(&self, location: &Path, range: Range<u64>) -> Result<Bytes> {
121        let _permit = self.semaphore.acquire().await.unwrap();
122        self.inner.get_range(location, range).await
123    }
124
125    async fn get_ranges(&self, location: &Path, ranges: &[Range<u64>]) -> Result<Vec<Bytes>> {
126        let _permit = self.semaphore.acquire().await.unwrap();
127        self.inner.get_ranges(location, ranges).await
128    }
129
130    async fn head(&self, location: &Path) -> Result<ObjectMeta> {
131        let _permit = self.semaphore.acquire().await.unwrap();
132        self.inner.head(location).await
133    }
134
135    async fn delete(&self, location: &Path) -> Result<()> {
136        let _permit = self.semaphore.acquire().await.unwrap();
137        self.inner.delete(location).await
138    }
139
140    fn delete_stream<'a>(
141        &'a self,
142        locations: BoxStream<'a, Result<Path>>,
143    ) -> BoxStream<'a, Result<Path>> {
144        self.inner.delete_stream(locations)
145    }
146
147    fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result<ObjectMeta>> {
148        let prefix = prefix.cloned();
149        let inner = Arc::clone(&self.inner);
150        let fut = Arc::clone(&self.semaphore)
151            .acquire_owned()
152            .map(move |permit| {
153                let s = inner.list(prefix.as_ref());
154                PermitWrapper::new(s, permit.unwrap())
155            });
156        fut.into_stream().flatten().boxed()
157    }
158
159    fn list_with_offset(
160        &self,
161        prefix: Option<&Path>,
162        offset: &Path,
163    ) -> BoxStream<'static, Result<ObjectMeta>> {
164        let prefix = prefix.cloned();
165        let offset = offset.clone();
166        let inner = Arc::clone(&self.inner);
167        let fut = Arc::clone(&self.semaphore)
168            .acquire_owned()
169            .map(move |permit| {
170                let s = inner.list_with_offset(prefix.as_ref(), &offset);
171                PermitWrapper::new(s, permit.unwrap())
172            });
173        fut.into_stream().flatten().boxed()
174    }
175
176    async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
177        let _permit = self.semaphore.acquire().await.unwrap();
178        self.inner.list_with_delimiter(prefix).await
179    }
180
181    async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
182        let _permit = self.semaphore.acquire().await.unwrap();
183        self.inner.copy(from, to).await
184    }
185
186    async fn rename(&self, from: &Path, to: &Path) -> Result<()> {
187        let _permit = self.semaphore.acquire().await.unwrap();
188        self.inner.rename(from, to).await
189    }
190
191    async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
192        let _permit = self.semaphore.acquire().await.unwrap();
193        self.inner.copy_if_not_exists(from, to).await
194    }
195
196    async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
197        let _permit = self.semaphore.acquire().await.unwrap();
198        self.inner.rename_if_not_exists(from, to).await
199    }
200}
201
202fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult {
203    let payload = match r.payload {
204        #[cfg(all(feature = "fs", not(target_arch = "wasm32")))]
205        v @ GetResultPayload::File(_, _) => v,
206        GetResultPayload::Stream(s) => {
207            GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed())
208        }
209    };
210    GetResult { payload, ..r }
211}
212
213/// Combines an [`OwnedSemaphorePermit`] with some other type
214struct PermitWrapper<T> {
215    inner: T,
216    #[allow(dead_code)]
217    permit: OwnedSemaphorePermit,
218}
219
220impl<T> PermitWrapper<T> {
221    fn new(inner: T, permit: OwnedSemaphorePermit) -> Self {
222        Self { inner, permit }
223    }
224}
225
226impl<T: Stream + Unpin> Stream for PermitWrapper<T> {
227    type Item = T::Item;
228
229    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
230        Pin::new(&mut self.inner).poll_next(cx)
231    }
232
233    fn size_hint(&self) -> (usize, Option<usize>) {
234        self.inner.size_hint()
235    }
236}
237
238/// An [`MultipartUpload`] wrapper that limits the maximum number of concurrent requests
239#[derive(Debug)]
240pub struct LimitUpload {
241    upload: Box<dyn MultipartUpload>,
242    semaphore: Arc<Semaphore>,
243}
244
245impl LimitUpload {
246    /// Create a new [`LimitUpload`] limiting `upload` to `max_concurrency` concurrent requests
247    pub fn new(upload: Box<dyn MultipartUpload>, max_concurrency: usize) -> Self {
248        Self {
249            upload,
250            semaphore: Arc::new(Semaphore::new(max_concurrency)),
251        }
252    }
253}
254
255#[async_trait]
256impl MultipartUpload for LimitUpload {
257    fn put_part(&mut self, data: PutPayload) -> UploadPart {
258        let upload = self.upload.put_part(data);
259        let s = Arc::clone(&self.semaphore);
260        Box::pin(async move {
261            let _permit = s.acquire().await.unwrap();
262            upload.await
263        })
264    }
265
266    async fn complete(&mut self) -> Result<PutResult> {
267        let _permit = self.semaphore.acquire().await.unwrap();
268        self.upload.complete().await
269    }
270
271    async fn abort(&mut self) -> Result<()> {
272        let _permit = self.semaphore.acquire().await.unwrap();
273        self.upload.abort().await
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use crate::integration::*;
280    use crate::limit::LimitStore;
281    use crate::memory::InMemory;
282    use crate::ObjectStore;
283    use futures::stream::StreamExt;
284    use std::pin::Pin;
285    use std::time::Duration;
286    use tokio::time::timeout;
287
288    #[tokio::test]
289    async fn limit_test() {
290        let max_requests = 10;
291        let memory = InMemory::new();
292        let integration = LimitStore::new(memory, max_requests);
293
294        put_get_delete_list(&integration).await;
295        get_opts(&integration).await;
296        list_uses_directories_correctly(&integration).await;
297        list_with_delimiter(&integration).await;
298        rename_and_copy(&integration).await;
299        stream_get(&integration).await;
300
301        let mut streams = Vec::with_capacity(max_requests);
302        for _ in 0..max_requests {
303            let mut stream = integration.list(None).peekable();
304            Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired
305            streams.push(stream);
306        }
307
308        let t = Duration::from_millis(20);
309
310        // Expect to not be able to make another request
311        let fut = integration.list(None).collect::<Vec<_>>();
312        assert!(timeout(t, fut).await.is_err());
313
314        // Drop one of the streams
315        streams.pop();
316
317        // Can now make another request
318        integration.list(None).collect::<Vec<_>>().await;
319    }
320}