1use crate::bucket::{Bucket, BucketFactory, QuotaInfo, RequestAllowed};
6use crate::distribution_formula::DistributionFormula;
7use crate::key_manager::{KeyManager, KeyManagerFactory};
8use crate::{RateLimit, RateLimitError, RateLimitResult, RateLimitStatistics};
9use data_storage_lib::ll::distributed::{
10 DistributedStorage, DistributedStorageClient, DistributedStorageError, Store,
11};
12use data_storage_lib::ll::local::{LocalStorage, LocalStorageError, SharedData};
13use data_storage_lib::ll::{distributed, local};
14use lock_lib::{Lock, LockBuilder, TryLock};
15use pdk_core::classy::timer::Timer;
16use pdk_core::classy::{Clock, TimeUnit};
17use pdk_core::log::{debug, warn};
18use std::rc::Rc;
19use std::time::Duration;
20
21const MAX_HOPS: usize = 200;
22const LOCK_EXPIRATION: Duration = Duration::from_secs(13);
23
24pub struct RateLimitInstance {
26 store: String,
27 key_manager_factory: KeyManagerFactory,
28 bucket_factory: BucketFactory,
29 formula: DistributionFormula,
30 clock: Rc<dyn Clock>,
31 shared_data: Rc<SharedData>,
32 distributed_storage: Option<Rc<DistributedStorageClient>>,
33 timer: Option<Rc<Timer>>,
34 lock_builder: Rc<LockBuilder>,
35}
36
37impl RateLimit for RateLimitInstance {
38 async fn is_allowed(
39 &self,
40 group_selector: &str,
41 bucket_selector: &str,
42 increment: usize,
43 ) -> Result<RateLimitResult, RateLimitError> {
44 let mut hops: usize = 0;
45
46 let keys = self
47 .key_manager_factory
48 .create(bucket_selector, group_selector);
49
50 loop {
51 let now = self.now();
52 let (mut bucket, cas) = self.refresh_bucket(&keys, now, group_selector)?;
53
54 match bucket.request_allowed(now, increment) {
55 RequestAllowed::OutOfQuota => {
56 return Ok(RateLimitResult::TooManyRequests(RateLimitStatistics::from(
57 &bucket, now,
58 )));
59 }
60 RequestAllowed::OutOfLocalQuota => {
61 if !self.hops_exceeded(&hops) {
62 debug!("Out of local quota");
63 let try_lock = self.lock(&keys);
64 self.lock_and_fetch_quota(
65 &keys,
66 group_selector,
67 &try_lock,
68 increment,
69 &mut hops,
70 )
71 .await?;
72 } else {
73 debug!("Max hop count reached");
74 return Err(RateLimitError::MaxHops);
75 }
76 }
77 RequestAllowed::Allowed => {
78 let serialized = bincode::serialize(&bucket)?;
79 match self.shared_data.set(keys.data_key(), &serialized, cas) {
80 Ok(()) => {
81 debug!("Bucket saved: {bucket:?}");
82 return Ok(RateLimitResult::Allowed(RateLimitStatistics::from(
83 &bucket, now,
84 )));
85 }
86 Err(LocalStorageError::CasMismatch) => {
87 debug!("Local cas mismatch.");
89 }
90 Err(e) => return Err(e.into()),
91 }
92 }
93 }
94 }
95 }
96}
97
98impl RateLimitInstance {
99 #[allow(clippy::too_many_arguments)]
100 pub(crate) fn new(
101 store: String,
102 key_manager_factory: KeyManagerFactory,
103 bucket_factory: BucketFactory,
104 formula: DistributionFormula,
105 clock: Rc<dyn Clock>,
106 shared_data: Rc<SharedData>,
107 local_storage: Option<Rc<DistributedStorageClient>>,
108 timer: Option<Rc<Timer>>,
109 lock_builder: Rc<LockBuilder>,
110 ) -> Self {
111 Self {
112 store,
113 key_manager_factory,
114 bucket_factory,
115 formula,
116 clock,
117 shared_data,
118 distributed_storage: local_storage,
119 timer,
120 lock_builder,
121 }
122 }
123
124 async fn lock_and_fetch_quota(
125 &self,
126 keys: &KeyManager,
127 group_selector: &str,
128 try_lock: &TryLock,
129 amount: usize,
130 hops: &mut usize,
131 ) -> Result<(), RateLimitError> {
132 let lock = try_lock.try_lock();
133
134 if let Some(lock) = lock {
135 let result = self
136 .fetch_quota(keys, group_selector, &lock, amount, hops)
137 .await;
138 drop(lock);
139 result
140 } else {
141 debug!("Other worker has the lock.");
145 self.timer().sleep(Duration::from_millis(100)).await;
146 Ok(())
147 }
148 }
149
150 async fn fetch_quota(
151 &self,
152 keys: &KeyManager,
153 group_selector: &str,
154 lock: &Lock<'_>,
155 amount: usize,
156 hops: &mut usize,
157 ) -> Result<(), RateLimitError> {
158 let get = self
159 .distributed_storage()
160 .get(&self.store, &self.store, keys.storage_key())
161 .await;
162
163 *hops += 1;
164
165 if !lock.refresh_lock() {
166 debug!("Lost the lock!!!");
167 return Ok(());
168 }
169
170 match get {
171 Ok((remote_bucket, cas)) => {
172 let mut remote_bucket = bincode::deserialize::<Bucket>(remote_bucket.as_slice())?;
173 let (local_bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
174
175 let retrieved_quota =
176 remote_bucket.get_quota(self.now(), &local_bucket, &self.formula, amount);
177
178 if !self.quota_given(&retrieved_quota) {
179 debug!("No quota could be obtained from the bucket.");
180 self.update_quota(keys, &retrieved_quota, group_selector)?;
181
182 Ok(())
183 } else {
184 debug!(
185 "Storing updated remote bucket: {}:{}:{}",
186 self.store,
187 self.store,
188 keys.storage_key()
189 );
190
191 debug!("Storing updated remote bucket (data): {remote_bucket:?}");
192
193 let remote_bucket = bincode::serialize(&remote_bucket)?;
194
195 let store = self
196 .distributed_storage()
197 .store(
198 &self.store,
199 &self.store,
200 keys.storage_key(),
201 &distributed::StoreMode::Cas(cas),
202 &remote_bucket,
203 )
204 .await;
205 *hops += 1;
206
207 if !lock.refresh_lock() {
208 debug!("Lost the lock!!!");
209 return Ok(());
210 }
211
212 match store {
213 Ok(()) => {
214 self.update_quota(keys, &retrieved_quota, group_selector)?;
215 Ok(())
216 }
217 Err(DistributedStorageError::CasMismatch) => {
218 debug!("Remote cas mismatch while updating quota.");
219 Ok(())
221 }
222 Err(
223 DistributedStorageError::KeyNotFound
224 | DistributedStorageError::StoreNotFound,
225 ) => self.init_storage(keys, group_selector, lock).await,
226 Err(e) => Err(e.into()),
227 }
228 }
229 }
230 Err(DistributedStorageError::KeyNotFound | DistributedStorageError::StoreNotFound) => {
231 self.init_storage(keys, group_selector, lock).await
232 }
233 Err(e) => Err(e.into()),
234 }
235 }
236
237 async fn init_storage(
238 &self,
239 keys: &KeyManager,
240 group_selector: &str,
241 lock: &Lock<'_>,
242 ) -> Result<(), RateLimitError> {
243 debug!("Initializing storage for key {}.", self.store);
244
245 let store = Store::new(self.store.clone(), None, None);
246
247 if let Err(e) = self.distributed_storage().upsert_store(&store).await {
248 warn!("Ignoring error creating store: {e}");
249 }
250
251 if !lock.refresh_lock() {
252 debug!("Lost the lock!!!");
253 return Ok(());
254 }
255
256 debug!("Initializing key {}.", keys.storage_key());
259
260 let (bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
261 let bucket = bincode::serialize(&bucket)?;
262
263 let store = self
264 .distributed_storage()
265 .store(
266 &self.store,
267 &self.store,
268 keys.storage_key(),
269 &distributed::StoreMode::Absent,
270 &bucket,
271 )
272 .await;
273
274 match store {
275 Ok(()) | Err(DistributedStorageError::CasMismatch) => Ok(()),
276 Err(e) => Err(e.into()),
277 }
278 }
279
280 fn quota_given(&self, retrieved_quota: &[QuotaInfo]) -> bool {
281 retrieved_quota.iter().any(QuotaInfo::is_some)
282 }
283
284 fn timer(&self) -> &Timer {
285 self.timer.as_ref().unwrap()
286 }
287
288 fn distributed_storage(&self) -> &DistributedStorageClient {
289 self.distributed_storage.as_ref().unwrap()
290 }
291
292 fn update_quota(
293 &self,
294 keys: &KeyManager,
295 retrieved_quota: &[QuotaInfo],
296 group_selector: &str,
297 ) -> Result<(), RateLimitError> {
298 debug!("Updating quota");
299 loop {
300 let (mut bucket, cas) = self.refresh_bucket(keys, self.now(), group_selector)?;
301 bucket.update_quota(retrieved_quota);
302 let serialized = bincode::serialize(&bucket)?;
303 match self.shared_data.set(keys.data_key(), &serialized, cas) {
304 Ok(()) => {
305 debug!("Bucket saved: {bucket:?}");
306 return Ok(());
307 }
308 Err(LocalStorageError::CasMismatch) => {
309 debug!("Local cas mismatch while updating quota.");
311 }
312 Err(e) => return Err(e.into()),
313 }
314 }
315 }
316
317 fn now(&self) -> u128 {
318 self.clock.get_current_time_unit(TimeUnit::Milliseconds)
319 }
320
321 fn hops_exceeded(&self, hops: &usize) -> bool {
322 *hops >= MAX_HOPS
323 }
324
325 fn refresh_bucket(
326 &self,
327 keys: &KeyManager,
328 now: u128,
329 group_selector: &str,
330 ) -> Result<(Bucket, local::StoreMode), RateLimitError> {
331 let bucket = self
332 .shared_data
333 .get(keys.data_key())
334 .map_err(RateLimitError::from)?;
335
336 let result = match bucket {
337 Some((bytes, cas)) => {
338 let bucket = bincode::deserialize(&bytes)?;
339 debug!("Retrieved Bucket from shared data: {bucket:?}");
340 (bucket, local::StoreMode::Cas(cas))
341 }
342 None => (
343 self.bucket_factory.create(now, group_selector),
344 local::StoreMode::Absent,
345 ),
346 };
347
348 Ok(result)
349 }
350
351 fn lock(&self, keys: &KeyManager) -> TryLock {
352 self.lock_builder
353 .new(keys.lock_key().to_string())
354 .expiration(LOCK_EXPIRATION)
355 .shared() .build()
357 }
358}