load_balancer/
limit.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::{Acquire, Release};
4use std::{
5    future::Future,
6    sync::{Arc, atomic::AtomicU64},
7    time::{Duration, Instant},
8};
9use tokio::{
10    spawn,
11    sync::{Mutex, RwLock},
12    task::{JoinHandle, yield_now},
13    time::sleep,
14};
15
16/// A single entry in the `LimitLoadBalancer`.
17///
18/// Tracks the maximum number of allowed allocations within the interval
19/// and the current allocation count.
20pub struct Entry<T>
21where
22    T: Send + Sync + Clone + 'static,
23{
24    /// Maximum number of allocations allowed for this entry in the interval.
25    pub max_count: u64,
26    /// Current allocation count within the interval.
27    pub count: AtomicU64,
28    /// The underlying value/resource of type `T`.
29    pub value: T,
30}
31
32impl<T> Clone for Entry<T>
33where
34    T: Send + Sync + Clone + 'static,
35{
36    fn clone(&self) -> Self {
37        Self {
38            max_count: self.max_count.clone(),
39            count: self.count.load(Acquire).into(),
40            value: self.value.clone(),
41        }
42    }
43}
44
45/// Internal reference structure for `LimitLoadBalancer`.
46///
47/// Holds the entries and the interval timer.
48pub struct LimitLoadBalancerRef<T>
49where
50    T: Send + Sync + Clone + 'static,
51{
52    /// The entries managed by this load balancer.
53    pub entries: RwLock<Vec<Entry<T>>>,
54    /// Timer task handle for resetting counts periodically.
55    pub timer: Mutex<Option<JoinHandle<()>>>,
56    /// The interval at which counts are reset.
57    pub interval: RwLock<Duration>,
58    /// The next scheduled reset time.
59    pub next_reset: RwLock<Instant>,
60}
61
62/// A load balancer that limits the number of allocations per entry
63/// over a specified time interval.
64///
65/// This implementation supports both async and sync allocation.
66#[derive(Clone)]
67pub struct LimitLoadBalancer<T>
68where
69    T: Send + Sync + Clone + 'static,
70{
71    /// Shared reference to the internal state.
72    inner: Arc<LimitLoadBalancerRef<T>>,
73}
74
75impl<T> LimitLoadBalancer<T>
76where
77    T: Send + Sync + Clone + 'static,
78{
79    /// Create a new `LimitLoadBalancer` with the default interval of 1 second.
80    ///
81    /// # Arguments
82    ///
83    /// * `entries` - A vector of tuples `(max_count, value)`.
84    pub fn new(entries: Vec<(u64, T)>) -> Self {
85        Self::new_interval(entries, Duration::from_secs(1))
86    }
87
88    /// Create a new `LimitLoadBalancer` with a custom interval duration.
89    ///
90    /// # Arguments
91    ///
92    /// * `entries` - A vector of tuples `(max_count, value)`.
93    /// * `interval` - Duration after which allocation counts are reset.
94    pub fn new_interval(entries: Vec<(u64, T)>, interval: Duration) -> Self {
95        Self {
96            inner: Arc::new(LimitLoadBalancerRef {
97                entries: entries
98                    .into_iter()
99                    .map(|(max_count, value)| Entry {
100                        max_count,
101                        value,
102                        count: 0.into(),
103                    })
104                    .collect::<Vec<_>>()
105                    .into(),
106                timer: Mutex::new(None),
107                interval: interval.into(),
108                next_reset: RwLock::new(Instant::now() + interval),
109            }),
110        }
111    }
112
113    /// Update the load balancer using an async callback.
114    pub async fn update<F, R, N>(&self, handle: F) -> anyhow::Result<N>
115    where
116        F: Fn(Arc<LimitLoadBalancerRef<T>>) -> R,
117        R: Future<Output = anyhow::Result<N>>,
118    {
119        handle(self.inner.clone()).await
120    }
121
122    /// Asynchronously allocate an entry, skipping the specified index.
123    /// Loops until a valid entry is found.
124    pub async fn alloc_skip(&self, index: usize) -> (usize, T) {
125        loop {
126            if let Some(v) = self.try_alloc_skip(index) {
127                return v;
128            }
129
130            let now = Instant::now();
131
132            let next = *self.inner.next_reset.read().await;
133
134            let remaining = if now < next {
135                next - now
136            } else {
137                Duration::ZERO
138            };
139
140            if remaining > Duration::ZERO {
141                sleep(remaining).await;
142            } else {
143                yield_now().await;
144            }
145        }
146    }
147
148    /// Try to allocate an entry without awaiting.
149    /// Returns `None` immediately if no entry is available.
150    pub fn try_alloc_skip(&self, index: usize) -> Option<(usize, T)> {
151        if let Ok(mut timer_guard) = self.inner.timer.try_lock() {
152            if timer_guard.is_none() {
153                let this = self.inner.clone();
154
155                *timer_guard = Some(spawn(async move {
156                    let mut interval = *this.interval.read().await;
157
158                    *this.next_reset.write().await = Instant::now() + interval;
159
160                    loop {
161                        sleep(match this.interval.try_read() {
162                            Ok(v) => {
163                                interval = *v;
164                                interval
165                            }
166                            Err(_) => interval,
167                        })
168                        .await;
169
170                        let now = Instant::now();
171
172                        let entries = this.entries.read().await;
173
174                        for entry in entries.iter() {
175                            entry.count.store(0, Release);
176                        }
177
178                        *this.next_reset.write().await = now + interval;
179                    }
180                }));
181            }
182        }
183
184        if let Ok(entries) = self.inner.entries.try_read() {
185            for (i, n) in entries.iter().enumerate() {
186                if i == index {
187                    continue;
188                }
189
190                let count = n.count.load(Acquire);
191
192                if n.max_count == 0
193                    || count < n.max_count
194                        && n.count
195                            .compare_exchange(count, count + 1, Release, Acquire)
196                            .is_ok()
197                {
198                    return Some((i, n.value.clone()));
199                }
200            }
201        }
202
203        None
204    }
205}
206
207impl<T> LoadBalancer<T> for LimitLoadBalancer<T>
208where
209    T: Send + Sync + Clone + 'static,
210{
211    /// Asynchronously allocate a resource from the load balancer.
212    fn alloc(&self) -> impl Future<Output = T> + Send {
213        async move { self.alloc_skip(usize::MAX).await.1 }
214    }
215
216    /// Synchronously try to allocate a resource.
217    fn try_alloc(&self) -> Option<T> {
218        self.try_alloc_skip(usize::MAX).map(|v| v.1)
219    }
220}
221
222#[async_trait]
223impl<T> BoxLoadBalancer<T> for LimitLoadBalancer<T>
224where
225    T: Send + Sync + Clone + 'static,
226{
227    /// Asynchronously allocate a resource from the load balancer.
228    async fn alloc(&self) -> T {
229        self.alloc_skip(usize::MAX).await.1
230    }
231
232    /// Synchronously try to allocate a resource.
233    fn try_alloc(&self) -> Option<T> {
234        self.try_alloc_skip(usize::MAX).map(|v| v.1)
235    }
236}