pdk-rate-limit-lib 1.7.0

PDK Rate Limit Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use crate::bucket::{Bucket, BucketFactory, QuotaInfo, RequestAllowed};
use crate::distribution_formula::DistributionFormula;
use crate::key_manager::{KeyManager, KeyManagerFactory};
use crate::{RateLimit, RateLimitError, RateLimitResult, RateLimitStatistics};
use data_storage_lib::ll::distributed::{
    DistributedStorage, DistributedStorageClient, DistributedStorageError, Store,
};
use data_storage_lib::ll::local::{LocalStorage, LocalStorageError, SharedData};
use data_storage_lib::ll::{distributed, local};
use lock_lib::{Lock, LockBuilder, TryLock};
use pdk_core::classy::timer::Timer;
use pdk_core::classy::{Clock, TimeUnit};
use pdk_core::log::{debug, warn};
use std::rc::Rc;
use std::time::Duration;

const MAX_HOPS: usize = 200;
const LOCK_EXPIRATION: Duration = Duration::from_secs(13);

/// An implementation of the RateLimit trait. For creating a new instance, use the `RateLimitBuilder`.  
pub struct RateLimitInstance {
    store: String,
    key_manager_factory: KeyManagerFactory,
    bucket_factory: BucketFactory,
    formula: DistributionFormula,
    clock: Rc<dyn Clock>,
    shared_data: Rc<SharedData>,
    distributed_storage: Option<Rc<DistributedStorageClient>>,
    timer: Option<Rc<Timer>>,
    lock_builder: Rc<LockBuilder>,
}

impl RateLimit for RateLimitInstance {
    async fn is_allowed(
        &self,
        group_selector: &str,
        bucket_selector: &str,
        increment: usize,
    ) -> Result<RateLimitResult, RateLimitError> {
        let mut hops: usize = 0;

        let keys = self
            .key_manager_factory
            .create(bucket_selector, group_selector);

        loop {
            let now = self.now();
            let (mut bucket, cas) = self.refresh_bucket(&keys, now, group_selector)?;

            match bucket.request_allowed(now, increment) {
                RequestAllowed::OutOfQuota => {
                    return Ok(RateLimitResult::TooManyRequests(RateLimitStatistics::from(
                        &bucket, now,
                    )));
                }
                RequestAllowed::OutOfLocalQuota => {
                    if !self.hops_exceeded(&hops) {
                        debug!("Out of local quota");
                        let try_lock = self.lock(&keys);
                        self.lock_and_fetch_quota(
                            &keys,
                            group_selector,
                            &try_lock,
                            increment,
                            &mut hops,
                        )
                        .await?;
                    } else {
                        debug!("Max hop count reached");
                        return Err(RateLimitError::MaxHops);
                    }
                }
                RequestAllowed::Allowed => {
                    let serialized = bincode::serialize(&bucket)?;
                    match self.shared_data.set(keys.data_key(), &serialized, cas) {
                        Ok(()) => {
                            debug!("Bucket saved: {bucket:?}");
                            return Ok(RateLimitResult::Allowed(RateLimitStatistics::from(
                                &bucket, now,
                            )));
                        }
                        Err(LocalStorageError::CasMismatch) => {
                            // In case of failed to store we go to the next loop
                            debug!("Local cas mismatch.");
                        }
                        Err(e) => return Err(e.into()),
                    }
                }
            }
        }
    }
}

impl RateLimitInstance {
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn new(
        store: String,
        key_manager_factory: KeyManagerFactory,
        bucket_factory: BucketFactory,
        formula: DistributionFormula,
        clock: Rc<dyn Clock>,
        shared_data: Rc<SharedData>,
        local_storage: Option<Rc<DistributedStorageClient>>,
        timer: Option<Rc<Timer>>,
        lock_builder: Rc<LockBuilder>,
    ) -> Self {
        Self {
            store,
            key_manager_factory,
            bucket_factory,
            formula,
            clock,
            shared_data,
            distributed_storage: local_storage,
            timer,
            lock_builder,
        }
    }

    async fn lock_and_fetch_quota(
        &self,
        keys: &KeyManager,
        group_selector: &str,
        try_lock: &TryLock,
        amount: usize,
        hops: &mut usize,
    ) -> Result<(), RateLimitError> {
        let lock = try_lock.try_lock();

        if let Some(lock) = lock {
            let result = self
                .fetch_quota(keys, group_selector, &lock, amount, hops)
                .await;
            drop(lock);
            result
        } else {
            // Another request is currently handling the fetch requests, we sleep for a tick and
            // retry on the next loop.
            // THIS IS A BUSY WAIT!!!
            debug!("Other worker has the lock.");
            self.timer().sleep(Duration::from_millis(100)).await;
            Ok(())
        }
    }

