use crate::bucket::{Bucket, BucketFactory, QuotaInfo, RequestAllowed};
use crate::distribution_formula::DistributionFormula;
use crate::key_manager::{KeyManager, KeyManagerFactory};
use crate::{RateLimit, RateLimitError, RateLimitResult, RateLimitStatistics};
use data_storage_lib::ll::distributed::{
DistributedStorage, DistributedStorageClient, DistributedStorageError, Store,
};
use data_storage_lib::ll::local::{LocalStorage, LocalStorageError, SharedData};
use data_storage_lib::ll::{distributed, local};
use lock_lib::{Lock, LockBuilder, TryLock};
use pdk_core::classy::timer::Timer;
use pdk_core::classy::{Clock, TimeUnit};
use pdk_core::log::{debug, warn};
use std::rc::Rc;
use std::time::Duration;
const MAX_HOPS: usize = 200;
const LOCK_EXPIRATION: Duration = Duration::from_secs(13);
pub struct RateLimitInstance {
store: String,
key_manager_factory: KeyManagerFactory,
bucket_factory: BucketFactory,
formula: DistributionFormula,
clock: Rc<dyn Clock>,
shared_data: Rc<SharedData>,
distributed_storage: Option<Rc<DistributedStorageClient>>,
timer: Option<Rc<Timer>>,
lock_builder: Rc<LockBuilder>,
}
impl RateLimit for RateLimitInstance {
async fn is_allowed(
&self,
group_selector: &str,
bucket_selector: &str,
increment: usize,
) -> Result<RateLimitResult, RateLimitError> {
let mut hops: usize = 0;
let keys = self
.key_manager_factory
.create(bucket_selector, group_selector);
loop {
let now = self.now();
let (mut bucket, cas) = self.refresh_bucket(&keys, now, group_selector)?;
match bucket.request_allowed(now, increment) {
RequestAllowed::OutOfQuota => {
return Ok(RateLimitResult::TooManyRequests(RateLimitStatistics::from(
&bucket, now,
)));
}
RequestAllowed::OutOfLocalQuota => {
if !self.hops_exceeded(&hops) {
debug!("Out of local quota");
let try_lock = self.lock(&keys);
self.lock_and_fetch_quota(
&keys,
group_selector,
&try_lock,
increment,
&mut hops,
)
.await?;
} else {
debug!("Max hop count reached");
return Err(RateLimitError::MaxHops);
}
}
RequestAllowed::Allowed => {
let serialized = bincode::serialize(&bucket)?;
match self.shared_data.set(keys.data_key(), &serialized, cas) {
Ok(()) => {
debug!("Bucket saved: {bucket:?}");
return Ok(RateLimitResult::Allowed(RateLimitStatistics::from(
&bucket, now,
)));
}
Err(LocalStorageError::CasMismatch) => {
debug!("Local cas mismatch.");
}
Err(e) => return Err(e.into()),
}
}
}
}
}
}
impl RateLimitInstance {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
store: String,
key_manager_factory: KeyManagerFactory,
bucket_factory: BucketFactory,
formula: DistributionFormula,
clock: Rc<dyn Clock>,
shared_data: Rc<SharedData>,
local_storage: Option<Rc<DistributedStorageClient>>,
timer: Option<Rc<Timer>>,
lock_builder: Rc<LockBuilder>,
) -> Self {
Self {
store,
key_manager_factory,
bucket_factory,
formula,
clock,
shared_data,
distributed_storage: local_storage,
timer,
lock_builder,
}
}
async fn lock_and_fetch_quota(
&self,
keys: &KeyManager,
group_selector: &str,
try_lock: &TryLock,
amount: usize,
hops: &mut usize,
) -> Result<(), RateLimitError> {
let lock = try_lock.try_lock();
if let Some(lock) = lock {
let result = self
.fetch_quota(keys, group_selector, &lock, amount, hops)
.await;
drop(lock);
result
} else {
debug!("Other worker has the lock.");
self.timer().sleep(Duration::from_millis(100)).await;
Ok(())
}
}
async fn fetch_quota(
&self,
keys: &KeyManager,
group_selector: &str,
lock: &Lock<'_>,
amount: usize,
hops: &mut usize,
) -> Result<(), RateLimitError> {
let get = self
.distributed_storage()
.get(&self.store, &self.store, keys.storage_key())
.await;
*hops += 1;
if !lock.refresh_lock() {
debug!("Lost the lock!!!");
return Ok(());
}
match get {
Ok((remote_bucket, cas)) => {
let mut remote_bucket = bincode::deserialize::<Bucket>(remote_bucket.as_slice())?;
let (local_bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
let retrieved_quota =
remote_bucket.get_quota(self.now(), &local_bucket, &self.formula, amount);
if !self.quota_given(&retrieved_quota) {
debug!("No quota could be obtained from the bucket.");
self.update_quota(keys, &retrieved_quota, group_selector)?;
Ok(())
} else {
debug!(
"Storing updated remote bucket: {}:{}:{}",
self.store,
self.store,
keys.storage_key()
);
debug!("Storing updated remote bucket (data): {remote_bucket:?}");
let remote_bucket = bincode::serialize(&remote_bucket)?;
let store = self
.distributed_storage()
.store(
&self.store,
&self.store,
keys.storage_key(),
&distributed::StoreMode::Cas(cas),
&remote_bucket,
)
.await;
*hops += 1;
if !lock.refresh_lock() {
debug!("Lost the lock!!!");
return Ok(());
}
match store {
Ok(()) => {
self.update_quota(keys, &retrieved_quota, group_selector)?;
Ok(())
}
Err(DistributedStorageError::CasMismatch) => {
debug!("Remote cas mismatch while updating quota.");
Ok(())
}
Err(
DistributedStorageError::KeyNotFound
| DistributedStorageError::StoreNotFound,
) => self.init_storage(keys, group_selector, lock).await,
Err(e) => Err(e.into()),
}
}
}
Err(DistributedStorageError::KeyNotFound | DistributedStorageError::StoreNotFound) => {
self.init_storage(keys, group_selector, lock).await
}
Err(e) => Err(e.into()),
}
}
async fn init_storage(
&self,
keys: &KeyManager,
group_selector: &str,
lock: &Lock<'_>,
) -> Result<(), RateLimitError> {
debug!("Initializing storage for key {}.", self.store);
let store = Store::new(self.store.clone(), None, None);
if let Err(e) = self.distributed_storage().upsert_store(&store).await {
warn!("Ignoring error creating store: {e}");
}
if !lock.refresh_lock() {
debug!("Lost the lock!!!");
return Ok(());
}
debug!("Initializing key {}.", keys.storage_key());
let (bucket, _) = self.refresh_bucket(keys, self.now(), group_selector)?;
let bucket = bincode::serialize(&bucket)?;
let store = self
.distributed_storage()
.store(
&self.store,
&self.store,
keys.storage_key(),
&distributed::StoreMode::Absent,
&bucket,
)
.await;
match store {
Ok(()) | Err(DistributedStorageError::CasMismatch) => Ok(()),
Err(e) => Err(e.into()),
}
}
fn quota_given(&self, retrieved_quota: &[QuotaInfo]) -> bool {
retrieved_quota.iter().any(QuotaInfo::is_some)
}
fn timer(&self) -> &Timer {
self.timer.as_ref().unwrap()
}
fn distributed_storage(&self) -> &DistributedStorageClient {
self.distributed_storage.as_ref().unwrap()
}
fn update_quota(
&self,
keys: &KeyManager,
retrieved_quota: &[QuotaInfo],
group_selector: &str,
) -> Result<(), RateLimitError> {
debug!("Updating quota");
loop {
let (mut bucket, cas) = self.refresh_bucket(keys, self.now(), group_selector)?;
bucket.update_quota(retrieved_quota);
let serialized = bincode::serialize(&bucket)?;
match self.shared_data.set(keys.data_key(), &serialized, cas) {
Ok(()) => {
debug!("Bucket saved: {bucket:?}");
return Ok(());
}
Err(LocalStorageError::CasMismatch) => {
debug!("Local cas mismatch while updating quota.");
}
Err(e) => return Err(e.into()),
}
}
}
fn now(&self) -> u128 {
self.clock.get_current_time_unit(TimeUnit::Milliseconds)
}
fn hops_exceeded(&self, hops: &usize) -> bool {
*hops >= MAX_HOPS
}
fn refresh_bucket(
&self,
keys: &KeyManager,
now: u128,
group_selector: &str,
) -> Result<(Bucket, local::StoreMode), RateLimitError> {
let bucket = self
.shared_data
.get(keys.data_key())
.map_err(RateLimitError::from)?;
let result = match bucket {
Some((bytes, cas)) => {
let bucket = bincode::deserialize(&bytes)?;
debug!("Retrieved Bucket from shared data: {bucket:?}");
(bucket, local::StoreMode::Cas(cas))
}
None => (
self.bucket_factory.create(now, group_selector),
local::StoreMode::Absent,
),
};
Ok(result)
}
fn lock(&self, keys: &KeyManager) -> TryLock {
self.lock_builder
.new(keys.lock_key().to_string())
.expiration(LOCK_EXPIRATION)
.shared() .build()
}
}