1use std::collections::BTreeMap;
2use std::num::NonZeroU64;
3
4use bitvec::order::Lsb0;
5use bitvec::slice::BitSlice;
6use borsh::{BorshDeserialize, BorshSerialize};
7use near_parameters::RuntimeConfig;
8use near_primitives_core::hash::CryptoHash;
9use near_primitives_core::types::ShardId;
10use near_schema_checker_lib::ProtocolSchema;
11
12pub type Bandwidth = u64;
15
16#[derive(
19 BorshSerialize,
20 BorshDeserialize,
21 serde::Serialize,
22 serde::Deserialize,
23 Debug,
24 Clone,
25 PartialEq,
26 Eq,
27 ProtocolSchema,
28)]
29#[borsh(use_discriminant = true)]
30#[repr(u8)]
31#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
32pub enum BandwidthRequests {
33 V1(BandwidthRequestsV1) = 0,
34}
35
36impl BandwidthRequests {
37 pub fn empty() -> BandwidthRequests {
38 BandwidthRequests::V1(BandwidthRequestsV1 { requests: Vec::new() })
39 }
40}
41
42#[derive(
44 BorshSerialize,
45 BorshDeserialize,
46 serde::Serialize,
47 serde::Deserialize,
48 Default,
49 Debug,
50 Clone,
51 PartialEq,
52 Eq,
53 ProtocolSchema,
54)]
55#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
56pub struct BandwidthRequestsV1 {
57 pub requests: Vec<BandwidthRequest>,
58}
59
60#[derive(
64 BorshSerialize,
65 BorshDeserialize,
66 serde::Serialize,
67 serde::Deserialize,
68 Debug,
69 Clone,
70 PartialEq,
71 Eq,
72 ProtocolSchema,
73)]
74#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
75pub struct BandwidthRequest {
76 pub to_shard: u16,
78 pub requested_values_bitmap: BandwidthRequestBitmap,
80}
81
82impl BandwidthRequest {
83 pub fn make_from_receipt_sizes<E>(
86 to_shard: ShardId,
87 receipt_sizes: impl Iterator<Item = Result<u64, E>>,
88 params: &BandwidthSchedulerParams,
89 ) -> Result<Option<BandwidthRequest>, E> {
90 let values = BandwidthRequestValues::new(params).values;
91 let mut bitmap = BandwidthRequestBitmap::new();
92
93 let mut total_size: u64 = 0;
97 let mut cur_value_idx: usize = 0;
98 for receipt_size_res in receipt_sizes {
99 let receipt_size = receipt_size_res?;
100 total_size = total_size.checked_add(receipt_size).expect(
101 "Total size of receipts doesn't fit in u64, are there exabytes of receipts?",
102 );
103
104 if total_size <= params.base_bandwidth {
105 continue;
106 }
107
108 while cur_value_idx < values.len() && values[cur_value_idx] < total_size {
110 cur_value_idx += 1;
111 }
112
113 if cur_value_idx == values.len() {
114 break;
116 }
117
118 bitmap.set_bit(cur_value_idx, true);
120 }
121
122 if bitmap.is_all_zeros() {
123 return Ok(None);
125 }
126
127 Ok(Some(BandwidthRequest { to_shard: to_shard.into(), requested_values_bitmap: bitmap }))
128 }
129}
130
131pub const BANDWIDTH_REQUEST_VALUES_NUM: usize = 40;
133
134#[derive(Clone, Debug, PartialEq, Eq)]
138pub struct BandwidthRequestValues {
139 pub values: [Bandwidth; BANDWIDTH_REQUEST_VALUES_NUM],
140}
141
142fn interpolate(min: u64, max: u64, i: u64, n: u64) -> u64 {
147 min + (max - min) * i / n
148}
149
150impl BandwidthRequestValues {
151 pub fn new(params: &BandwidthSchedulerParams) -> BandwidthRequestValues {
152 let mut values = [0; BANDWIDTH_REQUEST_VALUES_NUM];
157
158 let values_len: u64 =
159 values.len().try_into().expect("Converting usize to u64 shouldn't fail");
160 for i in 0..values.len() {
161 let i_u64: u64 = i.try_into().expect("Converting usize to u64 shouldn't fail");
162
163 values[i] =
164 interpolate(params.base_bandwidth, params.max_single_grant, i_u64 + 1, values_len);
165 }
166
167 BandwidthRequestValues { values }
168 }
169}
170
171#[derive(
174 BorshSerialize,
175 BorshDeserialize,
176 serde::Serialize,
177 serde::Deserialize,
178 Clone,
179 PartialEq,
180 Eq,
181 ProtocolSchema,
182)]
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184pub struct BandwidthRequestBitmap {
185 pub data: [u8; BANDWIDTH_REQUEST_BITMAP_SIZE],
186}
187
188pub const BANDWIDTH_REQUEST_BITMAP_SIZE: usize = BANDWIDTH_REQUEST_VALUES_NUM / 8;
189const _: () = assert!(
190 BANDWIDTH_REQUEST_VALUES_NUM % 8 == 0,
191 "Every bit in the bitmap should be used. It's wasteful to have unused bits.
192 And having unused bits would require extra validation logic"
193);
194
195impl std::fmt::Debug for BandwidthRequestBitmap {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 write!(f, "BandwidthRequestBitmap(")?;
198 for i in 0..self.len() {
199 if self.get_bit(i) {
200 write!(f, "1")?;
201 } else {
202 write!(f, "0")?;
203 }
204 }
205 write!(f, ")")
206 }
207}
208
209impl BandwidthRequestBitmap {
210 pub fn new() -> BandwidthRequestBitmap {
211 BandwidthRequestBitmap { data: [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE] }
212 }
213
214 pub fn get_bit(&self, idx: usize) -> bool {
215 assert!(idx < self.len());
216
217 let bit_slice = BitSlice::<_, Lsb0>::from_slice(self.data.as_slice());
218 *bit_slice.get(idx).unwrap()
219 }
220
221 pub fn set_bit(&mut self, idx: usize, val: bool) {
222 assert!(idx < self.len());
223
224 let bit_slice = BitSlice::<_, Lsb0>::from_slice_mut(self.data.as_mut_slice());
225 bit_slice.set(idx, val);
226 }
227
228 pub fn len(&self) -> usize {
229 BANDWIDTH_REQUEST_VALUES_NUM
230 }
231
232 pub fn is_all_zeros(&self) -> bool {
233 self.data == [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE]
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq)]
239pub struct BlockBandwidthRequests {
240 pub shards_bandwidth_requests: BTreeMap<ShardId, BandwidthRequests>,
242}
243
244impl BlockBandwidthRequests {
245 pub fn empty() -> BlockBandwidthRequests {
246 BlockBandwidthRequests { shards_bandwidth_requests: BTreeMap::new() }
247 }
248}
249
250#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
256#[borsh(use_discriminant = true)]
257#[repr(u8)]
258pub enum BandwidthSchedulerState {
259 V1(BandwidthSchedulerStateV1) = 0,
260}
261
262#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
263pub struct BandwidthSchedulerStateV1 {
264 pub link_allowances: Vec<LinkAllowance>,
267 pub sanity_check_hash: CryptoHash,
270}
271
272#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
275pub struct LinkAllowance {
276 pub sender: ShardId,
278 pub receiver: ShardId,
280 pub allowance: Bandwidth,
284}
285
286#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
288pub struct BandwidthSchedulerParams {
289 pub base_bandwidth: Bandwidth,
292 pub max_shard_bandwidth: Bandwidth,
294 pub max_single_grant: Bandwidth,
297 pub max_receipt_size: Bandwidth,
299 pub max_allowance: Bandwidth,
301}
302
303impl BandwidthSchedulerParams {
304 pub fn new(num_shards: NonZeroU64, runtime_config: &RuntimeConfig) -> BandwidthSchedulerParams {
306 let scheduler_config = runtime_config.bandwidth_scheduler_config;
307
308 Self::calculate(
309 scheduler_config.max_shard_bandwidth,
310 scheduler_config.max_single_grant,
311 scheduler_config.max_allowance,
312 scheduler_config.max_base_bandwidth,
313 runtime_config.wasm_config.limit_config.max_receipt_size,
314 num_shards.get(),
315 )
316 }
317
318 fn calculate(
319 max_shard_bandwidth: Bandwidth,
320 max_single_grant: Bandwidth,
321 max_allowance: Bandwidth,
322 max_base_bandwidth: Bandwidth,
323 max_receipt_size: Bandwidth,
324 num_shards: u64,
325 ) -> BandwidthSchedulerParams {
326 assert!(
327 max_single_grant >= max_receipt_size,
328 "A max_single_grant can't be lower than max_receipt_size - it'll be impossible to send a max size receipt"
329 );
330 assert!(
331 max_single_grant <= max_shard_bandwidth,
332 "A single grant must not be greater than max_shard_bandwidth"
333 );
334
335 let available_bandwidth = max_shard_bandwidth - max_single_grant;
341 let mut base_bandwidth = available_bandwidth / std::cmp::max(1, num_shards - 1);
342 if base_bandwidth > max_base_bandwidth {
343 base_bandwidth = max_base_bandwidth;
344 }
345
346 BandwidthSchedulerParams {
347 base_bandwidth,
348 max_shard_bandwidth,
349 max_single_grant,
350 max_receipt_size,
351 max_allowance,
352 }
353 }
354
355 pub fn for_test(num_shards: u64) -> BandwidthSchedulerParams {
357 let max_shard_bandwidth = 4_500_000;
358 let max_single_grant = 4 * 1024 * 1024;
359 let max_allowance = max_shard_bandwidth;
360 let max_base_bandwidth = 100_000;
361 let max_receipt_size = 4 * 1024 * 1024;
362
363 Self::calculate(
364 max_shard_bandwidth,
365 max_single_grant,
366 max_allowance,
367 max_base_bandwidth,
368 max_receipt_size,
369 num_shards,
370 )
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use std::num::NonZeroU64;
377 use std::ops::Deref;
378 use std::sync::Arc;
379
380 use near_parameters::RuntimeConfig;
381 use rand::{Rng, SeedableRng};
382
383 use crate::bandwidth_scheduler::{BANDWIDTH_REQUEST_VALUES_NUM, interpolate};
384 use crate::shard_layout::ShardUId;
385
386 use super::{
387 BandwidthRequest, BandwidthRequestBitmap, BandwidthRequestValues, BandwidthSchedulerParams,
388 };
389 use rand_chacha::ChaCha20Rng;
390
391 fn make_runtime_config(max_receipt_size: u64) -> RuntimeConfig {
392 let mut runtime_config = RuntimeConfig::test();
393
394 let mut wasm_config = runtime_config.wasm_config.deref().clone();
396 wasm_config.limit_config.max_receipt_size = max_receipt_size;
397 runtime_config.wasm_config = Arc::new(wasm_config);
398
399 runtime_config
400 }
401
402 fn assert_max_size_can_get_through(params: &BandwidthSchedulerParams, num_shards: u64) {
405 assert!(
406 (num_shards - 1) * params.base_bandwidth + params.max_receipt_size
407 <= params.max_shard_bandwidth
408 )
409 }
410
411 #[test]
412 fn test_scheduler_params_one_shard() {
413 let max_receipt_size = 4 * 1024 * 1024;
414 let num_shards = 1;
415
416 let runtime_config = make_runtime_config(max_receipt_size);
417 let scheduler_params =
418 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
419 let expected = BandwidthSchedulerParams {
420 base_bandwidth: 100_000,
421 max_shard_bandwidth: 4_500_000,
422 max_single_grant: 4 * 1024 * 1024,
423 max_receipt_size,
424 max_allowance: 4_500_000,
425 };
426 assert_eq!(scheduler_params, expected);
427 assert_max_size_can_get_through(&scheduler_params, num_shards);
428 }
429
430 #[test]
431 fn test_scheduler_params_six_shards() {
432 let max_receipt_size = 4 * 1024 * 1024;
433 let num_shards = 6;
434
435 let runtime_config = make_runtime_config(max_receipt_size);
436 let scheduler_params =
437 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
438 let expected = BandwidthSchedulerParams {
439 base_bandwidth: (4_500_000 - max_receipt_size) / 5,
440 max_shard_bandwidth: 4_500_000,
441 max_single_grant: 4 * 1024 * 1024,
442 max_receipt_size,
443 max_allowance: 4_500_000,
444 };
445 assert_eq!(scheduler_params, expected);
446 assert_max_size_can_get_through(&scheduler_params, num_shards);
447 }
448
449 #[test]
451 #[should_panic]
452 fn test_scheduler_params_invalid_config() {
453 let max_receipt_size = 40 * 1024 * 1024;
454 let num_shards = 6;
455 let runtime_config = make_runtime_config(max_receipt_size);
456 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
457 }
458
459 #[test]
460 fn test_bandwidth_request_bitmap() {
461 let mut bitmap = BandwidthRequestBitmap::new();
462 assert_eq!(bitmap.len(), BANDWIDTH_REQUEST_VALUES_NUM);
463
464 let mut fake_bitmap = [false; BANDWIDTH_REQUEST_VALUES_NUM];
465 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
466
467 for _ in 0..(BANDWIDTH_REQUEST_VALUES_NUM * 5) {
468 let random_index = rng.gen_range(0..BANDWIDTH_REQUEST_VALUES_NUM);
469 let value = rng.gen_bool(0.5);
470
471 bitmap.set_bit(random_index, value);
472 fake_bitmap[random_index] = value;
473
474 for i in 0..BANDWIDTH_REQUEST_VALUES_NUM {
475 assert_eq!(bitmap.get_bit(i), fake_bitmap[i]);
476 }
477 }
478 }
479
480 #[test]
481 fn test_bandwidth_request_values() {
482 let max_receipt_size = 4 * 1024 * 1024;
483
484 let params = BandwidthSchedulerParams::new(
485 NonZeroU64::new(6).unwrap(),
486 &make_runtime_config(max_receipt_size),
487 );
488 let values = BandwidthRequestValues::new(¶ms);
489
490 assert!(values.values[0] > params.base_bandwidth);
491 assert_eq!(values.values[BANDWIDTH_REQUEST_VALUES_NUM - 1], params.max_single_grant);
492
493 assert_eq!(params.base_bandwidth, 61139);
494 assert_eq!(
495 values.values,
496 [
497 164468, 267797, 371126, 474455, 577784, 681113, 784442, 887772, 991101, 1094430,
498 1197759, 1301088, 1404417, 1507746, 1611075, 1714405, 1817734, 1921063, 2024392,
499 2127721, 2231050, 2334379, 2437708, 2541038, 2644367, 2747696, 2851025, 2954354,
500 3057683, 3161012, 3264341, 3367671, 3471000, 3574329, 3677658, 3780987, 3884316,
501 3987645, 4090974, 4194304
502 ]
503 );
504 }
505
506 fn make_request_with_ones(ones_indexes: &[usize]) -> BandwidthRequest {
508 let mut req = BandwidthRequest {
509 to_shard: ShardUId::single_shard().shard_id().into(),
510 requested_values_bitmap: BandwidthRequestBitmap::new(),
511 };
512 for i in ones_indexes {
513 req.requested_values_bitmap.set_bit(*i, true);
514 }
515 req
516 }
517
518 fn make_sizes_iter<'a>(
519 sizes: &'a [u64],
520 ) -> impl Iterator<Item = Result<u64, std::convert::Infallible>> + 'a {
521 sizes.iter().map(|&size| Ok(size))
522 }
523
524 #[test]
525 fn test_make_bandwidth_request_from_receipt_sizes() {
526 let max_receipt_size = 4 * 1024 * 1024;
527 let params = BandwidthSchedulerParams::new(
528 NonZeroU64::new(6).unwrap(),
529 &make_runtime_config(max_receipt_size),
530 );
531 let values = BandwidthRequestValues::new(¶ms).values;
532
533 let get_request = |receipt_sizes: &[u64]| -> Option<BandwidthRequest> {
534 BandwidthRequest::make_from_receipt_sizes(
535 ShardUId::single_shard().shard_id(),
536 make_sizes_iter(receipt_sizes),
537 ¶ms,
538 )
539 .unwrap()
540 };
541
542 assert_eq!(get_request(&[]), None);
544
545 let below_base_bandwidth_receipts = [10_000, 20, 999, 2362, 3343, 232, 22];
547 assert!(below_base_bandwidth_receipts.iter().sum::<u64>() < params.base_bandwidth);
548 assert_eq!(get_request(&below_base_bandwidth_receipts), None);
549
550 let equal_to_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000];
552 assert_eq!(equal_to_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth);
553 assert_eq!(get_request(&equal_to_base_bandwidth_receipts), None);
554
555 let above_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000, 1];
558 assert_eq!(above_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth + 1);
559 assert_eq!(get_request(&above_base_bandwidth_receipts), Some(make_request_with_ones(&[0])));
560
561 let above_base_bandwidth_one_receipt = [params.base_bandwidth + 1];
563 assert_eq!(
564 get_request(&above_base_bandwidth_one_receipt),
565 Some(make_request_with_ones(&[0]))
566 );
567
568 let in_between_value = (values[values.len() / 2] + values[values.len() / 2 + 1]) / 2;
571 assert!(!values.contains(&in_between_value));
572 let in_between_size_receipt = [in_between_value];
573 assert_eq!(
574 get_request(&in_between_size_receipt),
575 Some(make_request_with_ones(&[values.len() / 2 + 1]))
576 );
577
578 let max_size_receipt = [max_receipt_size];
580 let max_size_receipt_value_idx =
581 values.iter().position(|v| *v == max_receipt_size).unwrap();
582 assert_eq!(
583 get_request(&max_size_receipt),
584 Some(make_request_with_ones(&[max_size_receipt_value_idx]))
585 );
586
587 assert!(2 * params.max_receipt_size > params.max_shard_bandwidth);
590 let two_max_size_receipts = [max_receipt_size, max_receipt_size];
591 assert_eq!(
592 get_request(&two_max_size_receipts),
593 Some(make_request_with_ones(&[max_size_receipt_value_idx]))
594 );
595
596 let lots_of_small_receipts: Vec<u64> = (0..10_000).into_iter().map(|_| 1_000).collect();
599 assert!(lots_of_small_receipts.iter().sum::<u64>() > params.max_shard_bandwidth);
600 let all_bitmap_indices: Vec<usize> = (0..BANDWIDTH_REQUEST_VALUES_NUM).collect();
601 assert_eq!(
602 get_request(&lots_of_small_receipts),
603 Some(make_request_with_ones(&all_bitmap_indices))
604 );
605 }
606
607 #[test]
610 fn test_make_bandwidth_request_random() {
611 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
612 let max_receipt_size = 4 * 1024 * 1024;
613 let params = BandwidthSchedulerParams::new(
614 NonZeroU64::new(6).unwrap(),
615 &make_runtime_config(max_receipt_size),
616 );
617
618 let min_receipt_size = 5_000;
619 let max_receipts_num = params.max_shard_bandwidth / min_receipt_size * 3 / 2;
620
621 for _test_idx in 0..100 {
622 let num_receipts = rng.gen_range(0..=max_receipts_num);
623 let receipt_sizes: Vec<u64> = (0..num_receipts)
624 .map(|_| rng.gen_range(min_receipt_size..=max_receipt_size))
625 .collect();
626
627 let request = BandwidthRequest::make_from_receipt_sizes(
628 ShardUId::single_shard().shard_id(),
629 make_sizes_iter(&receipt_sizes),
630 ¶ms,
631 )
632 .unwrap();
633
634 let expected_request =
635 make_bandwidth_request_slow(receipt_sizes.iter().copied(), ¶ms);
636 assert_eq!(request, expected_request);
637 }
638 }
639
640 fn make_bandwidth_request_slow(
643 receipt_sizes: impl Iterator<Item = u64>,
644 params: &BandwidthSchedulerParams,
645 ) -> Option<BandwidthRequest> {
646 let mut request = BandwidthRequest {
647 to_shard: ShardUId::single_shard().shard_id().into(),
648 requested_values_bitmap: BandwidthRequestBitmap::new(),
649 };
650 let values = BandwidthRequestValues::new(params).values;
651
652 let mut total_size = 0;
653 for receipt_size in receipt_sizes {
654 total_size += receipt_size;
655
656 for i in 0..values.len() {
657 if values[i] >= total_size {
658 request.requested_values_bitmap.set_bit(i, true);
659 break;
660 }
661 }
662 }
663
664 if request.requested_values_bitmap.is_all_zeros() {
665 return None;
666 }
667
668 Some(request)
669 }
670
671 #[test]
672 fn test_interpolate() {
673 assert_eq!(interpolate(100, 200, 0, 10), 100);
674 assert_eq!(interpolate(100, 200, 5, 10), 150);
675 assert_eq!(interpolate(100, 200, 10, 10), 200);
676 }
677}