    async fn fetch_quota(
        &self,
        keys: &KeyManager,
        group_selector: &str,
        lock: &Lock<'_>,
        amount: usize,
        hops: &mut usize,
    ) -> Result<(), RateLimitError> {
        let get = self
            .distributed_storage()
            .get(&self.store, &self.store, keys.storage_key())
            .await;

        *hops += 1;

        if !lock.refresh_lock() {
            debug!("Lost the lock!!!");
            return Ok(());
        }

        match get {
            Ok((remote_bucket, cas)) => {
                let mut remote_bucket = bincode::deserialize::<Bucket>(remote_bucket.as_slice())?;
                let (local_bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;

                let retrieved_quota =
                    remote_bucket.get_quota(self.now(), &local_bucket, &self.formula, amount);

                if !self.quota_given(&retrieved_quota) {
                    debug!("No quota could be obtained from the bucket.");
                    self.update_quota(keys, &retrieved_quota, group_selector)?;

                    Ok(())
                } else {
                    debug!(
                        "Storing updated remote bucket: {}:{}:{}",
                        self.store,
                        self.store,
                        keys.storage_key()
                    );

                    debug!("Storing updated remote bucket (data): {remote_bucket:?}");

                    let remote_bucket = bincode::serialize(&remote_bucket)?;

                    let store = self
                        .distributed_storage()
                        .store(
                            &self.store,
                            &self.store,
                            keys.storage_key(),
                            &distributed::StoreMode::Cas(cas),
                            &remote_bucket,
                        )
                        .await;
                    *hops += 1;

                    if !lock.refresh_lock() {
                        debug!("Lost the lock!!!");
                        return Ok(());
                    }

                    match store {
                        Ok(()) => {
                            self.update_quota(keys, &retrieved_quota, group_selector)?;
                            Ok(())
                        }
                        Err(DistributedStorageError::CasMismatch) => {
                            debug!("Remote cas mismatch while updating quota.");
                            // retry on next loop
                            Ok(())
                        }
                        Err(
                            DistributedStorageError::KeyNotFound
                            | DistributedStorageError::StoreNotFound,
                        ) => self.init_storage(keys, group_selector, lock).await,
                        Err(e) => Err(e.into()),
                    }
                }
            }
            Err(DistributedStorageError::KeyNotFound | DistributedStorageError::StoreNotFound) => {
                self.init_storage(keys, group_selector, lock).await
            }
            Err(e) => Err(e.into()),
        }
    }

    async fn init_storage(
        &self,
        keys: &KeyManager,
        group_selector: &str,
        lock: &Lock<'_>,
    ) -> Result<(), RateLimitError> {
        debug!("Initializing storage for key {}.", self.store);

        let store = Store::new(self.store.clone(), None, None);

        if let Err(e) = self.distributed_storage().upsert_store(&store).await {
            warn!("Ignoring error creating store: {e}");
        }

        if !lock.refresh_lock() {
            debug!("Lost the lock!!!");
            return Ok(());
        }

        // If the storage was missing then key was also missing, pre-emtive init.

        debug!("Initializing key {}.", keys.storage_key());

        let (bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
        let bucket = bincode::serialize(&bucket)?;

        let store = self
            .distributed_storage()
            .store(
                &self.store,
                &self.store,
                keys.storage_key(),
                &distributed::StoreMode::Absent,
                &bucket,
            )
            .await;

        match store {
            Ok(()) | Err(DistributedStorageError::CasMismatch) => Ok(()),
            Err(e) => Err(e.into()),
        }
    }

    fn quota_given(&self, retrieved_quota: &[QuotaInfo]) -> bool {
        retrieved_quota.iter().any(QuotaInfo::is_some)
    }

    fn timer(&self) -> &Timer {
        self.timer.as_ref().unwrap()
    }

    fn distributed_storage(&self) -> &DistributedStorageClient {
        self.distributed_storage.as_ref().unwrap()
    }

    fn update_quota(
        &self,
        keys: &KeyManager,
        retrieved_quota: &[QuotaInfo],
        group_selector: &str,
    ) -> Result<(), RateLimitError> {
        debug!("Updating quota");
        loop {
            let (mut bucket, cas) = self.refresh_bucket(keys, self.now(), group_selector)?;
            bucket.update_quota(retrieved_quota);
            let serialized = bincode::serialize(&bucket)?;
            match self.shared_data.set(keys.data_key(), &serialized, cas) {
                Ok(()) => {
                    debug!("Bucket saved: {bucket:?}");
                    return Ok(());
                }
                Err(LocalStorageError::CasMismatch) => {
                    // In case of failed to store we go to the next loop
                    debug!("Local cas mismatch while updating quota.");
                }
                Err(e) => return Err(e.into()),
            }
        }
    }

    fn now(&self) -> u128 {
        self.clock.get_current_time_unit(TimeUnit::Milliseconds)
    }

    fn hops_exceeded(&self, hops: &usize) -> bool {
        *hops >= MAX_HOPS
    }

    fn refresh_bucket(
        &self,
        keys: &KeyManager,
        now: u128,
        group_selector: &str,
    ) -> Result<(Bucket, local::StoreMode), RateLimitError> {
        let bucket = self
            .shared_data
            .get(keys.data_key())
            .map_err(RateLimitError::from)?;

        let result = match bucket {
            Some((bytes, cas)) => {
                let bucket = bincode::deserialize(&bytes)?;
                debug!("Retrieved Bucket from shared data: {bucket:?}");
                (bucket, local::StoreMode::Cas(cas))
            }
            None => (
                self.bucket_factory.create(now, group_selector),
                local::StoreMode::Absent,
            ),
        };

        Ok(result)
    }

    fn lock(&self, keys: &KeyManager) -> TryLock {
        self.lock_builder
            .new(keys.lock_key().to_string())
            .expiration(LOCK_EXPIRATION)
            .shared() // Isolation already provided by the keyManager.
            .build()
    }
}