use std::{collections::HashSet, sync::Arc, time::Duration};
use async_trait::async_trait;
use casper_types::PublicKey;
use tokio::{
sync::{mpsc, oneshot},
time::Instant,
};
use tracing::{debug, info, warn};
use crate::types::NodeId;
const STORED_BUFFER_SECS: Duration = Duration::from_secs(2);
pub(super) trait Limiter: Send + Sync {
fn create_handle(
&self,
peer_id: NodeId,
validator_id: Option<PublicKey>,
) -> Box<dyn LimiterHandle>;
fn update_validators(
&self,
active_validators: HashSet<PublicKey>,
upcoming_validators: HashSet<PublicKey>,
);
}
#[async_trait]
pub(super) trait LimiterHandle: Send + Sync {
async fn request_allowance(&self, amount: u32);
}
#[derive(Debug)]
pub(super) struct Unlimited;
struct UnlimitedHandle;
impl Limiter for Unlimited {
fn create_handle(
&self,
_peer_id: NodeId,
_validator_id: Option<PublicKey>,
) -> Box<dyn LimiterHandle> {
Box::new(UnlimitedHandle)
}
fn update_validators(
&self,
_active_validators: HashSet<PublicKey>,
_upcoming_validators: HashSet<PublicKey>,
) {
}
}
#[async_trait]
impl LimiterHandle for UnlimitedHandle {
async fn request_allowance(&self, _amount: u32) {
}
}
#[derive(Debug)]
pub(super) struct ClassBasedLimiter {
sender: mpsc::UnboundedSender<ClassBasedCommand>,
}
enum PeerClass {
ActiveValidator,
UpcomingValidator,
Bulk,
}
enum ClassBasedCommand {
UpdateValidators {
active_validators: HashSet<PublicKey>,
upcoming_validators: HashSet<PublicKey>,
},
RequestResource {
amount: u32,
id: Arc<ConsumerId>,
responder: oneshot::Sender<()>,
},
Shutdown,
}
#[derive(Debug)]
struct ClassBasedHandle {
sender: mpsc::UnboundedSender<ClassBasedCommand>,
consumer_id: Arc<ConsumerId>,
}
#[derive(Debug)]
struct ConsumerId {
#[allow(dead_code)]
peer_id: NodeId,
validator_id: Option<PublicKey>,
}
impl ClassBasedLimiter {
pub(super) fn new(resources_per_second: u32) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
tokio::spawn(worker(
receiver,
resources_per_second,
((resources_per_second as f64) * STORED_BUFFER_SECS.as_secs_f64()) as u32,
));
ClassBasedLimiter { sender }
}
}
impl Limiter for ClassBasedLimiter {
fn create_handle(
&self,
peer_id: NodeId,
validator_id: Option<PublicKey>,
) -> Box<dyn LimiterHandle> {
Box::new(ClassBasedHandle {
sender: self.sender.clone(),
consumer_id: Arc::new(ConsumerId {
peer_id,
validator_id,
}),
})
}
fn update_validators(
&self,
active_validators: HashSet<PublicKey>,
upcoming_validators: HashSet<PublicKey>,
) {
if self
.sender
.send(ClassBasedCommand::UpdateValidators {
active_validators,
upcoming_validators,
})
.is_err()
{
debug!("could not update validator data set of limiter, channel closed");
}
}
}
#[async_trait]
impl LimiterHandle for ClassBasedHandle {
async fn request_allowance(&self, amount: u32) {
let (responder, waiter) = oneshot::channel();
if self
.sender
.send(ClassBasedCommand::RequestResource {
amount,
id: self.consumer_id.clone(),
responder,
})
.is_err()
{
debug!("worker was shutdown, sending is unlimited");
} else if waiter.await.is_err() {
debug!("failed to await resource allowance, unlimited");
}
}
}
impl Drop for ClassBasedLimiter {
fn drop(&mut self) {
if self.sender.send(ClassBasedCommand::Shutdown).is_err() {
warn!("error sending shutdown command to class based limiter");
}
}
}
async fn worker(
mut receiver: mpsc::UnboundedReceiver<ClassBasedCommand>,
resources_per_second: u32,
max_stored_resource: u32,
) {
let mut active_validators = HashSet::new();
let mut upcoming_validators = HashSet::new();
let mut resources_available: i64 = 0;
let mut last_refill: Instant = Instant::now();
let mut logged_uninitialized = false;
while let Some(msg) = receiver.recv().await {
match msg {
ClassBasedCommand::UpdateValidators {
active_validators: new_active_validators,
upcoming_validators: new_upcoming_validators,
} => {
active_validators = new_active_validators;
upcoming_validators = new_upcoming_validators;
debug!(
?active_validators,
?upcoming_validators,
"resource classes updated"
);
}
ClassBasedCommand::RequestResource {
amount,
id,
responder,
} => {
if active_validators.is_empty() && upcoming_validators.is_empty() {
if !logged_uninitialized {
logged_uninitialized = true;
info!("empty set of validators, not limiting resources at all");
}
continue;
}
let peer_class = if let Some(ref validator_id) = id.validator_id {
if active_validators.contains(validator_id) {
PeerClass::ActiveValidator
} else if upcoming_validators.contains(validator_id) {
PeerClass::UpcomingValidator
} else {
PeerClass::Bulk
}
} else {
PeerClass::Bulk
};
match peer_class {
PeerClass::ActiveValidator | PeerClass::UpcomingValidator => {
}
PeerClass::Bulk => {
while resources_available < 0 {
let now = Instant::now();
let elapsed = now - last_refill;
last_refill = now;
resources_available +=
((elapsed.as_nanos() * resources_per_second as u128)
/ 1_000_000_000) as i64;
resources_available =
resources_available.min(max_stored_resource as i64);
if resources_available < 0 {
let estimated_time_remaining = Duration::from_millis(
(-resources_available) as u64 * 1000
/ resources_per_second as u64,
);
tokio::time::sleep(estimated_time_remaining).await;
}
}
resources_available -= amount as i64;
}
}
if responder.send(()).is_err() {
debug!("resource requester disappeared before we could answer.")
}
}
ClassBasedCommand::Shutdown => {
receiver.close();
}
}
}
debug!("class based worker exiting");
}
#[cfg(test)]
mod tests {
use std::{collections::HashSet, time::Duration};
use tokio::time::Instant;
use super::{ClassBasedLimiter, Limiter, NodeId, PublicKey, Unlimited};
use crate::{crypto::AsymmetricKeyExt, testing::init_logging};
const SHORT_TIME: Duration = Duration::from_millis(250);
#[tokio::test]
async fn unlimited_limiter_is_unlimited() {
let mut rng = crate::new_rng();
let unlimited = Unlimited;
let handle = unlimited.create_handle(NodeId::random(&mut rng), None);
let start = Instant::now();
handle.request_allowance(0).await;
handle.request_allowance(u32::MAX).await;
handle.request_allowance(1).await;
let end = Instant::now();
assert!(end - start < SHORT_TIME);
}
#[tokio::test]
async fn active_validator_is_unlimited() {
let mut rng = crate::new_rng();
let validator_id = PublicKey::random(&mut rng);
let limiter = ClassBasedLimiter::new(1_000);
let mut active_validators = HashSet::new();
active_validators.insert(validator_id.clone());
limiter.update_validators(active_validators, HashSet::new());
let handle = limiter.create_handle(NodeId::random(&mut rng), Some(validator_id));
let start = Instant::now();
handle.request_allowance(0).await;
handle.request_allowance(u32::MAX).await;
handle.request_allowance(1).await;
let end = Instant::now();
assert!(end - start < SHORT_TIME);
}
#[tokio::test]
async fn inactive_validator_limited() {
let mut rng = crate::new_rng();
let validator_id = PublicKey::random(&mut rng);
let limiter = ClassBasedLimiter::new(1_000);
let mut active_validators = HashSet::new();
active_validators.insert(PublicKey::random(&mut rng));
limiter.update_validators(active_validators, HashSet::new());
let handles = vec![
limiter.create_handle(NodeId::random(&mut rng), Some(validator_id)),
limiter.create_handle(NodeId::random(&mut rng), None),
];
for handle in handles {
let start = Instant::now();
handle.request_allowance(1000).await;
handle.request_allowance(1000).await;
handle.request_allowance(1000).await;
handle.request_allowance(2000).await;
handle.request_allowance(4000).await;
handle.request_allowance(1).await;
let end = Instant::now();
let diff = end - start;
assert!(diff >= Duration::from_secs(9));
assert!(diff <= Duration::from_secs(10));
}
}
#[tokio::test]
async fn nonvalidators_parallel_limited() {
let mut rng = crate::new_rng();
let validator_id = PublicKey::random(&mut rng);
let limiter = ClassBasedLimiter::new(1_000);
let start = Instant::now();
let mut active_validators = HashSet::new();
active_validators.insert(PublicKey::random(&mut rng));
limiter.update_validators(active_validators, HashSet::new());
let join_handles = (0..5)
.map(|_| limiter.create_handle(NodeId::random(&mut rng), Some(validator_id.clone())))
.map(|handle| {
tokio::spawn(async move {
handle.request_allowance(500).await;
handle.request_allowance(150).await;
handle.request_allowance(350).await;
handle.request_allowance(1).await;
})
});
for join_handle in join_handles {
join_handle.await.expect("could not join task");
}
let end = Instant::now();
let diff = end - start;
assert!(diff >= Duration::from_secs(5));
assert!(diff <= Duration::from_secs(6));
}
#[tokio::test]
async fn inactive_validators_unlimited_when_no_validators_known() {
init_logging();
let mut rng = crate::new_rng();
let validator_id = PublicKey::random(&mut rng);
let limiter = ClassBasedLimiter::new(1_000);
limiter.update_validators(HashSet::new(), HashSet::new());
let handles = vec![
limiter.create_handle(NodeId::random(&mut rng), Some(validator_id)),
limiter.create_handle(NodeId::random(&mut rng), None),
];
for handle in handles {
let start = Instant::now();
handle.request_allowance(1000).await;
handle.request_allowance(1000).await;
handle.request_allowance(1000).await;
handle.request_allowance(2000).await;
handle.request_allowance(4000).await;
handle.request_allowance(1).await;
let end = Instant::now();
let diff = end - start;
assert!(diff <= SHORT_TIME);
}
}
}