use bitvec::order::Lsb0;
use bitvec::slice::BitSlice;
use borsh::{BorshDeserialize, BorshSerialize};
use near_parameters::RuntimeConfig;
use near_primitives_core::hash::CryptoHash;
use near_primitives_core::types::ShardId;
use near_schema_checker_lib::ProtocolSchema;
use std::collections::BTreeMap;
use std::num::NonZeroU64;
pub type Bandwidth = u64;
#[derive(
BorshSerialize,
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
Debug,
Clone,
PartialEq,
Eq,
ProtocolSchema,
)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum BandwidthRequests {
V1(BandwidthRequestsV1) = 0,
}
impl BandwidthRequests {
pub fn empty() -> BandwidthRequests {
BandwidthRequests::V1(BandwidthRequestsV1 { requests: Vec::new() })
}
}
#[derive(
BorshSerialize,
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
Default,
Debug,
Clone,
PartialEq,
Eq,
ProtocolSchema,
)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct BandwidthRequestsV1 {
pub requests: Vec<BandwidthRequest>,
}
#[derive(
BorshSerialize,
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
Debug,
Clone,
PartialEq,
Eq,
ProtocolSchema,
)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct BandwidthRequest {
pub to_shard: u16,
pub requested_values_bitmap: BandwidthRequestBitmap,
}
impl BandwidthRequest {
pub fn make_from_receipt_sizes<E>(
to_shard: ShardId,
receipt_sizes: impl Iterator<Item = Result<u64, E>>,
params: &BandwidthSchedulerParams,
) -> Result<Option<BandwidthRequest>, E> {
let values = BandwidthRequestValues::new(params).values;
let mut bitmap = BandwidthRequestBitmap::new();
let mut total_size: u64 = 0;
let mut cur_value_idx: usize = 0;
for receipt_size_res in receipt_sizes {
let receipt_size = receipt_size_res?;
total_size = total_size.checked_add(receipt_size).expect(
"Total size of receipts doesn't fit in u64, are there exabytes of receipts?",
);
if total_size <= params.base_bandwidth {
continue;
}
while cur_value_idx < values.len() && values[cur_value_idx] < total_size {
cur_value_idx += 1;
}
if cur_value_idx == values.len() {
break;
}
bitmap.set_bit(cur_value_idx, true);
}
if bitmap.is_all_zeros() {
return Ok(None);
}
Ok(Some(BandwidthRequest { to_shard: to_shard.into(), requested_values_bitmap: bitmap }))
}
}
pub const BANDWIDTH_REQUEST_VALUES_NUM: usize = 40;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BandwidthRequestValues {
pub values: [Bandwidth; BANDWIDTH_REQUEST_VALUES_NUM],
}
fn interpolate(min: u64, max: u64, i: u64, n: u64) -> u64 {
min + (max - min) * i / n
}
impl BandwidthRequestValues {
pub fn new(params: &BandwidthSchedulerParams) -> BandwidthRequestValues {
let mut values = [0; BANDWIDTH_REQUEST_VALUES_NUM];
let values_len: u64 =
values.len().try_into().expect("Converting usize to u64 shouldn't fail");
for i in 0..values.len() {
let i_u64: u64 = i.try_into().expect("Converting usize to u64 shouldn't fail");
values[i] =
interpolate(params.base_bandwidth, params.max_single_grant, i_u64 + 1, values_len);
}
BandwidthRequestValues { values }
}
}
#[derive(
BorshSerialize,
BorshDeserialize,
serde::Serialize,
serde::Deserialize,
Clone,
PartialEq,
Eq,
ProtocolSchema,
)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct BandwidthRequestBitmap {
pub data: [u8; BANDWIDTH_REQUEST_BITMAP_SIZE],
}
pub const BANDWIDTH_REQUEST_BITMAP_SIZE: usize = BANDWIDTH_REQUEST_VALUES_NUM / 8;
const _: () = assert!(
BANDWIDTH_REQUEST_VALUES_NUM % 8 == 0,
"Every bit in the bitmap should be used. It's wasteful to have unused bits.
And having unused bits would require extra validation logic"
);
impl std::fmt::Debug for BandwidthRequestBitmap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BandwidthRequestBitmap(")?;
for i in 0..self.len() {
if self.get_bit(i) {
write!(f, "1")?;
} else {
write!(f, "0")?;
}
}
write!(f, ")")
}
}
impl BandwidthRequestBitmap {
pub fn new() -> BandwidthRequestBitmap {
BandwidthRequestBitmap { data: [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE] }
}
pub fn get_bit(&self, idx: usize) -> bool {
assert!(idx < self.len());
let bit_slice = BitSlice::<_, Lsb0>::from_slice(self.data.as_slice());
*bit_slice.get(idx).unwrap()
}
pub fn set_bit(&mut self, idx: usize, val: bool) {
assert!(idx < self.len());
let bit_slice = BitSlice::<_, Lsb0>::from_slice_mut(self.data.as_mut_slice());
bit_slice.set(idx, val);
}
pub fn len(&self) -> usize {
BANDWIDTH_REQUEST_VALUES_NUM
}
pub fn is_all_zeros(&self) -> bool {
self.data == [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BlockBandwidthRequests {
pub shards_bandwidth_requests: BTreeMap<ShardId, BandwidthRequests>,
}
impl BlockBandwidthRequests {
pub fn empty() -> BlockBandwidthRequests {
BlockBandwidthRequests { shards_bandwidth_requests: BTreeMap::new() }
}
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
pub enum BandwidthSchedulerState {
V1(BandwidthSchedulerStateV1) = 0,
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
pub struct BandwidthSchedulerStateV1 {
pub link_allowances: Vec<LinkAllowance>,
pub sanity_check_hash: CryptoHash,
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
pub struct LinkAllowance {
pub sender: ShardId,
pub receiver: ShardId,
pub allowance: Bandwidth,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct BandwidthSchedulerParams {
pub base_bandwidth: Bandwidth,
pub max_shard_bandwidth: Bandwidth,
pub max_single_grant: Bandwidth,
pub max_receipt_size: Bandwidth,
pub max_allowance: Bandwidth,
}
impl BandwidthSchedulerParams {
pub fn new(num_shards: NonZeroU64, runtime_config: &RuntimeConfig) -> BandwidthSchedulerParams {
let scheduler_config = runtime_config.bandwidth_scheduler_config;
Self::calculate(
scheduler_config.max_shard_bandwidth,
scheduler_config.max_single_grant,
scheduler_config.max_allowance,
scheduler_config.max_base_bandwidth,
runtime_config.wasm_config.limit_config.max_receipt_size,
num_shards.get(),
)
}
fn calculate(
max_shard_bandwidth: Bandwidth,
max_single_grant: Bandwidth,
max_allowance: Bandwidth,
max_base_bandwidth: Bandwidth,
max_receipt_size: Bandwidth,
num_shards: u64,
) -> BandwidthSchedulerParams {
assert!(
max_single_grant >= max_receipt_size,
"A max_single_grant can't be lower than max_receipt_size - it'll be impossible to send a max size receipt"
);
assert!(
max_single_grant <= max_shard_bandwidth,
"A single grant must not be greater than max_shard_bandwidth"
);
let available_bandwidth = max_shard_bandwidth - max_single_grant;
let mut base_bandwidth = available_bandwidth / std::cmp::max(1, num_shards - 1);
if base_bandwidth > max_base_bandwidth {
base_bandwidth = max_base_bandwidth;
}
BandwidthSchedulerParams {
base_bandwidth,
max_shard_bandwidth,
max_single_grant,
max_receipt_size,
max_allowance,
}
}
pub fn for_test(num_shards: u64) -> BandwidthSchedulerParams {
let max_shard_bandwidth = 4_500_000;
let max_single_grant = 4 * 1024 * 1024;
let max_allowance = max_shard_bandwidth;
let max_base_bandwidth = 100_000;
let max_receipt_size = 4 * 1024 * 1024;
Self::calculate(
max_shard_bandwidth,
max_single_grant,
max_allowance,
max_base_bandwidth,
max_receipt_size,
num_shards,
)
}
}
#[cfg(test)]
mod tests {
use super::{
BandwidthRequest, BandwidthRequestBitmap, BandwidthRequestValues, BandwidthSchedulerParams,
};
use crate::bandwidth_scheduler::{BANDWIDTH_REQUEST_VALUES_NUM, interpolate};
use crate::shard_layout::ShardUId;
use near_parameters::RuntimeConfig;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use std::num::NonZeroU64;
use std::ops::Deref;
use std::sync::Arc;
fn make_runtime_config(max_receipt_size: u64) -> RuntimeConfig {
let mut runtime_config = RuntimeConfig::test();
let mut wasm_config = runtime_config.wasm_config.deref().clone();
wasm_config.limit_config.max_receipt_size = max_receipt_size;
runtime_config.wasm_config = Arc::new(wasm_config);
runtime_config
}
fn assert_max_size_can_get_through(params: &BandwidthSchedulerParams, num_shards: u64) {
assert!(
(num_shards - 1) * params.base_bandwidth + params.max_receipt_size
<= params.max_shard_bandwidth
)
}
#[test]
fn test_scheduler_params_one_shard() {
let max_receipt_size = 4 * 1024 * 1024;
let num_shards = 1;
let runtime_config = make_runtime_config(max_receipt_size);
let scheduler_params =
BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
let expected = BandwidthSchedulerParams {
base_bandwidth: 100_000,
max_shard_bandwidth: 4_500_000,
max_single_grant: 4 * 1024 * 1024,
max_receipt_size,
max_allowance: 4_500_000,
};
assert_eq!(scheduler_params, expected);
assert_max_size_can_get_through(&scheduler_params, num_shards);
}
#[test]
fn test_scheduler_params_six_shards() {
let max_receipt_size = 4 * 1024 * 1024;
let num_shards = 6;
let runtime_config = make_runtime_config(max_receipt_size);
let scheduler_params =
BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
let expected = BandwidthSchedulerParams {
base_bandwidth: (4_500_000 - max_receipt_size) / 5,
max_shard_bandwidth: 4_500_000,
max_single_grant: 4 * 1024 * 1024,
max_receipt_size,
max_allowance: 4_500_000,
};
assert_eq!(scheduler_params, expected);
assert_max_size_can_get_through(&scheduler_params, num_shards);
}
#[test]
#[should_panic]
fn test_scheduler_params_invalid_config() {
let max_receipt_size = 40 * 1024 * 1024;
let num_shards = 6;
let runtime_config = make_runtime_config(max_receipt_size);
BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
}
#[test]
fn test_bandwidth_request_bitmap() {
let mut bitmap = BandwidthRequestBitmap::new();
assert_eq!(bitmap.len(), BANDWIDTH_REQUEST_VALUES_NUM);
let mut fake_bitmap = [false; BANDWIDTH_REQUEST_VALUES_NUM];
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
for _ in 0..(BANDWIDTH_REQUEST_VALUES_NUM * 5) {
let random_index = rng.gen_range(0..BANDWIDTH_REQUEST_VALUES_NUM);
let value = rng.gen_bool(0.5);
bitmap.set_bit(random_index, value);
fake_bitmap[random_index] = value;
for i in 0..BANDWIDTH_REQUEST_VALUES_NUM {
assert_eq!(bitmap.get_bit(i), fake_bitmap[i]);
}
}
}
#[test]
fn test_bandwidth_request_values() {
let max_receipt_size = 4 * 1024 * 1024;
let params = BandwidthSchedulerParams::new(
NonZeroU64::new(6).unwrap(),
&make_runtime_config(max_receipt_size),
);
let values = BandwidthRequestValues::new(¶ms);
assert!(values.values[0] > params.base_bandwidth);
assert_eq!(values.values[BANDWIDTH_REQUEST_VALUES_NUM - 1], params.max_single_grant);
assert_eq!(params.base_bandwidth, 61139);
assert_eq!(
values.values,
[
164468, 267797, 371126, 474455, 577784, 681113, 784442, 887772, 991101, 1094430,
1197759, 1301088, 1404417, 1507746, 1611075, 1714405, 1817734, 1921063, 2024392,
2127721, 2231050, 2334379, 2437708, 2541038, 2644367, 2747696, 2851025, 2954354,
3057683, 3161012, 3264341, 3367671, 3471000, 3574329, 3677658, 3780987, 3884316,
3987645, 4090974, 4194304
]
);
}
fn make_request_with_ones(ones_indexes: &[usize]) -> BandwidthRequest {
let mut req = BandwidthRequest {
to_shard: ShardUId::single_shard().shard_id().into(),
requested_values_bitmap: BandwidthRequestBitmap::new(),
};
for i in ones_indexes {
req.requested_values_bitmap.set_bit(*i, true);
}
req
}
fn make_sizes_iter<'a>(
sizes: &'a [u64],
) -> impl Iterator<Item = Result<u64, std::convert::Infallible>> + 'a {
sizes.iter().map(|&size| Ok(size))
}
#[test]
fn test_make_bandwidth_request_from_receipt_sizes() {
let max_receipt_size = 4 * 1024 * 1024;
let params = BandwidthSchedulerParams::new(
NonZeroU64::new(6).unwrap(),
&make_runtime_config(max_receipt_size),
);
let values = BandwidthRequestValues::new(¶ms).values;
let get_request = |receipt_sizes: &[u64]| -> Option<BandwidthRequest> {
BandwidthRequest::make_from_receipt_sizes(
ShardUId::single_shard().shard_id(),
make_sizes_iter(receipt_sizes),
¶ms,
)
.unwrap()
};
assert_eq!(get_request(&[]), None);
let below_base_bandwidth_receipts = [10_000, 20, 999, 2362, 3343, 232, 22];
assert!(below_base_bandwidth_receipts.iter().sum::<u64>() < params.base_bandwidth);
assert_eq!(get_request(&below_base_bandwidth_receipts), None);
let equal_to_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000];
assert_eq!(equal_to_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth);
assert_eq!(get_request(&equal_to_base_bandwidth_receipts), None);
let above_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000, 1];
assert_eq!(above_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth + 1);
assert_eq!(get_request(&above_base_bandwidth_receipts), Some(make_request_with_ones(&[0])));
let above_base_bandwidth_one_receipt = [params.base_bandwidth + 1];
assert_eq!(
get_request(&above_base_bandwidth_one_receipt),
Some(make_request_with_ones(&[0]))
);
let in_between_value = (values[values.len() / 2] + values[values.len() / 2 + 1]) / 2;
assert!(!values.contains(&in_between_value));
let in_between_size_receipt = [in_between_value];
assert_eq!(
get_request(&in_between_size_receipt),
Some(make_request_with_ones(&[values.len() / 2 + 1]))
);
let max_size_receipt = [max_receipt_size];
let max_size_receipt_value_idx =
values.iter().position(|v| *v == max_receipt_size).unwrap();
assert_eq!(
get_request(&max_size_receipt),
Some(make_request_with_ones(&[max_size_receipt_value_idx]))
);
assert!(2 * params.max_receipt_size > params.max_shard_bandwidth);
let two_max_size_receipts = [max_receipt_size, max_receipt_size];
assert_eq!(
get_request(&two_max_size_receipts),
Some(make_request_with_ones(&[max_size_receipt_value_idx]))
);
let lots_of_small_receipts: Vec<u64> = (0..10_000).into_iter().map(|_| 1_000).collect();
assert!(lots_of_small_receipts.iter().sum::<u64>() > params.max_shard_bandwidth);
let all_bitmap_indices: Vec<usize> = (0..BANDWIDTH_REQUEST_VALUES_NUM).collect();
assert_eq!(
get_request(&lots_of_small_receipts),
Some(make_request_with_ones(&all_bitmap_indices))
);
}
#[test]
fn test_make_bandwidth_request_random() {
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let max_receipt_size = 4 * 1024 * 1024;
let params = BandwidthSchedulerParams::new(
NonZeroU64::new(6).unwrap(),
&make_runtime_config(max_receipt_size),
);
let min_receipt_size = 5_000;
let max_receipts_num = params.max_shard_bandwidth / min_receipt_size * 3 / 2;
for _test_idx in 0..100 {
let num_receipts = rng.gen_range(0..=max_receipts_num);
let receipt_sizes: Vec<u64> = (0..num_receipts)
.map(|_| rng.gen_range(min_receipt_size..=max_receipt_size))
.collect();
let request = BandwidthRequest::make_from_receipt_sizes(
ShardUId::single_shard().shard_id(),
make_sizes_iter(&receipt_sizes),
¶ms,
)
.unwrap();
let expected_request =
make_bandwidth_request_slow(receipt_sizes.iter().copied(), ¶ms);
assert_eq!(request, expected_request);
}
}
fn make_bandwidth_request_slow(
receipt_sizes: impl Iterator<Item = u64>,
params: &BandwidthSchedulerParams,
) -> Option<BandwidthRequest> {
let mut request = BandwidthRequest {
to_shard: ShardUId::single_shard().shard_id().into(),
requested_values_bitmap: BandwidthRequestBitmap::new(),
};
let values = BandwidthRequestValues::new(params).values;
let mut total_size = 0;
for receipt_size in receipt_sizes {
total_size += receipt_size;
for i in 0..values.len() {
if values[i] >= total_size {
request.requested_values_bitmap.set_bit(i, true);
break;
}
}
}
if request.requested_values_bitmap.is_all_zeros() {
return None;
}
Some(request)
}
#[test]
fn test_interpolate() {
assert_eq!(interpolate(100, 200, 0, 10), 100);
assert_eq!(interpolate(100, 200, 5, 10), 150);
assert_eq!(interpolate(100, 200, 10, 10), 200);
}
}