ort_core/
limit.rs

1use std::sync::Arc;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3
4#[async_trait::async_trait]
5pub trait Acquire: Clone + Send + Sync + 'static {
6    type Handle: Send + Sync + 'static;
7
8    async fn acquire(&self) -> Self::Handle;
9}
10
11#[derive(Clone)]
12pub struct Limit<A, M> {
13    acquire: A,
14    inner: M,
15}
16
17// === impl Limit ===
18
19impl<A, T> Limit<A, T> {
20    pub fn new(acquire: A, inner: T) -> Self {
21        Self { acquire, inner }
22    }
23}
24
25#[async_trait::async_trait]
26impl<T, A, M> crate::MakeOrt<T> for Limit<A, M>
27where
28    T: Send + 'static,
29    A: Acquire,
30    M: crate::MakeOrt<T>,
31    M::Ort: crate::Ort,
32{
33    type Ort = Limit<A, M::Ort>;
34
35    async fn make_ort(&mut self, target: T) -> Result<Self::Ort, crate::Error> {
36        let inner = self.inner.make_ort(target).await?;
37        let acquire = self.acquire.clone();
38        Ok(Limit { acquire, inner })
39    }
40}
41
42#[async_trait::async_trait]
43impl<A, O> crate::Ort for Limit<A, O>
44where
45    A: Acquire,
46    O: crate::Ort,
47{
48    async fn ort(&mut self, spec: crate::Spec) -> Result<crate::Reply, crate::Error> {
49        let permit = self.acquire.acquire().await;
50        let reply = self.inner.ort(spec).await;
51        drop(permit);
52        reply
53    }
54}
55
56// === impl Acqire ===
57
58#[async_trait::async_trait]
59impl<A: Acquire, B: Acquire> Acquire for (A, B) {
60    type Handle = (A::Handle, B::Handle);
61
62    async fn acquire(&self) -> Self::Handle {
63        tokio::join!(self.0.acquire(), self.1.acquire())
64    }
65}
66
67#[async_trait::async_trait]
68impl<A: Acquire> Acquire for Option<A> {
69    type Handle = Option<A::Handle>;
70
71    async fn acquire(&self) -> Self::Handle {
72        match self {
73            Some(semaphore) => Some(semaphore.acquire().await),
74            None => None,
75        }
76    }
77}
78
79#[async_trait::async_trait]
80impl Acquire for Arc<Semaphore> {
81    type Handle = OwnedSemaphorePermit;
82
83    async fn acquire(&self) -> Self::Handle {
84        self.clone()
85            .acquire_owned()
86            .await
87            .expect("Semaphore must not be closed")
88    }
89}