Skip to main content

opendal_layer_concurrent_limit/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Concurrent request limit layer implementation for Apache OpenDAL.
19
20#![cfg_attr(docsrs, feature(doc_cfg))]
21#![deny(missing_docs)]
22
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::Context;
27use std::task::Poll;
28
29use futures::Stream;
30use futures::StreamExt;
31use mea::semaphore::OwnedSemaphorePermit;
32use mea::semaphore::Semaphore;
33use opendal_core::raw::*;
34use opendal_core::*;
35
36/// ConcurrentLimitSemaphore abstracts a semaphore-like concurrency primitive
37/// that yields an owned permit released on drop.
38pub trait ConcurrentLimitSemaphore: Send + Sync + Clone + Unpin + 'static {
39    /// The owned permit type associated with the semaphore. Dropping it
40    /// must release the permit back to the semaphore.
41    type Permit: Send + Sync + 'static;
42
43    /// Acquire an owned permit asynchronously.
44    fn acquire(&self) -> impl Future<Output = Self::Permit> + MaybeSend;
45}
46
47impl ConcurrentLimitSemaphore for Arc<Semaphore> {
48    type Permit = OwnedSemaphorePermit;
49
50    async fn acquire(&self) -> Self::Permit {
51        self.clone().acquire_owned(1).await
52    }
53}
54
55/// Add concurrent request limit.
56///
57/// # Notes
58///
59/// Users can control how many concurrent connections could be established
60/// between OpenDAL and underlying storage services.
61///
62/// All operators wrapped by this layer will share a common semaphore. This
63/// allows you to reuse the same layer across multiple operators, ensuring
64/// that the total number of concurrent requests across the entire
65/// application does not exceed the limit.
66///
67/// # Examples
68///
69/// Add a concurrent limit layer to the operator:
70///
71/// ```no_run
72/// # use opendal_core::services;
73/// # use opendal_core::Operator;
74/// # use opendal_core::Result;
75/// # use opendal_layer_concurrent_limit::ConcurrentLimitLayer;
76/// #
77/// # fn main() -> Result<()> {
78/// let _ = Operator::new(services::Memory::default())?
79///     .layer(ConcurrentLimitLayer::new(1024))
80///     .finish();
81/// # Ok(())
82/// # }
83/// ```
84///
85/// Share a concurrent limit layer between the operators:
86///
87/// ```no_run
88/// # use opendal_core::services;
89/// # use opendal_core::Operator;
90/// # use opendal_core::Result;
91/// # use opendal_layer_concurrent_limit::ConcurrentLimitLayer;
92/// #
93/// # fn main() -> Result<()> {
94/// let limit = ConcurrentLimitLayer::new(1024);
95///
96/// let _operator_a = Operator::new(services::Memory::default())?
97///     .layer(limit.clone())
98///     .finish();
99/// let _operator_b = Operator::new(services::Memory::default())?
100///     .layer(limit.clone())
101///     .finish();
102/// # Ok(())
103/// # }
104/// ```
105#[derive(Clone)]
106pub struct ConcurrentLimitLayer<S: ConcurrentLimitSemaphore = Arc<Semaphore>> {
107    operation_semaphore: S,
108    http_semaphore: Option<S>,
109}
110
111impl ConcurrentLimitLayer<Arc<Semaphore>> {
112    /// Create a new `ConcurrentLimitLayer` with the specified number of
113    /// permits.
114    ///
115    /// These permits will be applied to all operations.
116    pub fn new(permits: usize) -> Self {
117        Self::with_semaphore(Arc::new(Semaphore::new(permits)))
118    }
119
120    /// Set a concurrent limit for HTTP requests.
121    ///
122    /// This convenience helper constructs a new semaphore with the specified
123    /// number of permits and calls [`ConcurrentLimitLayer::with_http_semaphore`].
124    /// Use [`ConcurrentLimitLayer::with_http_semaphore`] directly when reusing
125    /// a shared semaphore.
126    pub fn with_http_concurrent_limit(self, permits: usize) -> Self {
127        self.with_http_semaphore(Arc::new(Semaphore::new(permits)))
128    }
129}
130
131impl<S: ConcurrentLimitSemaphore> ConcurrentLimitLayer<S> {
132    /// Create a layer with any ConcurrentLimitSemaphore implementation.
133    ///
134    /// ```
135    /// # use std::sync::Arc;
136    /// # use mea::semaphore::Semaphore;
137    /// # use opendal_layer_concurrent_limit::ConcurrentLimitLayer;
138    /// let semaphore = Arc::new(Semaphore::new(1024));
139    /// let _layer = ConcurrentLimitLayer::with_semaphore(semaphore);
140    /// ```
141    pub fn with_semaphore(operation_semaphore: S) -> Self {
142        Self {
143            operation_semaphore,
144            http_semaphore: None,
145        }
146    }
147
148    /// Provide a custom HTTP concurrency semaphore instance.
149    pub fn with_http_semaphore(mut self, semaphore: S) -> Self {
150        self.http_semaphore = Some(semaphore);
151        self
152    }
153}
154
155impl<A: Access, S: ConcurrentLimitSemaphore> Layer<A> for ConcurrentLimitLayer<S>
156where
157    S::Permit: Unpin,
158{
159    type LayeredAccess = ConcurrentLimitAccessor<A, S>;
160
161    fn layer(&self, inner: A) -> Self::LayeredAccess {
162        let info = inner.info();
163
164        // Update http client with concurrent limit http fetcher.
165        info.update_http_client(|client| {
166            HttpClient::with(ConcurrentLimitHttpFetcher::<S> {
167                inner: client.into_inner(),
168                http_semaphore: self.http_semaphore.clone(),
169            })
170        });
171
172        ConcurrentLimitAccessor {
173            inner,
174            semaphore: self.operation_semaphore.clone(),
175        }
176    }
177}
178
179#[doc(hidden)]
180pub struct ConcurrentLimitHttpFetcher<S: ConcurrentLimitSemaphore> {
181    inner: HttpFetcher,
182    http_semaphore: Option<S>,
183}
184
185impl<S: ConcurrentLimitSemaphore> HttpFetch for ConcurrentLimitHttpFetcher<S>
186where
187    S::Permit: Unpin,
188{
189    async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
190        let Some(semaphore) = self.http_semaphore.clone() else {
191            return self.inner.fetch(req).await;
192        };
193
194        let permit = semaphore.acquire().await;
195
196        let resp = self.inner.fetch(req).await?;
197        let (parts, body) = resp.into_parts();
198        let body = body.map_inner(|s| {
199            Box::new(ConcurrentLimitStream::<_, S::Permit> {
200                inner: s,
201                _permit: permit,
202            })
203        });
204        Ok(http::Response::from_parts(parts, body))
205    }
206}
207
208struct ConcurrentLimitStream<S, P> {
209    inner: S,
210    // Hold on this permit until this reader has been dropped.
211    _permit: P,
212}
213
214impl<S, P> Stream for ConcurrentLimitStream<S, P>
215where
216    S: Stream<Item = Result<Buffer>> + Unpin + 'static,
217    P: Unpin,
218{
219    type Item = Result<Buffer>;
220
221    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        // Safe due to Unpin bounds on S and P (thus on Self).
223        let this = self.get_mut();
224        this.inner.poll_next_unpin(cx)
225    }
226}
227
228#[doc(hidden)]
229#[derive(Clone)]
230pub struct ConcurrentLimitAccessor<A: Access, S: ConcurrentLimitSemaphore> {
231    inner: A,
232    semaphore: S,
233}
234
235impl<A: Access, S: ConcurrentLimitSemaphore> std::fmt::Debug for ConcurrentLimitAccessor<A, S> {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        f.debug_struct("ConcurrentLimitAccessor")
238            .field("inner", &self.inner)
239            .finish_non_exhaustive()
240    }
241}
242
243impl<A: Access, S: ConcurrentLimitSemaphore> LayeredAccess for ConcurrentLimitAccessor<A, S>
244where
245    S::Permit: Unpin,
246{
247    type Inner = A;
248    type Reader = ConcurrentLimitWrapper<A::Reader, S::Permit>;
249    type Writer = ConcurrentLimitWrapper<A::Writer, S::Permit>;
250    type Lister = ConcurrentLimitWrapper<A::Lister, S::Permit>;
251    type Deleter = ConcurrentLimitWrapper<A::Deleter, S::Permit>;
252    type Copier = ConcurrentLimitWrapper<A::Copier, S::Permit>;
253
254    fn inner(&self) -> &Self::Inner {
255        &self.inner
256    }
257
258    async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
259        let _permit = self.semaphore.acquire().await;
260
261        self.inner.create_dir(path, args).await
262    }
263
264    async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
265        let permit = self.semaphore.acquire().await;
266
267        self.inner
268            .read(path, args)
269            .await
270            .map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
271    }
272
273    async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
274        let permit = self.semaphore.acquire().await;
275
276        self.inner
277            .write(path, args)
278            .await
279            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
280    }
281
282    async fn copy(
283        &self,
284        from: &str,
285        to: &str,
286        args: OpCopy,
287        opts: OpCopier,
288    ) -> Result<(RpCopy, Self::Copier)> {
289        let permit = self.semaphore.acquire().await;
290
291        self.inner
292            .copy(from, to, args, opts.clone())
293            .await
294            .map(|(rp, c)| (rp, ConcurrentLimitWrapper::new(c, permit)))
295    }
296
297    async fn rename(&self, from: &str, to: &str, args: OpRename) -> Result<RpRename> {
298        let _permit = self.semaphore.acquire().await;
299
300        self.inner.rename(from, to, args).await
301    }
302
303    async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
304        let _permit = self.semaphore.acquire().await;
305
306        self.inner.stat(path, args).await
307    }
308
309    async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
310        let permit = self.semaphore.acquire().await;
311
312        self.inner
313            .delete()
314            .await
315            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
316    }
317
318    async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
319        let permit = self.semaphore.acquire().await;
320
321        self.inner
322            .list(path, args)
323            .await
324            .map(|(rp, s)| (rp, ConcurrentLimitWrapper::new(s, permit)))
325    }
326}
327
328#[doc(hidden)]
329pub struct ConcurrentLimitWrapper<R, P> {
330    inner: R,
331
332    // Hold on this permit until this reader has been dropped.
333    _permit: P,
334}
335
336impl<R, P> ConcurrentLimitWrapper<R, P> {
337    fn new(inner: R, permit: P) -> Self {
338        Self {
339            inner,
340            _permit: permit,
341        }
342    }
343}
344
345impl<R: oio::Read, P: Send + Sync + 'static + Unpin> oio::Read for ConcurrentLimitWrapper<R, P> {
346    async fn read(&mut self) -> Result<Buffer> {
347        self.inner.read().await
348    }
349}
350
351impl<R: oio::Write, P: Send + Sync + 'static + Unpin> oio::Write for ConcurrentLimitWrapper<R, P> {
352    async fn write(&mut self, bs: Buffer) -> Result<()> {
353        self.inner.write(bs).await
354    }
355
356    async fn close(&mut self) -> Result<Metadata> {
357        self.inner.close().await
358    }
359
360    async fn abort(&mut self) -> Result<()> {
361        self.inner.abort().await
362    }
363}
364
365impl<R: oio::List, P: Send + Sync + 'static + Unpin> oio::List for ConcurrentLimitWrapper<R, P> {
366    async fn next(&mut self) -> Result<Option<oio::Entry>> {
367        self.inner.next().await
368    }
369}
370
371impl<R: oio::Delete, P: Send + Sync + 'static + Unpin> oio::Delete
372    for ConcurrentLimitWrapper<R, P>
373{
374    async fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
375        self.inner.delete(path, args).await
376    }
377
378    async fn close(&mut self) -> Result<()> {
379        self.inner.close().await
380    }
381}
382
383impl<C: oio::Copy, P: Send + Sync + 'static + Unpin> oio::Copy for ConcurrentLimitWrapper<C, P> {
384    async fn next(&mut self) -> Result<Option<usize>> {
385        self.inner.next().await
386    }
387
388    async fn close(&mut self) -> Result<Metadata> {
389        self.inner.close().await
390    }
391
392    async fn abort(&mut self) -> Result<()> {
393        self.inner.abort().await
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use opendal_core::Operator;
401    use opendal_core::OperatorBuilder;
402    use opendal_core::services;
403    use std::sync::Arc;
404    use std::time::Duration;
405    use tokio::time::timeout;
406
407    use futures::stream;
408    use http::Response;
409
410    #[tokio::test]
411    async fn operation_semaphore_can_be_shared() {
412        let semaphore = Arc::new(Semaphore::new(1));
413        let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
414
415        let permit = semaphore.clone().acquire_owned(1).await;
416
417        let op = Operator::new(services::Memory::default())
418            .expect("operator must build")
419            .layer(layer)
420            .finish();
421
422        let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
423        assert!(
424            blocked.is_err(),
425            "operation should be limited by shared semaphore"
426        );
427
428        drop(permit);
429
430        let completed = timeout(Duration::from_millis(50), op.stat("any")).await;
431        assert!(
432            completed.is_ok(),
433            "operation should proceed once permit is released"
434        );
435    }
436
437    #[tokio::test]
438    async fn operation_semaphore_limits_copy_and_rename() {
439        #[derive(Clone, Debug)]
440        struct CopyRenameBackend {
441            info: Arc<AccessorInfo>,
442        }
443
444        impl Access for CopyRenameBackend {
445            type Reader = ();
446            type Writer = ();
447            type Lister = ();
448            type Deleter = ();
449            type Copier = oio::Copier;
450
451            fn info(&self) -> Arc<AccessorInfo> {
452                self.info.clone()
453            }
454
455            async fn copy(
456                &self,
457                _: &str,
458                _: &str,
459                _: OpCopy,
460                _: OpCopier,
461            ) -> Result<(RpCopy, Self::Copier)> {
462                Ok((RpCopy::default(), Box::new(())))
463            }
464
465            async fn rename(&self, _: &str, _: &str, _: OpRename) -> Result<RpRename> {
466                Ok(RpRename::default())
467            }
468        }
469
470        let semaphore = Arc::new(Semaphore::new(1));
471        let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
472        let info = Arc::new(AccessorInfo::default());
473        info.set_native_capability(Capability {
474            copy: true,
475            rename: true,
476            ..Default::default()
477        });
478        let op = OperatorBuilder::new(CopyRenameBackend { info })
479            .layer(layer)
480            .finish();
481
482        let permit = semaphore.clone().acquire_owned(1).await;
483
484        let copy = timeout(Duration::from_millis(50), op.copy("from", "to")).await;
485        assert!(copy.is_err(), "copy should wait for the operation permit");
486
487        let rename = timeout(Duration::from_millis(50), op.rename("from", "to")).await;
488        assert!(
489            rename.is_err(),
490            "rename should wait for the operation permit"
491        );
492
493        drop(permit);
494
495        timeout(Duration::from_millis(50), op.copy("from", "to"))
496            .await
497            .expect("copy should proceed once permit is released")
498            .expect("copy should succeed");
499        timeout(Duration::from_millis(50), op.rename("from", "to"))
500            .await
501            .expect("rename should proceed once permit is released")
502            .expect("rename should succeed");
503    }
504
505    #[tokio::test]
506    async fn operation_semaphore_held_until_copier_dropped() {
507        #[derive(Clone, Debug)]
508        struct CopierBackend {
509            info: Arc<AccessorInfo>,
510        }
511
512        impl Access for CopierBackend {
513            type Reader = ();
514            type Writer = ();
515            type Lister = ();
516            type Deleter = ();
517            type Copier = oio::Copier;
518
519            fn info(&self) -> Arc<AccessorInfo> {
520                self.info.clone()
521            }
522
523            async fn copy(
524                &self,
525                _: &str,
526                _: &str,
527                _: OpCopy,
528                _: OpCopier,
529            ) -> Result<(RpCopy, Self::Copier)> {
530                Ok((RpCopy::default(), Box::new(())))
531            }
532
533            async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
534                Ok(RpStat::new(Metadata::new(EntryMode::FILE)))
535            }
536        }
537
538        let semaphore = Arc::new(Semaphore::new(1));
539        let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
540        let info = Arc::new(AccessorInfo::default());
541        info.set_native_capability(Capability {
542            copy: true,
543            stat: true,
544            ..Default::default()
545        });
546        let op = OperatorBuilder::new(CopierBackend { info })
547            .layer(layer)
548            .finish();
549
550        let copier = timeout(Duration::from_millis(50), op.copier("from", "to"))
551            .await
552            .expect("copier setup should not block")
553            .expect("copier should be created");
554
555        // The permit is held by the live copier, so concurrent operations
556        // must time out until the copier is dropped.
557        let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
558        assert!(
559            blocked.is_err(),
560            "stat should wait while the copier holds the permit"
561        );
562
563        drop(copier);
564
565        timeout(Duration::from_millis(50), op.stat("any"))
566            .await
567            .expect("stat should proceed once the copier is dropped")
568            .expect("stat should succeed");
569    }
570
571    #[tokio::test]
572    async fn concurrent_chunked_read_with_http_limit() {
573        use opendal_core::raw::*;
574
575        struct EchoFetcher;
576
577        impl HttpFetch for EchoFetcher {
578            async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
579                let data = req.into_body();
580                let len = data.len() as u64;
581                let body =
582                    HttpBody::new(Box::pin(stream::once(async move { Ok(data) })), Some(len));
583                Ok(http::Response::builder()
584                    .status(http::StatusCode::OK)
585                    .body(body)
586                    .unwrap())
587            }
588        }
589
590        #[derive(Clone, Debug)]
591        struct HttpBackend {
592            info: Arc<AccessorInfo>,
593            content: Buffer,
594        }
595
596        impl Access for HttpBackend {
597            type Reader = HttpBody;
598            type Writer = ();
599            type Lister = ();
600            type Deleter = ();
601            type Copier = oio::Copier;
602
603            fn info(&self) -> Arc<AccessorInfo> {
604                self.info.clone()
605            }
606
607            async fn read(&self, _: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
608                let range = args.range();
609                let start = range.offset() as usize;
610                let data = match range.size() {
611                    Some(sz) => self.content.slice(start..start + sz as usize),
612                    None => self.content.slice(start..),
613                };
614                let req = http::Request::get("http://fake").body(data).unwrap();
615                let resp = self.info.http_client().fetch(req).await?;
616                Ok((
617                    RpRead::new(Metadata::new(EntryMode::FILE).with_content_length(0)),
618                    resp.into_body(),
619                ))
620            }
621
622            async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
623                Ok(RpStat::new(
624                    Metadata::new(EntryMode::FILE).with_content_length(self.content.len() as u64),
625                ))
626            }
627
628            async fn write(&self, _: &str, _: OpWrite) -> Result<(RpWrite, Self::Writer)> {
629                Err(Error::new(ErrorKind::Unsupported, "not needed"))
630            }
631            async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
632                Err(Error::new(ErrorKind::Unsupported, "not needed"))
633            }
634            async fn list(&self, _: &str, _: OpList) -> Result<(RpList, Self::Lister)> {
635                Err(Error::new(ErrorKind::Unsupported, "not needed"))
636            }
637        }
638
639        let content = Buffer::from(vec![0u8; 4096]);
640        let info = Arc::new(AccessorInfo::default());
641        info.update_http_client(|_| HttpClient::with(EchoFetcher));
642
643        let op = OperatorBuilder::new(HttpBackend {
644            info,
645            content: content.clone(),
646        })
647        .layer(ConcurrentLimitLayer::new(1024).with_http_concurrent_limit(2))
648        .finish();
649
650        // chunk=256 ⇒ 16 HTTP requests, concurrent=4, but only 2 HTTP permits.
651        let result = timeout(Duration::from_secs(5), async {
652            op.reader_with("test")
653                .chunk(256)
654                .concurrent(4)
655                .await
656                .expect("reader must build")
657                .read(..)
658                .await
659        })
660        .await;
661
662        let buf = result
663            .expect("read must not deadlock (timeout)")
664            .expect("read must succeed");
665        assert_eq!(buf.to_bytes(), content.to_bytes());
666    }
667
668    #[tokio::test]
669    async fn http_semaphore_holds_until_body_dropped() {
670        struct DummyFetcher;
671
672        impl HttpFetch for DummyFetcher {
673            async fn fetch(&self, _req: http::Request<Buffer>) -> Result<Response<HttpBody>> {
674                let body = HttpBody::new(stream::empty(), None);
675                Ok(Response::builder()
676                    .status(http::StatusCode::OK)
677                    .body(body)
678                    .expect("response must build"))
679            }
680        }
681
682        let semaphore = Arc::new(Semaphore::new(1));
683        let layer = ConcurrentLimitLayer::new(1).with_http_semaphore(semaphore.clone());
684        let fetcher = ConcurrentLimitHttpFetcher::<Arc<Semaphore>> {
685            inner: HttpClient::with(DummyFetcher).into_inner(),
686            http_semaphore: layer.http_semaphore.clone(),
687        };
688
689        let request = http::Request::builder()
690            .uri("http://example.invalid/")
691            .body(Buffer::new())
692            .expect("request must build");
693        let _resp = fetcher
694            .fetch(request)
695            .await
696            .expect("first fetch should succeed");
697
698        let request = http::Request::builder()
699            .uri("http://example.invalid/")
700            .body(Buffer::new())
701            .expect("request must build");
702        let blocked = timeout(Duration::from_millis(50), fetcher.fetch(request)).await;
703        assert!(
704            blocked.is_err(),
705            "http fetch should block while the body holds the permit"
706        );
707    }
708}