load_balancer/
limit.rs

1use crate::{BoxLoadBalancer, LoadBalancer};
2use async_trait::async_trait;
3use std::sync::atomic::Ordering::Acquire;
4use std::sync::atomic::Ordering::Release;
5use std::{
6    future::Future,
7    sync::{Arc, atomic::AtomicU64},
8    time::Duration,
9};
10use tokio::{
11    spawn,
12    sync::{Mutex, RwLock},
13    task::{JoinHandle, yield_now},
14    time::sleep,
15};
16
17/// A single entry in the `LimitLoadBalancer`.
18///
19/// Tracks the maximum number of allowed allocations within the interval
20/// and the current allocation count.
21pub struct Entry<T>
22where
23    T: Send + Sync + Clone + 'static,
24{
25    /// Maximum number of allocations allowed for this entry in the interval.
26    pub max_count: u64,
27    /// Current allocation count within the interval.
28    pub count: AtomicU64,
29    /// The underlying value/resource of type `T`.
30    pub value: T,
31}
32
33impl<T> Clone for Entry<T>
34where
35    T: Send + Sync + Clone + 'static,
36{
37    fn clone(&self) -> Self {
38        Self {
39            max_count: self.max_count.clone(),
40            count: self.count.load(Acquire).into(),
41            value: self.value.clone(),
42        }
43    }
44}
45
46/// Internal reference structure for `LimitLoadBalancer`.
47///
48/// Holds the entries and the interval timer.
49pub struct LimitLoadBalancerRef<T>
50where
51    T: Send + Sync + Clone + 'static,
52{
53    /// The entries managed by this load balancer.
54    pub entries: RwLock<Vec<Entry<T>>>,
55    /// Timer task handle for resetting counts periodically.
56    pub timer: Mutex<Option<JoinHandle<()>>>,
57    /// The interval at which counts are reset.
58    pub interval: RwLock<Duration>,
59}
60
61/// A load balancer that limits the number of allocations per entry
62/// over a specified time interval.
63///
64/// This implementation supports both async and sync allocation.
65#[derive(Clone)]
66pub struct LimitLoadBalancer<T>
67where
68    T: Send + Sync + Clone + 'static,
69{
70    /// Shared reference to the internal state.
71    inner: Arc<LimitLoadBalancerRef<T>>,
72}
73
74impl<T> LimitLoadBalancer<T>
75where
76    T: Send + Sync + Clone + 'static,
77{
78    /// Create a new `LimitLoadBalancer` with the default interval of 1 second.
79    ///
80    /// # Arguments
81    ///
82    /// * `entries` - A vector of tuples `(max_count, value)`.
83    pub fn new(entries: Vec<(u64, T)>) -> Self {
84        Self {
85            inner: Arc::new(LimitLoadBalancerRef {
86                entries: entries
87                    .into_iter()
88                    .map(|(max_count, value)| Entry {
89                        max_count,
90                        value,
91                        count: 0.into(),
92                    })
93                    .collect::<Vec<_>>()
94                    .into(),
95                timer: None.into(),
96                interval: Duration::from_secs(1).into(),
97            }),
98        }
99    }
100
101    /// Create a new `LimitLoadBalancer` with a custom interval duration.
102    ///
103    /// # Arguments
104    ///
105    /// * `entries` - A vector of tuples `(max_count, value)`.
106    /// * `interval` - Duration after which allocation counts are reset.
107    pub fn new_interval(entries: Vec<(u64, T)>, interval: Duration) -> Self {
108        Self {
109            inner: Arc::new(LimitLoadBalancerRef {
110                entries: entries
111                    .into_iter()
112                    .map(|(max_count, value)| Entry {
113                        max_count,
114                        value,
115                        count: 0.into(),
116                    })
117                    .collect::<Vec<_>>()
118                    .into(),
119                timer: Mutex::new(None),
120                interval: interval.into(),
121            }),
122        }
123    }
124
125    /// Update the load balancer using an async callback.
126    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
127    where
128        F: Fn(Arc<LimitLoadBalancerRef<T>>) -> R,
129        R: Future<Output = anyhow::Result<()>>,
130    {
131        handle(self.inner.clone()).await
132    }
133
134    /// Asynchronously allocate an entry, skipping the specified index.
135    /// Loops until a valid entry is found.
136    pub async fn alloc_skip(&self, index: usize) -> (usize, T) {
137        loop {
138            match self.try_alloc_skip(index) {
139                Some(v) => return v,
140                _ => yield_now().await,
141            };
142        }
143    }
144
145    /// Try to allocate an entry without awaiting.
146    /// Returns `None` immediately if no entry is available.
147    pub fn try_alloc_skip(&self, index: usize) -> Option<(usize, T)> {
148        if let Ok(mut v) = self.inner.timer.try_lock() {
149            if v.is_none() {
150                let this = self.inner.clone();
151
152                *v = Some(spawn(async move {
153                    let mut interval = *this.interval.read().await;
154
155                    loop {
156                        sleep(match this.interval.try_read() {
157                            Ok(v) => {
158                                interval = *v;
159                                interval
160                            }
161                            Err(_) => interval,
162                        })
163                        .await;
164
165                        for i in this.entries.read().await.iter() {
166                            i.count.store(0, Release);
167                        }
168                    }
169                }));
170            }
171        }
172
173        if let Ok(v) = self.inner.entries.try_read() {
174            for (i, n) in v.iter().enumerate() {
175                if i == index {
176                    continue;
177                }
178
179                let count = n.count.load(Acquire);
180
181                if n.max_count == 0
182                    || count < n.max_count
183                        && n.count
184                            .compare_exchange(count, count + 1, Release, Acquire)
185                            .is_ok()
186                {
187                    return Some((i, n.value.clone()));
188                }
189            }
190        }
191
192        None
193    }
194}
195
196impl<T> LoadBalancer<T> for LimitLoadBalancer<T>
197where
198    T: Send + Sync + Clone + 'static,
199{
200    /// Asynchronously allocate a resource from the load balancer.
201    fn alloc(&self) -> impl Future<Output = T> + Send {
202        async move { self.alloc_skip(usize::MAX).await.1 }
203    }
204
205    /// Synchronously try to allocate a resource.
206    fn try_alloc(&self) -> Option<T> {
207        self.try_alloc_skip(usize::MAX).map(|v| v.1)
208    }
209}
210
211#[async_trait]
212impl<T> BoxLoadBalancer<T> for LimitLoadBalancer<T>
213where
214    T: Send + Sync + Clone + 'static,
215{
216    /// Asynchronously allocate a resource from the load balancer.
217    async fn alloc(&self) -> T {
218        self.alloc_skip(usize::MAX).await.1
219    }
220
221    /// Synchronously try to allocate a resource.
222    fn try_alloc(&self) -> Option<T> {
223        self.try_alloc_skip(usize::MAX).map(|v| v.1)
224    }
225}