Skip to main content

pdk_rate_limit_lib/
implementation.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use crate::bucket::{Bucket, BucketFactory, QuotaInfo, RequestAllowed};
6use crate::distribution_formula::DistributionFormula;
7use crate::key_manager::{KeyManager, KeyManagerFactory};
8use crate::{RateLimit, RateLimitError, RateLimitResult, RateLimitStatistics};
9use data_storage_lib::ll::distributed::{
10    DistributedStorage, DistributedStorageClient, DistributedStorageError, Store,
11};
12use data_storage_lib::ll::local::{LocalStorage, LocalStorageError, SharedData};
13use data_storage_lib::ll::{distributed, local};
14use lock_lib::{Lock, LockBuilder, TryLock};
15use pdk_core::classy::timer::Timer;
16use pdk_core::classy::{Clock, TimeUnit};
17use pdk_core::log::{debug, warn};
18use std::rc::Rc;
19use std::time::Duration;
20
21const MAX_HOPS: usize = 200;
22const LOCK_EXPIRATION: Duration = Duration::from_secs(13);
23
24/// An implementation of the RateLimit trait. For creating a new instance, use the `RateLimitBuilder`.  
25pub struct RateLimitInstance {
26    store: String,
27    key_manager_factory: KeyManagerFactory,
28    bucket_factory: BucketFactory,
29    formula: DistributionFormula,
30    clock: Rc<dyn Clock>,
31    shared_data: Rc<SharedData>,
32    distributed_storage: Option<Rc<DistributedStorageClient>>,
33    timer: Option<Rc<Timer>>,
34    lock_builder: Rc<LockBuilder>,
35}
36
37impl RateLimit for RateLimitInstance {
38    async fn is_allowed(
39        &self,
40        group_selector: &str,
41        bucket_selector: &str,
42        increment: usize,
43    ) -> Result<RateLimitResult, RateLimitError> {
44        let mut hops: usize = 0;
45
46        let keys = self
47            .key_manager_factory
48            .create(bucket_selector, group_selector);
49
50        loop {
51            let now = self.now();
52            let (mut bucket, cas) = self.refresh_bucket(&keys, now, group_selector)?;
53
54            match bucket.request_allowed(now, increment) {
55                RequestAllowed::OutOfQuota => {
56                    return Ok(RateLimitResult::TooManyRequests(RateLimitStatistics::from(
57                        &bucket, now,
58                    )));
59                }
60                RequestAllowed::OutOfLocalQuota => {
61                    if !self.hops_exceeded(&hops) {
62                        debug!("Out of local quota");
63                        let try_lock = self.lock(&keys);
64                        self.lock_and_fetch_quota(
65                            &keys,
66                            group_selector,
67                            &try_lock,
68                            increment,
69                            &mut hops,
70                        )
71                        .await?;
72                    } else {
73                        debug!("Max hop count reached");
74                        return Err(RateLimitError::MaxHops);
75                    }
76                }
77                RequestAllowed::Allowed => {
78                    let serialized = bincode::serialize(&bucket)?;
79                    match self.shared_data.set(keys.data_key(), &serialized, cas) {
80                        Ok(()) => {
81                            debug!("Bucket saved: {bucket:?}");
82                            return Ok(RateLimitResult::Allowed(RateLimitStatistics::from(
83                                &bucket, now,
84                            )));
85                        }
86                        Err(LocalStorageError::CasMismatch) => {
87                            // In case of failed to store we go to the next loop
88                            debug!("Local cas mismatch.");
89                        }
90                        Err(e) => return Err(e.into()),
91                    }
92                }
93            }
94        }
95    }
96}
97
98impl RateLimitInstance {
99    #[allow(clippy::too_many_arguments)]
100    pub(crate) fn new(
101        store: String,
102        key_manager_factory: KeyManagerFactory,
103        bucket_factory: BucketFactory,
104        formula: DistributionFormula,
105        clock: Rc<dyn Clock>,
106        shared_data: Rc<SharedData>,
107        local_storage: Option<Rc<DistributedStorageClient>>,
108        timer: Option<Rc<Timer>>,
109        lock_builder: Rc<LockBuilder>,
110    ) -> Self {
111        Self {
112            store,
113            key_manager_factory,
114            bucket_factory,
115            formula,
116            clock,
117            shared_data,
118            distributed_storage: local_storage,
119            timer,
120            lock_builder,
121        }
122    }
123
124    async fn lock_and_fetch_quota(
125        &self,
126        keys: &KeyManager,
127        group_selector: &str,
128        try_lock: &TryLock,
129        amount: usize,
130        hops: &mut usize,
131    ) -> Result<(), RateLimitError> {
132        let lock = try_lock.try_lock();
133
134        if let Some(lock) = lock {
135            let result = self
136                .fetch_quota(keys, group_selector, &lock, amount, hops)
137                .await;
138            drop(lock);
139            result
140        } else {
141            // Another request is currently handling the fetch requests, we sleep for a tick and
142            // retry on the next loop.
143            // THIS IS A BUSY WAIT!!!
144            debug!("Other worker has the lock.");
145            self.timer().sleep(Duration::from_millis(100)).await;
146            Ok(())
147        }
148    }
149
150    async fn fetch_quota(
151        &self,
152        keys: &KeyManager,
153        group_selector: &str,
154        lock: &Lock<'_>,
155        amount: usize,
156        hops: &mut usize,
157    ) -> Result<(), RateLimitError> {
158        let get = self
159            .distributed_storage()
160            .get(&self.store, &self.store, keys.storage_key())
161            .await;
162
163        *hops += 1;
164
165        if !lock.refresh_lock() {
166            debug!("Lost the lock!!!");
167            return Ok(());
168        }
169
170        match get {
171            Ok((remote_bucket, cas)) => {
172                let mut remote_bucket = bincode::deserialize::<Bucket>(remote_bucket.as_slice())?;
173                let (local_bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
174
175                let retrieved_quota =
176                    remote_bucket.get_quota(self.now(), &local_bucket, &self.formula, amount);
177
178                if !self.quota_given(&retrieved_quota) {
179                    debug!("No quota could be obtained from the bucket.");
180                    self.update_quota(keys, &retrieved_quota, group_selector)?;
181
182                    Ok(())
183                } else {
184                    debug!(
185                        "Storing updated remote bucket: {}:{}:{}",
186                        self.store,
187                        self.store,
188                        keys.storage_key()
189                    );
190
191                    debug!("Storing updated remote bucket (data): {remote_bucket:?}");
192
193                    let remote_bucket = bincode::serialize(&remote_bucket)?;
194
195                    let store = self
196                        .distributed_storage()
197                        .store(
198                            &self.store,
199                            &self.store,
200                            keys.storage_key(),
201                            &distributed::StoreMode::Cas(cas),
202                            &remote_bucket,
203                        )
204                        .await;
205                    *hops += 1;
206
207                    if !lock.refresh_lock() {
208                        debug!("Lost the lock!!!");
209                        return Ok(());
210                    }
211
212                    match store {
213                        Ok(()) => {
214                            self.update_quota(keys, &retrieved_quota, group_selector)?;
215                            Ok(())
216                        }
217                        Err(DistributedStorageError::CasMismatch) => {
218                            debug!("Remote cas mismatch while updating quota.");
219                            // retry on next loop
220                            Ok(())
221                        }
222                        Err(
223                            DistributedStorageError::KeyNotFound
224                            | DistributedStorageError::StoreNotFound,
225                        ) => self.init_storage(keys, group_selector, lock).await,
226                        Err(e) => Err(e.into()),
227                    }
228                }
229            }
230            Err(DistributedStorageError::KeyNotFound | DistributedStorageError::StoreNotFound) => {
231                self.init_storage(keys, group_selector, lock).await
232            }
233            Err(e) => Err(e.into()),
234        }
235    }
236
237    async fn init_storage(
238        &self,
239        keys: &KeyManager,
240        group_selector: &str,
241        lock: &Lock<'_>,
242    ) -> Result<(), RateLimitError> {
243        debug!("Initializing storage for key {}.", self.store);
244
245        let store = Store::new(self.store.clone(), None, None);
246
247        if let Err(e) = self.distributed_storage().upsert_store(&store).await {
248            warn!("Ignoring error creating store: {e}");
249        }
250
251        if !lock.refresh_lock() {
252            debug!("Lost the lock!!!");
253            return Ok(());
254        }
255
256        // If the storage was missing then key was also missing, pre-emtive init.
257
258        debug!("Initializing key {}.", keys.storage_key());
259
260        let (bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
261        let bucket = bincode::serialize(&bucket)?;
262
263        let store = self
264            .distributed_storage()
265            .store(
266                &self.store,
267                &self.store,
268                keys.storage_key(),
269                &distributed::StoreMode::Absent,
270                &bucket,
271            )
272            .await;
273
274        match store {
275            Ok(()) | Err(DistributedStorageError::CasMismatch) => Ok(()),
276            Err(e) => Err(e.into()),
277        }
278    }
279
280    fn quota_given(&self, retrieved_quota: &[QuotaInfo]) -> bool {
281        retrieved_quota.iter().any(QuotaInfo::is_some)
282    }
283
284    fn timer(&self) -> &Timer {
285        self.timer.as_ref().unwrap()
286    }
287
288    fn distributed_storage(&self) -> &DistributedStorageClient {
289        self.distributed_storage.as_ref().unwrap()
290    }
291
292    fn update_quota(
293        &self,
294        keys: &KeyManager,
295        retrieved_quota: &[QuotaInfo],
296        group_selector: &str,
297    ) -> Result<(), RateLimitError> {
298        debug!("Updating quota");
299        loop {
300            let (mut bucket, cas) = self.refresh_bucket(keys, self.now(), group_selector)?;
301            bucket.update_quota(retrieved_quota);
302            let serialized = bincode::serialize(&bucket)?;
303            match self.shared_data.set(keys.data_key(), &serialized, cas) {
304                Ok(()) => {
305                    debug!("Bucket saved: {bucket:?}");
306                    return Ok(());
307                }
308                Err(LocalStorageError::CasMismatch) => {
309                    // In case of failed to store we go to the next loop
310                    debug!("Local cas mismatch while updating quota.");
311                }
312                Err(e) => return Err(e.into()),
313            }
314        }
315    }
316
317    fn now(&self) -> u128 {
318        self.clock.get_current_time_unit(TimeUnit::Milliseconds)
319    }
320
321    fn hops_exceeded(&self, hops: &usize) -> bool {
322        *hops >= MAX_HOPS
323    }
324
325    fn refresh_bucket(
326        &self,
327        keys: &KeyManager,
328        now: u128,
329        group_selector: &str,
330    ) -> Result<(Bucket, local::StoreMode), RateLimitError> {
331        let bucket = self
332            .shared_data
333            .get(keys.data_key())
334            .map_err(RateLimitError::from)?;
335
336        let result = match bucket {
337            Some((bytes, cas)) => {
338                let bucket = bincode::deserialize(&bytes)?;
339                debug!("Retrieved Bucket from shared data: {bucket:?}");
340                (bucket, local::StoreMode::Cas(cas))
341            }
342            None => (
343                self.bucket_factory.create(now, group_selector),
344                local::StoreMode::Absent,
345            ),
346        };
347
348        Ok(result)
349    }
350
351    fn lock(&self, keys: &KeyManager) -> TryLock {
352        self.lock_builder
353            .new(keys.lock_key().to_string())
354            .expiration(LOCK_EXPIRATION)
355            .shared() // Isolation already provided by the keyManager.
356            .build()
357    }
358}