1#![cfg_attr(docsrs, feature(doc_cfg))]
21#![deny(missing_docs)]
22
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::Context;
27use std::task::Poll;
28
29use futures::Stream;
30use futures::StreamExt;
31use mea::semaphore::OwnedSemaphorePermit;
32use mea::semaphore::Semaphore;
33use opendal_core::raw::*;
34use opendal_core::*;
35
36pub trait ConcurrentLimitSemaphore: Send + Sync + Clone + Unpin + 'static {
39 type Permit: Send + Sync + 'static;
42
43 fn acquire(&self) -> impl Future<Output = Self::Permit> + MaybeSend;
45}
46
47impl ConcurrentLimitSemaphore for Arc<Semaphore> {
48 type Permit = OwnedSemaphorePermit;
49
50 async fn acquire(&self) -> Self::Permit {
51 self.clone().acquire_owned(1).await
52 }
53}
54
55#[derive(Clone)]
106pub struct ConcurrentLimitLayer<S: ConcurrentLimitSemaphore = Arc<Semaphore>> {
107 operation_semaphore: S,
108 http_semaphore: Option<S>,
109}
110
111impl ConcurrentLimitLayer<Arc<Semaphore>> {
112 pub fn new(permits: usize) -> Self {
117 Self::with_semaphore(Arc::new(Semaphore::new(permits)))
118 }
119
120 pub fn with_http_concurrent_limit(self, permits: usize) -> Self {
127 self.with_http_semaphore(Arc::new(Semaphore::new(permits)))
128 }
129}
130
131impl<S: ConcurrentLimitSemaphore> ConcurrentLimitLayer<S> {
132 pub fn with_semaphore(operation_semaphore: S) -> Self {
142 Self {
143 operation_semaphore,
144 http_semaphore: None,
145 }
146 }
147
148 pub fn with_http_semaphore(mut self, semaphore: S) -> Self {
150 self.http_semaphore = Some(semaphore);
151 self
152 }
153}
154
155impl<A: Access, S: ConcurrentLimitSemaphore> Layer<A> for ConcurrentLimitLayer<S>
156where
157 S::Permit: Unpin,
158{
159 type LayeredAccess = ConcurrentLimitAccessor<A, S>;
160
161 fn layer(&self, inner: A) -> Self::LayeredAccess {
162 let info = inner.info();
163
164 info.update_http_client(|client| {
166 HttpClient::with(ConcurrentLimitHttpFetcher::<S> {
167 inner: client.into_inner(),
168 http_semaphore: self.http_semaphore.clone(),
169 })
170 });
171
172 ConcurrentLimitAccessor {
173 inner,
174 semaphore: self.operation_semaphore.clone(),
175 }
176 }
177}
178
179#[doc(hidden)]
180pub struct ConcurrentLimitHttpFetcher<S: ConcurrentLimitSemaphore> {
181 inner: HttpFetcher,
182 http_semaphore: Option<S>,
183}
184
185impl<S: ConcurrentLimitSemaphore> HttpFetch for ConcurrentLimitHttpFetcher<S>
186where
187 S::Permit: Unpin,
188{
189 async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
190 let Some(semaphore) = self.http_semaphore.clone() else {
191 return self.inner.fetch(req).await;
192 };
193
194 let permit = semaphore.acquire().await;
195
196 let resp = self.inner.fetch(req).await?;
197 let (parts, body) = resp.into_parts();
198 let body = body.map_inner(|s| {
199 Box::new(ConcurrentLimitStream::<_, S::Permit> {
200 inner: s,
201 _permit: permit,
202 })
203 });
204 Ok(http::Response::from_parts(parts, body))
205 }
206}
207
208struct ConcurrentLimitStream<S, P> {
209 inner: S,
210 _permit: P,
212}
213
214impl<S, P> Stream for ConcurrentLimitStream<S, P>
215where
216 S: Stream<Item = Result<Buffer>> + Unpin + 'static,
217 P: Unpin,
218{
219 type Item = Result<Buffer>;
220
221 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 let this = self.get_mut();
224 this.inner.poll_next_unpin(cx)
225 }
226}
227
228#[doc(hidden)]
229#[derive(Clone)]
230pub struct ConcurrentLimitAccessor<A: Access, S: ConcurrentLimitSemaphore> {
231 inner: A,
232 semaphore: S,
233}
234
235impl<A: Access, S: ConcurrentLimitSemaphore> std::fmt::Debug for ConcurrentLimitAccessor<A, S> {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 f.debug_struct("ConcurrentLimitAccessor")
238 .field("inner", &self.inner)
239 .finish_non_exhaustive()
240 }
241}
242
243impl<A: Access, S: ConcurrentLimitSemaphore> LayeredAccess for ConcurrentLimitAccessor<A, S>
244where
245 S::Permit: Unpin,
246{
247 type Inner = A;
248 type Reader = ConcurrentLimitWrapper<A::Reader, S::Permit>;
249 type Writer = ConcurrentLimitWrapper<A::Writer, S::Permit>;
250 type Lister = ConcurrentLimitWrapper<A::Lister, S::Permit>;
251 type Deleter = ConcurrentLimitWrapper<A::Deleter, S::Permit>;
252 type Copier = ConcurrentLimitWrapper<A::Copier, S::Permit>;
253
254 fn inner(&self) -> &Self::Inner {
255 &self.inner
256 }
257
258 async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
259 let _permit = self.semaphore.acquire().await;
260
261 self.inner.create_dir(path, args).await
262 }
263
264 async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
265 let permit = self.semaphore.acquire().await;
266
267 self.inner
268 .read(path, args)
269 .await
270 .map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
271 }
272
273 async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
274 let permit = self.semaphore.acquire().await;
275
276 self.inner
277 .write(path, args)
278 .await
279 .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
280 }
281
282 async fn copy(
283 &self,
284 from: &str,
285 to: &str,
286 args: OpCopy,
287 opts: OpCopier,
288 ) -> Result<(RpCopy, Self::Copier)> {
289 let permit = self.semaphore.acquire().await;
290
291 self.inner
292 .copy(from, to, args, opts.clone())
293 .await
294 .map(|(rp, c)| (rp, ConcurrentLimitWrapper::new(c, permit)))
295 }
296
297 async fn rename(&self, from: &str, to: &str, args: OpRename) -> Result<RpRename> {
298 let _permit = self.semaphore.acquire().await;
299
300 self.inner.rename(from, to, args).await
301 }
302
303 async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
304 let _permit = self.semaphore.acquire().await;
305
306 self.inner.stat(path, args).await
307 }
308
309 async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
310 let permit = self.semaphore.acquire().await;
311
312 self.inner
313 .delete()
314 .await
315 .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
316 }
317
318 async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
319 let permit = self.semaphore.acquire().await;
320
321 self.inner
322 .list(path, args)
323 .await
324 .map(|(rp, s)| (rp, ConcurrentLimitWrapper::new(s, permit)))
325 }
326}
327
328#[doc(hidden)]
329pub struct ConcurrentLimitWrapper<R, P> {
330 inner: R,
331
332 _permit: P,
334}
335
336impl<R, P> ConcurrentLimitWrapper<R, P> {
337 fn new(inner: R, permit: P) -> Self {
338 Self {
339 inner,
340 _permit: permit,
341 }
342 }
343}
344
345impl<R: oio::Read, P: Send + Sync + 'static + Unpin> oio::Read for ConcurrentLimitWrapper<R, P> {
346 async fn read(&mut self) -> Result<Buffer> {
347 self.inner.read().await
348 }
349}
350
351impl<R: oio::Write, P: Send + Sync + 'static + Unpin> oio::Write for ConcurrentLimitWrapper<R, P> {
352 async fn write(&mut self, bs: Buffer) -> Result<()> {
353 self.inner.write(bs).await
354 }
355
356 async fn close(&mut self) -> Result<Metadata> {
357 self.inner.close().await
358 }
359
360 async fn abort(&mut self) -> Result<()> {
361 self.inner.abort().await
362 }
363}
364
365impl<R: oio::List, P: Send + Sync + 'static + Unpin> oio::List for ConcurrentLimitWrapper<R, P> {
366 async fn next(&mut self) -> Result<Option<oio::Entry>> {
367 self.inner.next().await
368 }
369}
370
371impl<R: oio::Delete, P: Send + Sync + 'static + Unpin> oio::Delete
372 for ConcurrentLimitWrapper<R, P>
373{
374 async fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
375 self.inner.delete(path, args).await
376 }
377
378 async fn close(&mut self) -> Result<()> {
379 self.inner.close().await
380 }
381}
382
383impl<C: oio::Copy, P: Send + Sync + 'static + Unpin> oio::Copy for ConcurrentLimitWrapper<C, P> {
384 async fn next(&mut self) -> Result<Option<usize>> {
385 self.inner.next().await
386 }
387
388 async fn close(&mut self) -> Result<Metadata> {
389 self.inner.close().await
390 }
391
392 async fn abort(&mut self) -> Result<()> {
393 self.inner.abort().await
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use opendal_core::Operator;
401 use opendal_core::OperatorBuilder;
402 use opendal_core::services;
403 use std::sync::Arc;
404 use std::time::Duration;
405 use tokio::time::timeout;
406
407 use futures::stream;
408 use http::Response;
409
410 #[tokio::test]
411 async fn operation_semaphore_can_be_shared() {
412 let semaphore = Arc::new(Semaphore::new(1));
413 let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
414
415 let permit = semaphore.clone().acquire_owned(1).await;
416
417 let op = Operator::new(services::Memory::default())
418 .expect("operator must build")
419 .layer(layer)
420 .finish();
421
422 let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
423 assert!(
424 blocked.is_err(),
425 "operation should be limited by shared semaphore"
426 );
427
428 drop(permit);
429
430 let completed = timeout(Duration::from_millis(50), op.stat("any")).await;
431 assert!(
432 completed.is_ok(),
433 "operation should proceed once permit is released"
434 );
435 }
436
437 #[tokio::test]
438 async fn operation_semaphore_limits_copy_and_rename() {
439 #[derive(Clone, Debug)]
440 struct CopyRenameBackend {
441 info: Arc<AccessorInfo>,
442 }
443
444 impl Access for CopyRenameBackend {
445 type Reader = ();
446 type Writer = ();
447 type Lister = ();
448 type Deleter = ();
449 type Copier = oio::Copier;
450
451 fn info(&self) -> Arc<AccessorInfo> {
452 self.info.clone()
453 }
454
455 async fn copy(
456 &self,
457 _: &str,
458 _: &str,
459 _: OpCopy,
460 _: OpCopier,
461 ) -> Result<(RpCopy, Self::Copier)> {
462 Ok((RpCopy::default(), Box::new(())))
463 }
464
465 async fn rename(&self, _: &str, _: &str, _: OpRename) -> Result<RpRename> {
466 Ok(RpRename::default())
467 }
468 }
469
470 let semaphore = Arc::new(Semaphore::new(1));
471 let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
472 let info = Arc::new(AccessorInfo::default());
473 info.set_native_capability(Capability {
474 copy: true,
475 rename: true,
476 ..Default::default()
477 });
478 let op = OperatorBuilder::new(CopyRenameBackend { info })
479 .layer(layer)
480 .finish();
481
482 let permit = semaphore.clone().acquire_owned(1).await;
483
484 let copy = timeout(Duration::from_millis(50), op.copy("from", "to")).await;
485 assert!(copy.is_err(), "copy should wait for the operation permit");
486
487 let rename = timeout(Duration::from_millis(50), op.rename("from", "to")).await;
488 assert!(
489 rename.is_err(),
490 "rename should wait for the operation permit"
491 );
492
493 drop(permit);
494
495 timeout(Duration::from_millis(50), op.copy("from", "to"))
496 .await
497 .expect("copy should proceed once permit is released")
498 .expect("copy should succeed");
499 timeout(Duration::from_millis(50), op.rename("from", "to"))
500 .await
501 .expect("rename should proceed once permit is released")
502 .expect("rename should succeed");
503 }
504
505 #[tokio::test]
506 async fn operation_semaphore_held_until_copier_dropped() {
507 #[derive(Clone, Debug)]
508 struct CopierBackend {
509 info: Arc<AccessorInfo>,
510 }
511
512 impl Access for CopierBackend {
513 type Reader = ();
514 type Writer = ();
515 type Lister = ();
516 type Deleter = ();
517 type Copier = oio::Copier;
518
519 fn info(&self) -> Arc<AccessorInfo> {
520 self.info.clone()
521 }
522
523 async fn copy(
524 &self,
525 _: &str,
526 _: &str,
527 _: OpCopy,
528 _: OpCopier,
529 ) -> Result<(RpCopy, Self::Copier)> {
530 Ok((RpCopy::default(), Box::new(())))
531 }
532
533 async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
534 Ok(RpStat::new(Metadata::new(EntryMode::FILE)))
535 }
536 }
537
538 let semaphore = Arc::new(Semaphore::new(1));
539 let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
540 let info = Arc::new(AccessorInfo::default());
541 info.set_native_capability(Capability {
542 copy: true,
543 stat: true,
544 ..Default::default()
545 });
546 let op = OperatorBuilder::new(CopierBackend { info })
547 .layer(layer)
548 .finish();
549
550 let copier = timeout(Duration::from_millis(50), op.copier("from", "to"))
551 .await
552 .expect("copier setup should not block")
553 .expect("copier should be created");
554
555 let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
558 assert!(
559 blocked.is_err(),
560 "stat should wait while the copier holds the permit"
561 );
562
563 drop(copier);
564
565 timeout(Duration::from_millis(50), op.stat("any"))
566 .await
567 .expect("stat should proceed once the copier is dropped")
568 .expect("stat should succeed");
569 }
570
571 #[tokio::test]
572 async fn concurrent_chunked_read_with_http_limit() {
573 use opendal_core::raw::*;
574
575 struct EchoFetcher;
576
577 impl HttpFetch for EchoFetcher {
578 async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
579 let data = req.into_body();
580 let len = data.len() as u64;
581 let body =
582 HttpBody::new(Box::pin(stream::once(async move { Ok(data) })), Some(len));
583 Ok(http::Response::builder()
584 .status(http::StatusCode::OK)
585 .body(body)
586 .unwrap())
587 }
588 }
589
590 #[derive(Clone, Debug)]
591 struct HttpBackend {
592 info: Arc<AccessorInfo>,
593 content: Buffer,
594 }
595
596 impl Access for HttpBackend {
597 type Reader = HttpBody;
598 type Writer = ();
599 type Lister = ();
600 type Deleter = ();
601 type Copier = oio::Copier;
602
603 fn info(&self) -> Arc<AccessorInfo> {
604 self.info.clone()
605 }
606
607 async fn read(&self, _: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
608 let range = args.range();
609 let start = range.offset() as usize;
610 let data = match range.size() {
611 Some(sz) => self.content.slice(start..start + sz as usize),
612 None => self.content.slice(start..),
613 };
614 let req = http::Request::get("http://fake").body(data).unwrap();
615 let resp = self.info.http_client().fetch(req).await?;
616 Ok((
617 RpRead::new(Metadata::new(EntryMode::FILE).with_content_length(0)),
618 resp.into_body(),
619 ))
620 }
621
622 async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
623 Ok(RpStat::new(
624 Metadata::new(EntryMode::FILE).with_content_length(self.content.len() as u64),
625 ))
626 }
627
628 async fn write(&self, _: &str, _: OpWrite) -> Result<(RpWrite, Self::Writer)> {
629 Err(Error::new(ErrorKind::Unsupported, "not needed"))
630 }
631 async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
632 Err(Error::new(ErrorKind::Unsupported, "not needed"))
633 }
634 async fn list(&self, _: &str, _: OpList) -> Result<(RpList, Self::Lister)> {
635 Err(Error::new(ErrorKind::Unsupported, "not needed"))
636 }
637 }
638
639 let content = Buffer::from(vec![0u8; 4096]);
640 let info = Arc::new(AccessorInfo::default());
641 info.update_http_client(|_| HttpClient::with(EchoFetcher));
642
643 let op = OperatorBuilder::new(HttpBackend {
644 info,
645 content: content.clone(),
646 })
647 .layer(ConcurrentLimitLayer::new(1024).with_http_concurrent_limit(2))
648 .finish();
649
650 let result = timeout(Duration::from_secs(5), async {
652 op.reader_with("test")
653 .chunk(256)
654 .concurrent(4)
655 .await
656 .expect("reader must build")
657 .read(..)
658 .await
659 })
660 .await;
661
662 let buf = result
663 .expect("read must not deadlock (timeout)")
664 .expect("read must succeed");
665 assert_eq!(buf.to_bytes(), content.to_bytes());
666 }
667
668 #[tokio::test]
669 async fn http_semaphore_holds_until_body_dropped() {
670 struct DummyFetcher;
671
672 impl HttpFetch for DummyFetcher {
673 async fn fetch(&self, _req: http::Request<Buffer>) -> Result<Response<HttpBody>> {
674 let body = HttpBody::new(stream::empty(), None);
675 Ok(Response::builder()
676 .status(http::StatusCode::OK)
677 .body(body)
678 .expect("response must build"))
679 }
680 }
681
682 let semaphore = Arc::new(Semaphore::new(1));
683 let layer = ConcurrentLimitLayer::new(1).with_http_semaphore(semaphore.clone());
684 let fetcher = ConcurrentLimitHttpFetcher::<Arc<Semaphore>> {
685 inner: HttpClient::with(DummyFetcher).into_inner(),
686 http_semaphore: layer.http_semaphore.clone(),
687 };
688
689 let request = http::Request::builder()
690 .uri("http://example.invalid/")
691 .body(Buffer::new())
692 .expect("request must build");
693 let _resp = fetcher
694 .fetch(request)
695 .await
696 .expect("first fetch should succeed");
697
698 let request = http::Request::builder()
699 .uri("http://example.invalid/")
700 .body(Buffer::new())
701 .expect("request must build");
702 let blocked = timeout(Duration::from_millis(50), fetcher.fetch(request)).await;
703 assert!(
704 blocked.is_err(),
705 "http fetch should block while the body holds the permit"
706 );
707 }
708}