#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(missing_docs)]
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use futures::Stream;
use futures::StreamExt;
use mea::semaphore::OwnedSemaphorePermit;
use mea::semaphore::Semaphore;
use opendal_core::raw::*;
use opendal_core::*;
pub trait ConcurrentLimitSemaphore: Send + Sync + Clone + Unpin + 'static {
type Permit: Send + Sync + 'static;
fn acquire(&self) -> impl Future<Output = Self::Permit> + MaybeSend;
}
impl ConcurrentLimitSemaphore for Arc<Semaphore> {
type Permit = OwnedSemaphorePermit;
async fn acquire(&self) -> Self::Permit {
self.clone().acquire_owned(1).await
}
}
#[derive(Clone)]
pub struct ConcurrentLimitLayer<S: ConcurrentLimitSemaphore = Arc<Semaphore>> {
operation_semaphore: S,
http_semaphore: Option<S>,
}
impl ConcurrentLimitLayer<Arc<Semaphore>> {
pub fn new(permits: usize) -> Self {
Self::with_semaphore(Arc::new(Semaphore::new(permits)))
}
pub fn with_http_concurrent_limit(self, permits: usize) -> Self {
self.with_http_semaphore(Arc::new(Semaphore::new(permits)))
}
}
impl<S: ConcurrentLimitSemaphore> ConcurrentLimitLayer<S> {
pub fn with_semaphore(operation_semaphore: S) -> Self {
Self {
operation_semaphore,
http_semaphore: None,
}
}
pub fn with_http_semaphore(mut self, semaphore: S) -> Self {
self.http_semaphore = Some(semaphore);
self
}
}
impl<A: Access, S: ConcurrentLimitSemaphore> Layer<A> for ConcurrentLimitLayer<S>
where
S::Permit: Unpin,
{
type LayeredAccess = ConcurrentLimitAccessor<A, S>;
fn layer(&self, inner: A) -> Self::LayeredAccess {
let info = inner.info();
info.update_http_client(|client| {
HttpClient::with(ConcurrentLimitHttpFetcher::<S> {
inner: client.into_inner(),
http_semaphore: self.http_semaphore.clone(),
})
});
ConcurrentLimitAccessor {
inner,
semaphore: self.operation_semaphore.clone(),
}
}
}
#[doc(hidden)]
pub struct ConcurrentLimitHttpFetcher<S: ConcurrentLimitSemaphore> {
inner: HttpFetcher,
http_semaphore: Option<S>,
}
impl<S: ConcurrentLimitSemaphore> HttpFetch for ConcurrentLimitHttpFetcher<S>
where
S::Permit: Unpin,
{
async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
let Some(semaphore) = self.http_semaphore.clone() else {
return self.inner.fetch(req).await;
};
let permit = semaphore.acquire().await;
let resp = self.inner.fetch(req).await?;
let (parts, body) = resp.into_parts();
let body = body.map_inner(|s| {
Box::new(ConcurrentLimitStream::<_, S::Permit> {
inner: s,
_permit: permit,
})
});
Ok(http::Response::from_parts(parts, body))
}
}
struct ConcurrentLimitStream<S, P> {
inner: S,
_permit: P,
}
impl<S, P> Stream for ConcurrentLimitStream<S, P>
where
S: Stream<Item = Result<Buffer>> + Unpin + 'static,
P: Unpin,
{
type Item = Result<Buffer>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.inner.poll_next_unpin(cx)
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct ConcurrentLimitAccessor<A: Access, S: ConcurrentLimitSemaphore> {
inner: A,
semaphore: S,
}
impl<A: Access, S: ConcurrentLimitSemaphore> std::fmt::Debug for ConcurrentLimitAccessor<A, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConcurrentLimitAccessor")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl<A: Access, S: ConcurrentLimitSemaphore> LayeredAccess for ConcurrentLimitAccessor<A, S>
where
S::Permit: Unpin,
{
type Inner = A;
type Reader = ConcurrentLimitWrapper<A::Reader, S::Permit>;
type Writer = ConcurrentLimitWrapper<A::Writer, S::Permit>;
type Lister = ConcurrentLimitWrapper<A::Lister, S::Permit>;
type Deleter = ConcurrentLimitWrapper<A::Deleter, S::Permit>;
type Copier = ConcurrentLimitWrapper<A::Copier, S::Permit>;
fn inner(&self) -> &Self::Inner {
&self.inner
}
async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
let _permit = self.semaphore.acquire().await;
self.inner.create_dir(path, args).await
}
async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
let permit = self.semaphore.acquire().await;
self.inner
.read(path, args)
.await
.map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
}
async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let permit = self.semaphore.acquire().await;
self.inner
.write(path, args)
.await
.map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
}
async fn copy(
&self,
from: &str,
to: &str,
args: OpCopy,
opts: OpCopier,
) -> Result<(RpCopy, Self::Copier)> {
let permit = self.semaphore.acquire().await;
self.inner
.copy(from, to, args, opts.clone())
.await
.map(|(rp, c)| (rp, ConcurrentLimitWrapper::new(c, permit)))
}
async fn rename(&self, from: &str, to: &str, args: OpRename) -> Result<RpRename> {
let _permit = self.semaphore.acquire().await;
self.inner.rename(from, to, args).await
}
async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
let _permit = self.semaphore.acquire().await;
self.inner.stat(path, args).await
}
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
let permit = self.semaphore.acquire().await;
self.inner
.delete()
.await
.map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
}
async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
let permit = self.semaphore.acquire().await;
self.inner
.list(path, args)
.await
.map(|(rp, s)| (rp, ConcurrentLimitWrapper::new(s, permit)))
}
}
#[doc(hidden)]
pub struct ConcurrentLimitWrapper<R, P> {
inner: R,
_permit: P,
}
impl<R, P> ConcurrentLimitWrapper<R, P> {
fn new(inner: R, permit: P) -> Self {
Self {
inner,
_permit: permit,
}
}
}
impl<R: oio::Read, P: Send + Sync + 'static + Unpin> oio::Read for ConcurrentLimitWrapper<R, P> {
async fn read(&mut self) -> Result<Buffer> {
self.inner.read().await
}
}
impl<R: oio::Write, P: Send + Sync + 'static + Unpin> oio::Write for ConcurrentLimitWrapper<R, P> {
async fn write(&mut self, bs: Buffer) -> Result<()> {
self.inner.write(bs).await
}
async fn close(&mut self) -> Result<Metadata> {
self.inner.close().await
}
async fn abort(&mut self) -> Result<()> {
self.inner.abort().await
}
}
impl<R: oio::List, P: Send + Sync + 'static + Unpin> oio::List for ConcurrentLimitWrapper<R, P> {
async fn next(&mut self) -> Result<Option<oio::Entry>> {
self.inner.next().await
}
}
impl<R: oio::Delete, P: Send + Sync + 'static + Unpin> oio::Delete
for ConcurrentLimitWrapper<R, P>
{
async fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
self.inner.delete(path, args).await
}
async fn close(&mut self) -> Result<()> {
self.inner.close().await
}
}
impl<C: oio::Copy, P: Send + Sync + 'static + Unpin> oio::Copy for ConcurrentLimitWrapper<C, P> {
async fn next(&mut self) -> Result<Option<usize>> {
self.inner.next().await
}
async fn close(&mut self) -> Result<Metadata> {
self.inner.close().await
}
async fn abort(&mut self) -> Result<()> {
self.inner.abort().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use opendal_core::Operator;
use opendal_core::OperatorBuilder;
use opendal_core::services;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use futures::stream;
use http::Response;
#[tokio::test]
async fn operation_semaphore_can_be_shared() {
let semaphore = Arc::new(Semaphore::new(1));
let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
let permit = semaphore.clone().acquire_owned(1).await;
let op = Operator::new(services::Memory::default())
.expect("operator must build")
.layer(layer)
.finish();
let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
assert!(
blocked.is_err(),
"operation should be limited by shared semaphore"
);
drop(permit);
let completed = timeout(Duration::from_millis(50), op.stat("any")).await;
assert!(
completed.is_ok(),
"operation should proceed once permit is released"
);
}
#[tokio::test]
async fn operation_semaphore_limits_copy_and_rename() {
#[derive(Clone, Debug)]
struct CopyRenameBackend {
info: Arc<AccessorInfo>,
}
impl Access for CopyRenameBackend {
type Reader = ();
type Writer = ();
type Lister = ();
type Deleter = ();
type Copier = oio::Copier;
fn info(&self) -> Arc<AccessorInfo> {
self.info.clone()
}
async fn copy(
&self,
_: &str,
_: &str,
_: OpCopy,
_: OpCopier,
) -> Result<(RpCopy, Self::Copier)> {
Ok((RpCopy::default(), Box::new(())))
}
async fn rename(&self, _: &str, _: &str, _: OpRename) -> Result<RpRename> {
Ok(RpRename::default())
}
}
let semaphore = Arc::new(Semaphore::new(1));
let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
let info = Arc::new(AccessorInfo::default());
info.set_native_capability(Capability {
copy: true,
rename: true,
..Default::default()
});
let op = OperatorBuilder::new(CopyRenameBackend { info })
.layer(layer)
.finish();
let permit = semaphore.clone().acquire_owned(1).await;
let copy = timeout(Duration::from_millis(50), op.copy("from", "to")).await;
assert!(copy.is_err(), "copy should wait for the operation permit");
let rename = timeout(Duration::from_millis(50), op.rename("from", "to")).await;
assert!(
rename.is_err(),
"rename should wait for the operation permit"
);
drop(permit);
timeout(Duration::from_millis(50), op.copy("from", "to"))
.await
.expect("copy should proceed once permit is released")
.expect("copy should succeed");
timeout(Duration::from_millis(50), op.rename("from", "to"))
.await
.expect("rename should proceed once permit is released")
.expect("rename should succeed");
}
#[tokio::test]
async fn operation_semaphore_held_until_copier_dropped() {
#[derive(Clone, Debug)]
struct CopierBackend {
info: Arc<AccessorInfo>,
}
impl Access for CopierBackend {
type Reader = ();
type Writer = ();
type Lister = ();
type Deleter = ();
type Copier = oio::Copier;
fn info(&self) -> Arc<AccessorInfo> {
self.info.clone()
}
async fn copy(
&self,
_: &str,
_: &str,
_: OpCopy,
_: OpCopier,
) -> Result<(RpCopy, Self::Copier)> {
Ok((RpCopy::default(), Box::new(())))
}
async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
Ok(RpStat::new(Metadata::new(EntryMode::FILE)))
}
}
let semaphore = Arc::new(Semaphore::new(1));
let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone());
let info = Arc::new(AccessorInfo::default());
info.set_native_capability(Capability {
copy: true,
stat: true,
..Default::default()
});
let op = OperatorBuilder::new(CopierBackend { info })
.layer(layer)
.finish();
let copier = timeout(Duration::from_millis(50), op.copier("from", "to"))
.await
.expect("copier setup should not block")
.expect("copier should be created");
let blocked = timeout(Duration::from_millis(50), op.stat("any")).await;
assert!(
blocked.is_err(),
"stat should wait while the copier holds the permit"
);
drop(copier);
timeout(Duration::from_millis(50), op.stat("any"))
.await
.expect("stat should proceed once the copier is dropped")
.expect("stat should succeed");
}
#[tokio::test]
async fn concurrent_chunked_read_with_http_limit() {
use opendal_core::raw::*;
struct EchoFetcher;
impl HttpFetch for EchoFetcher {
async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
let data = req.into_body();
let len = data.len() as u64;
let body =
HttpBody::new(Box::pin(stream::once(async move { Ok(data) })), Some(len));
Ok(http::Response::builder()
.status(http::StatusCode::OK)
.body(body)
.unwrap())
}
}
#[derive(Clone, Debug)]
struct HttpBackend {
info: Arc<AccessorInfo>,
content: Buffer,
}
impl Access for HttpBackend {
type Reader = HttpBody;
type Writer = ();
type Lister = ();
type Deleter = ();
type Copier = oio::Copier;
fn info(&self) -> Arc<AccessorInfo> {
self.info.clone()
}
async fn read(&self, _: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
let range = args.range();
let start = range.offset() as usize;
let data = match range.size() {
Some(sz) => self.content.slice(start..start + sz as usize),
None => self.content.slice(start..),
};
let req = http::Request::get("http://fake").body(data).unwrap();
let resp = self.info.http_client().fetch(req).await?;
Ok((
RpRead::new(Metadata::new(EntryMode::FILE).with_content_length(0)),
resp.into_body(),
))
}
async fn stat(&self, _: &str, _: OpStat) -> Result<RpStat> {
Ok(RpStat::new(
Metadata::new(EntryMode::FILE).with_content_length(self.content.len() as u64),
))
}
async fn write(&self, _: &str, _: OpWrite) -> Result<(RpWrite, Self::Writer)> {
Err(Error::new(ErrorKind::Unsupported, "not needed"))
}
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
Err(Error::new(ErrorKind::Unsupported, "not needed"))
}
async fn list(&self, _: &str, _: OpList) -> Result<(RpList, Self::Lister)> {
Err(Error::new(ErrorKind::Unsupported, "not needed"))
}
}
let content = Buffer::from(vec![0u8; 4096]);
let info = Arc::new(AccessorInfo::default());
info.update_http_client(|_| HttpClient::with(EchoFetcher));
let op = OperatorBuilder::new(HttpBackend {
info,
content: content.clone(),
})
.layer(ConcurrentLimitLayer::new(1024).with_http_concurrent_limit(2))
.finish();
let result = timeout(Duration::from_secs(5), async {
op.reader_with("test")
.chunk(256)
.concurrent(4)
.await
.expect("reader must build")
.read(..)
.await
})
.await;
let buf = result
.expect("read must not deadlock (timeout)")
.expect("read must succeed");
assert_eq!(buf.to_bytes(), content.to_bytes());
}
#[tokio::test]
async fn http_semaphore_holds_until_body_dropped() {
struct DummyFetcher;
impl HttpFetch for DummyFetcher {
async fn fetch(&self, _req: http::Request<Buffer>) -> Result<Response<HttpBody>> {
let body = HttpBody::new(stream::empty(), None);
Ok(Response::builder()
.status(http::StatusCode::OK)
.body(body)
.expect("response must build"))
}
}
let semaphore = Arc::new(Semaphore::new(1));
let layer = ConcurrentLimitLayer::new(1).with_http_semaphore(semaphore.clone());
let fetcher = ConcurrentLimitHttpFetcher::<Arc<Semaphore>> {
inner: HttpClient::with(DummyFetcher).into_inner(),
http_semaphore: layer.http_semaphore.clone(),
};
let request = http::Request::builder()
.uri("http://example.invalid/")
.body(Buffer::new())
.expect("request must build");
let _resp = fetcher
.fetch(request)
.await
.expect("first fetch should succeed");
let request = http::Request::builder()
.uri("http://example.invalid/")
.body(Buffer::new())
.expect("request must build");
let blocked = timeout(Duration::from_millis(50), fetcher.fetch(request)).await;
assert!(
blocked.is_err(),
"http fetch should block while the body holds the permit"
);
}
}