1use std::sync::Arc;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3
4#[async_trait::async_trait]
5pub trait Acquire: Clone + Send + Sync + 'static {
6 type Handle: Send + Sync + 'static;
7
8 async fn acquire(&self) -> Self::Handle;
9}
10
11#[derive(Clone)]
12pub struct Limit<A, M> {
13 acquire: A,
14 inner: M,
15}
16
17impl<A, T> Limit<A, T> {
20 pub fn new(acquire: A, inner: T) -> Self {
21 Self { acquire, inner }
22 }
23}
24
25#[async_trait::async_trait]
26impl<T, A, M> crate::MakeOrt<T> for Limit<A, M>
27where
28 T: Send + 'static,
29 A: Acquire,
30 M: crate::MakeOrt<T>,
31 M::Ort: crate::Ort,
32{
33 type Ort = Limit<A, M::Ort>;
34
35 async fn make_ort(&mut self, target: T) -> Result<Self::Ort, crate::Error> {
36 let inner = self.inner.make_ort(target).await?;
37 let acquire = self.acquire.clone();
38 Ok(Limit { acquire, inner })
39 }
40}
41
42#[async_trait::async_trait]
43impl<A, O> crate::Ort for Limit<A, O>
44where
45 A: Acquire,
46 O: crate::Ort,
47{
48 async fn ort(&mut self, spec: crate::Spec) -> Result<crate::Reply, crate::Error> {
49 let permit = self.acquire.acquire().await;
50 let reply = self.inner.ort(spec).await;
51 drop(permit);
52 reply
53 }
54}
55
56#[async_trait::async_trait]
59impl<A: Acquire, B: Acquire> Acquire for (A, B) {
60 type Handle = (A::Handle, B::Handle);
61
62 async fn acquire(&self) -> Self::Handle {
63 tokio::join!(self.0.acquire(), self.1.acquire())
64 }
65}
66
67#[async_trait::async_trait]
68impl<A: Acquire> Acquire for Option<A> {
69 type Handle = Option<A::Handle>;
70
71 async fn acquire(&self) -> Self::Handle {
72 match self {
73 Some(semaphore) => Some(semaphore.acquire().await),
74 None => None,
75 }
76 }
77}
78
79#[async_trait::async_trait]
80impl Acquire for Arc<Semaphore> {
81 type Handle = OwnedSemaphorePermit;
82
83 async fn acquire(&self) -> Self::Handle {
84 self.clone()
85 .acquire_owned()
86 .await
87 .expect("Semaphore must not be closed")
88 }
89}