load_balancer/
limit.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::Acquire;
4use std::sync::atomic::Ordering::Release;
5use std::{
6    future::Future,
7    sync::{Arc, atomic::AtomicU64},
8    time::Duration,
9};
10use tokio::{
11    spawn,
12    sync::{Mutex, RwLock},
13    task::{JoinHandle, yield_now},
14    time::sleep,
15};
16
17/// A single entry in the `LimitLoadBalancer`.
18///
19/// Tracks the maximum number of allowed allocations within the interval
20/// and the current allocation count.
21pub struct Entry<T>
22where
23    T: Send + Sync + Clone + 'static,
24{
25    /// Maximum number of allocations allowed for this entry in the interval.
26    pub max_count: u64,
27    /// Current allocation count within the interval.
28    pub count: AtomicU64,
29    /// The underlying value/resource of type `T`.
30    pub value: T,
31}
32
33impl<T> Clone for Entry<T>
34where
35    T: Send + Sync + Clone + 'static,
36{
37    fn clone(&self) -> Self {
38        Self {
39            max_count: self.max_count.clone(),
40            count: self.count.load(Acquire).into(),
41            value: self.value.clone(),
42        }
43    }
44}
45
46/// Internal reference structure for `LimitLoadBalancer`.
47///
48/// Holds the entries and the interval timer.
49pub struct LimitLoadBalancerRef<T>
50where
51    T: Send + Sync + Clone + 'static,
52{
53    /// The entries managed by this load balancer.
54    pub entries: RwLock<Vec<Entry<T>>>,
55    /// Timer task handle for resetting counts periodically.
56    pub timer: Mutex<Option<JoinHandle<()>>>,
57    /// The interval at which counts are reset.
58    pub interval: RwLock<Duration>,
59}
60
61/// A load balancer that limits the number of allocations per entry
62/// over a specified time interval.
63///
64/// This implementation supports both async and sync allocation.
65#[derive(Clone)]
66pub struct LimitLoadBalancer<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    /// Shared reference to the internal state.
71    inner: Arc<LimitLoadBalancerRef<T>>,
72}
73
74impl<T> LimitLoadBalancer<T>
75where
76    T: Send + Sync + Clone + 'static,
77{
78    /// Create a new `LimitLoadBalancer` with the default interval of 1 second.
79    pub fn new(entries: Vec<(u64, T)>) -> Self {
80        Self {
81            inner: Arc::new(LimitLoadBalancerRef {
82                entries: entries
83                    .into_iter()
84                    .map(|(max_count, value)| Entry {
85                        max_count,
86                        value,
87                        count: 0.into(),
88                    })
89                    .collect::<Vec<_>>()
90                    .into(),
91                timer: None.into(),
92                interval: Duration::from_secs(1).into(),
93            }),
94        }
95    }
96
97    /// Create a new `LimitLoadBalancer` with a custom interval duration.
98    pub fn new_interval(entries: Vec<(u64, T)>, interval: Duration) -> Self {
99        Self {
100            inner: Arc::new(LimitLoadBalancerRef {
101                entries: entries
102                    .into_iter()
103                    .map(|(max_count, value)| Entry {
104                        max_count,
105                        value,
106                        count: 0.into(),
107                    })
108                    .collect::<Vec<_>>()
109                    .into(),
110                timer: Mutex::new(None),
111                interval: interval.into(),
112            }),
113        }
114    }
115
116    /// Update the load balancer using an async callback.
117    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
118    where
119        F: Fn(Arc<LimitLoadBalancerRef<T>>) -> R,
120        R: Future<Output = anyhow::Result<()>>,
121    {
122        handle(self.inner.clone()).await
123    }
124
125    /// Asynchronously allocate an entry, skipping the specified index.
126    /// Loops until a valid entry is found.
127    async fn alloc_skip(&self, index: usize) -> Option<(usize, T)> {
128        loop {
129            match self.try_alloc_skip(index) {
130                Some(v) => return Some(v),
131                _ => yield_now().await,
132            };
133        }
134    }
135
136    /// Try to allocate an entry without awaiting.
137    /// Returns `None` immediately if no entry is available.
138    fn try_alloc_skip(&self, index: usize) -> Option<(usize, T)> {
139        if let Ok(mut v) = self.inner.timer.try_lock() {
140            if v.is_none() {
141                let this = self.inner.clone();
142
143                *v = Some(spawn(async move {
144                    let mut interval = *this.interval.read().await;
145
146                    loop {
147                        sleep(match this.interval.try_read() {
148                            Ok(v) => {
149                                interval = *v;
150                                interval
151                            }
152                            Err(_) => interval,
153                        })
154                        .await;
155
156                        for i in this.entries.read().await.iter() {
157                            i.count.store(0, Release);
158                        }
159                    }
160                }));
161            }
162        }
163
164        if let Ok(v) = self.inner.entries.try_read() {
165            for (i, n) in v.iter().enumerate() {
166                if i == index {
167                    continue;
168                }
169
170                let count = n.count.load(Acquire);
171
172                if n.max_count == 0
173                    || count < n.max_count
174                        && n.count
175                            .compare_exchange(count, count + 1, Release, Acquire)
176                            .is_ok()
177                {
178                    return Some((i, n.value.clone()));
179                }
180            }
181        }
182
183        None
184    }
185}
186
187impl<T> LoadBalancer<T> for LimitLoadBalancer<T>
188where
189    T: Send + Sync + Clone + 'static,
190{
191    /// Asynchronously allocate a resource from the load balancer.
192    /// Returns `Some(T)` if an entry is available.
193    fn alloc(&self) -> impl Future<Output = Option<T>> + Send {
194        async move { self.alloc_skip(usize::MAX).await.map(|v| v.1) }
195    }
196
197    /// Synchronously try to allocate a resource.
198    /// Returns `Some(T)` if an entry is available.
199    fn try_alloc(&self) -> Option<T> {
200        self.try_alloc_skip(usize::MAX).map(|v| v.1)
201    }
202}
203
204#[async_trait]
205impl<T> BoxLoadBalancer<T> for LimitLoadBalancer<T>
206where
207    T: Send + Sync + Clone + 'static,
208{
209    /// Asynchronously allocate a resource from the load balancer.
210    /// Returns `Some(T)` if an entry is available.
211    async fn alloc(&self) -> Option<T> {
212        self.alloc_skip(usize::MAX).await.map(|v| v.1)
213    }
214
215    /// Synchronously try to allocate a resource.
216    fn try_alloc(&self) -> Option<T> {
217        self.try_alloc_skip(usize::MAX).map(|v| v.1)
218    }
219}