1use std::collections::BTreeMap;
2use std::num::NonZeroU64;
3
4use bitvec::order::Lsb0;
5use bitvec::slice::BitSlice;
6use borsh::{BorshDeserialize, BorshSerialize};
7use near_parameters::RuntimeConfig;
8use near_primitives_core::hash::CryptoHash;
9use near_primitives_core::types::ShardId;
10use near_schema_checker_lib::ProtocolSchema;
11
12pub type Bandwidth = u64;
15
16#[derive(
19 BorshSerialize,
20 BorshDeserialize,
21 serde::Serialize,
22 serde::Deserialize,
23 Debug,
24 Clone,
25 PartialEq,
26 Eq,
27 ProtocolSchema,
28)]
29#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
30pub enum BandwidthRequests {
31 V1(BandwidthRequestsV1),
32}
33
34impl BandwidthRequests {
35 pub fn empty() -> BandwidthRequests {
36 BandwidthRequests::V1(BandwidthRequestsV1 { requests: Vec::new() })
37 }
38}
39
40#[derive(
41 BorshSerialize,
42 BorshDeserialize,
43 serde::Serialize,
44 serde::Deserialize,
45 Default,
46 Debug,
47 Clone,
48 PartialEq,
49 Eq,
50 ProtocolSchema,
51)]
52#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
53pub struct BandwidthRequestsV1 {
54 pub requests: Vec<BandwidthRequest>,
55}
56
57#[derive(
61 BorshSerialize,
62 BorshDeserialize,
63 serde::Serialize,
64 serde::Deserialize,
65 Debug,
66 Clone,
67 PartialEq,
68 Eq,
69 ProtocolSchema,
70)]
71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
72pub struct BandwidthRequest {
73 pub to_shard: u16,
75 pub requested_values_bitmap: BandwidthRequestBitmap,
77}
78
79impl BandwidthRequest {
80 pub fn make_from_receipt_sizes<E>(
83 to_shard: ShardId,
84 receipt_sizes: impl Iterator<Item = Result<u64, E>>,
85 params: &BandwidthSchedulerParams,
86 ) -> Result<Option<BandwidthRequest>, E> {
87 let values = BandwidthRequestValues::new(params).values;
88 let mut bitmap = BandwidthRequestBitmap::new();
89
90 let mut total_size: u64 = 0;
94 let mut cur_value_idx: usize = 0;
95 for receipt_size_res in receipt_sizes {
96 let receipt_size = receipt_size_res?;
97 total_size = total_size.checked_add(receipt_size).expect(
98 "Total size of receipts doesn't fit in u64, are there exabytes of receipts?",
99 );
100
101 if total_size <= params.base_bandwidth {
102 continue;
103 }
104
105 while cur_value_idx < values.len() && values[cur_value_idx] < total_size {
107 cur_value_idx += 1;
108 }
109
110 if cur_value_idx == values.len() {
111 break;
113 }
114
115 bitmap.set_bit(cur_value_idx, true);
117 }
118
119 if bitmap.is_all_zeros() {
120 return Ok(None);
122 }
123
124 Ok(Some(BandwidthRequest { to_shard: to_shard.into(), requested_values_bitmap: bitmap }))
125 }
126}
127
128pub const BANDWIDTH_REQUEST_VALUES_NUM: usize = 40;
130
131#[derive(Clone, Debug, PartialEq, Eq)]
135pub struct BandwidthRequestValues {
136 pub values: [Bandwidth; BANDWIDTH_REQUEST_VALUES_NUM],
137}
138
139fn interpolate(min: u64, max: u64, i: u64, n: u64) -> u64 {
144 min + (max - min) * i / n
145}
146
147impl BandwidthRequestValues {
148 pub fn new(params: &BandwidthSchedulerParams) -> BandwidthRequestValues {
149 let mut values = [0; BANDWIDTH_REQUEST_VALUES_NUM];
154
155 let values_len: u64 =
156 values.len().try_into().expect("Converting usize to u64 shouldn't fail");
157 for i in 0..values.len() {
158 let i_u64: u64 = i.try_into().expect("Converting usize to u64 shouldn't fail");
159
160 values[i] =
161 interpolate(params.base_bandwidth, params.max_single_grant, i_u64 + 1, values_len);
162 }
163
164 BandwidthRequestValues { values }
165 }
166}
167
168#[derive(
171 BorshSerialize,
172 BorshDeserialize,
173 serde::Serialize,
174 serde::Deserialize,
175 Clone,
176 PartialEq,
177 Eq,
178 ProtocolSchema,
179)]
180#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
181pub struct BandwidthRequestBitmap {
182 pub data: [u8; BANDWIDTH_REQUEST_BITMAP_SIZE],
183}
184
185pub const BANDWIDTH_REQUEST_BITMAP_SIZE: usize = BANDWIDTH_REQUEST_VALUES_NUM / 8;
186const _: () = assert!(
187 BANDWIDTH_REQUEST_VALUES_NUM % 8 == 0,
188 "Every bit in the bitmap should be used. It's wasteful to have unused bits.
189 And having unused bits would require extra validation logic"
190);
191
192impl std::fmt::Debug for BandwidthRequestBitmap {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "BandwidthRequestBitmap(")?;
195 for i in 0..self.len() {
196 if self.get_bit(i) {
197 write!(f, "1")?;
198 } else {
199 write!(f, "0")?;
200 }
201 }
202 write!(f, ")")
203 }
204}
205
206impl BandwidthRequestBitmap {
207 pub fn new() -> BandwidthRequestBitmap {
208 BandwidthRequestBitmap { data: [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE] }
209 }
210
211 pub fn get_bit(&self, idx: usize) -> bool {
212 assert!(idx < self.len());
213
214 let bit_slice = BitSlice::<_, Lsb0>::from_slice(self.data.as_slice());
215 *bit_slice.get(idx).unwrap()
216 }
217
218 pub fn set_bit(&mut self, idx: usize, val: bool) {
219 assert!(idx < self.len());
220
221 let bit_slice = BitSlice::<_, Lsb0>::from_slice_mut(self.data.as_mut_slice());
222 bit_slice.set(idx, val);
223 }
224
225 pub fn len(&self) -> usize {
226 BANDWIDTH_REQUEST_VALUES_NUM
227 }
228
229 pub fn is_all_zeros(&self) -> bool {
230 self.data == [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE]
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
236pub struct BlockBandwidthRequests {
237 pub shards_bandwidth_requests: BTreeMap<ShardId, BandwidthRequests>,
239}
240
241impl BlockBandwidthRequests {
242 pub fn empty() -> BlockBandwidthRequests {
243 BlockBandwidthRequests { shards_bandwidth_requests: BTreeMap::new() }
244 }
245}
246
247#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
253pub enum BandwidthSchedulerState {
254 V1(BandwidthSchedulerStateV1),
255}
256
257#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
258pub struct BandwidthSchedulerStateV1 {
259 pub link_allowances: Vec<LinkAllowance>,
262 pub sanity_check_hash: CryptoHash,
265}
266
267#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
270pub struct LinkAllowance {
271 pub sender: ShardId,
273 pub receiver: ShardId,
275 pub allowance: Bandwidth,
279}
280
281#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
283pub struct BandwidthSchedulerParams {
284 pub base_bandwidth: Bandwidth,
287 pub max_shard_bandwidth: Bandwidth,
289 pub max_single_grant: Bandwidth,
292 pub max_receipt_size: Bandwidth,
294 pub max_allowance: Bandwidth,
296}
297
298impl BandwidthSchedulerParams {
299 pub fn new(num_shards: NonZeroU64, runtime_config: &RuntimeConfig) -> BandwidthSchedulerParams {
301 let scheduler_config = runtime_config.bandwidth_scheduler_config;
302
303 Self::calculate(
304 scheduler_config.max_shard_bandwidth,
305 scheduler_config.max_single_grant,
306 scheduler_config.max_allowance,
307 scheduler_config.max_base_bandwidth,
308 runtime_config.wasm_config.limit_config.max_receipt_size,
309 num_shards.get(),
310 )
311 }
312
313 fn calculate(
314 max_shard_bandwidth: Bandwidth,
315 max_single_grant: Bandwidth,
316 max_allowance: Bandwidth,
317 max_base_bandwidth: Bandwidth,
318 max_receipt_size: Bandwidth,
319 num_shards: u64,
320 ) -> BandwidthSchedulerParams {
321 assert!(
322 max_single_grant >= max_receipt_size,
323 "A max_single_grant can't be lower than max_receipt_size - it'll be impossible to send a max size receipt"
324 );
325 assert!(
326 max_single_grant <= max_shard_bandwidth,
327 "A single grant must not be greater than max_shard_bandwidth"
328 );
329
330 let available_bandwidth = max_shard_bandwidth - max_single_grant;
336 let mut base_bandwidth = available_bandwidth / std::cmp::max(1, num_shards - 1);
337 if base_bandwidth > max_base_bandwidth {
338 base_bandwidth = max_base_bandwidth;
339 }
340
341 BandwidthSchedulerParams {
342 base_bandwidth,
343 max_shard_bandwidth,
344 max_single_grant,
345 max_receipt_size,
346 max_allowance,
347 }
348 }
349
350 pub fn for_test(num_shards: u64) -> BandwidthSchedulerParams {
352 let max_shard_bandwidth = 4_500_000;
353 let max_single_grant = 4 * 1024 * 1024;
354 let max_allowance = max_shard_bandwidth;
355 let max_base_bandwidth = 100_000;
356 let max_receipt_size = 4 * 1024 * 1024;
357
358 Self::calculate(
359 max_shard_bandwidth,
360 max_single_grant,
361 max_allowance,
362 max_base_bandwidth,
363 max_receipt_size,
364 num_shards,
365 )
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use std::num::NonZeroU64;
372 use std::ops::Deref;
373 use std::sync::Arc;
374
375 use near_parameters::RuntimeConfig;
376 use rand::{Rng, SeedableRng};
377
378 use crate::bandwidth_scheduler::{BANDWIDTH_REQUEST_VALUES_NUM, interpolate};
379 use crate::shard_layout::ShardUId;
380
381 use super::{
382 BandwidthRequest, BandwidthRequestBitmap, BandwidthRequestValues, BandwidthSchedulerParams,
383 };
384 use rand_chacha::ChaCha20Rng;
385
386 fn make_runtime_config(max_receipt_size: u64) -> RuntimeConfig {
387 let mut runtime_config = RuntimeConfig::test();
388
389 let mut wasm_config = runtime_config.wasm_config.deref().clone();
391 wasm_config.limit_config.max_receipt_size = max_receipt_size;
392 runtime_config.wasm_config = Arc::new(wasm_config);
393
394 runtime_config
395 }
396
397 fn assert_max_size_can_get_through(params: &BandwidthSchedulerParams, num_shards: u64) {
400 assert!(
401 (num_shards - 1) * params.base_bandwidth + params.max_receipt_size
402 <= params.max_shard_bandwidth
403 )
404 }
405
406 #[test]
407 fn test_scheduler_params_one_shard() {
408 let max_receipt_size = 4 * 1024 * 1024;
409 let num_shards = 1;
410
411 let runtime_config = make_runtime_config(max_receipt_size);
412 let scheduler_params =
413 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
414 let expected = BandwidthSchedulerParams {
415 base_bandwidth: 100_000,
416 max_shard_bandwidth: 4_500_000,
417 max_single_grant: 4 * 1024 * 1024,
418 max_receipt_size,
419 max_allowance: 4_500_000,
420 };
421 assert_eq!(scheduler_params, expected);
422 assert_max_size_can_get_through(&scheduler_params, num_shards);
423 }
424
425 #[test]
426 fn test_scheduler_params_six_shards() {
427 let max_receipt_size = 4 * 1024 * 1024;
428 let num_shards = 6;
429
430 let runtime_config = make_runtime_config(max_receipt_size);
431 let scheduler_params =
432 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
433 let expected = BandwidthSchedulerParams {
434 base_bandwidth: (4_500_000 - max_receipt_size) / 5,
435 max_shard_bandwidth: 4_500_000,
436 max_single_grant: 4 * 1024 * 1024,
437 max_receipt_size,
438 max_allowance: 4_500_000,
439 };
440 assert_eq!(scheduler_params, expected);
441 assert_max_size_can_get_through(&scheduler_params, num_shards);
442 }
443
444 #[test]
446 #[should_panic]
447 fn test_scheduler_params_invalid_config() {
448 let max_receipt_size = 40 * 1024 * 1024;
449 let num_shards = 6;
450 let runtime_config = make_runtime_config(max_receipt_size);
451 BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
452 }
453
454 #[test]
455 fn test_bandwidth_request_bitmap() {
456 let mut bitmap = BandwidthRequestBitmap::new();
457 assert_eq!(bitmap.len(), BANDWIDTH_REQUEST_VALUES_NUM);
458
459 let mut fake_bitmap = [false; BANDWIDTH_REQUEST_VALUES_NUM];
460 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
461
462 for _ in 0..(BANDWIDTH_REQUEST_VALUES_NUM * 5) {
463 let random_index = rng.gen_range(0..BANDWIDTH_REQUEST_VALUES_NUM);
464 let value = rng.gen_bool(0.5);
465
466 bitmap.set_bit(random_index, value);
467 fake_bitmap[random_index] = value;
468
469 for i in 0..BANDWIDTH_REQUEST_VALUES_NUM {
470 assert_eq!(bitmap.get_bit(i), fake_bitmap[i]);
471 }
472 }
473 }
474
475 #[test]
476 fn test_bandwidth_request_values() {
477 let max_receipt_size = 4 * 1024 * 1024;
478
479 let params = BandwidthSchedulerParams::new(
480 NonZeroU64::new(6).unwrap(),
481 &make_runtime_config(max_receipt_size),
482 );
483 let values = BandwidthRequestValues::new(¶ms);
484
485 assert!(values.values[0] > params.base_bandwidth);
486 assert_eq!(values.values[BANDWIDTH_REQUEST_VALUES_NUM - 1], params.max_single_grant);
487
488 assert_eq!(params.base_bandwidth, 61139);
489 assert_eq!(
490 values.values,
491 [
492 164468, 267797, 371126, 474455, 577784, 681113, 784442, 887772, 991101, 1094430,
493 1197759, 1301088, 1404417, 1507746, 1611075, 1714405, 1817734, 1921063, 2024392,
494 2127721, 2231050, 2334379, 2437708, 2541038, 2644367, 2747696, 2851025, 2954354,
495 3057683, 3161012, 3264341, 3367671, 3471000, 3574329, 3677658, 3780987, 3884316,
496 3987645, 4090974, 4194304
497 ]
498 );
499 }
500
501 fn make_request_with_ones(ones_indexes: &[usize]) -> BandwidthRequest {
503 let mut req = BandwidthRequest {
504 to_shard: ShardUId::single_shard().shard_id().into(),
505 requested_values_bitmap: BandwidthRequestBitmap::new(),
506 };
507 for i in ones_indexes {
508 req.requested_values_bitmap.set_bit(*i, true);
509 }
510 req
511 }
512
513 fn make_sizes_iter<'a>(
514 sizes: &'a [u64],
515 ) -> impl Iterator<Item = Result<u64, std::convert::Infallible>> + 'a {
516 sizes.iter().map(|&size| Ok(size))
517 }
518
519 #[test]
520 fn test_make_bandwidth_request_from_receipt_sizes() {
521 let max_receipt_size = 4 * 1024 * 1024;
522 let params = BandwidthSchedulerParams::new(
523 NonZeroU64::new(6).unwrap(),
524 &make_runtime_config(max_receipt_size),
525 );
526 let values = BandwidthRequestValues::new(¶ms).values;
527
528 let get_request = |receipt_sizes: &[u64]| -> Option<BandwidthRequest> {
529 BandwidthRequest::make_from_receipt_sizes(
530 ShardUId::single_shard().shard_id(),
531 make_sizes_iter(receipt_sizes),
532 ¶ms,
533 )
534 .unwrap()
535 };
536
537 assert_eq!(get_request(&[]), None);
539
540 let below_base_bandwidth_receipts = [10_000, 20, 999, 2362, 3343, 232, 22];
542 assert!(below_base_bandwidth_receipts.iter().sum::<u64>() < params.base_bandwidth);
543 assert_eq!(get_request(&below_base_bandwidth_receipts), None);
544
545 let equal_to_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000];
547 assert_eq!(equal_to_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth);
548 assert_eq!(get_request(&equal_to_base_bandwidth_receipts), None);
549
550 let above_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000, 1];
553 assert_eq!(above_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth + 1);
554 assert_eq!(get_request(&above_base_bandwidth_receipts), Some(make_request_with_ones(&[0])));
555
556 let above_base_bandwidth_one_receipt = [params.base_bandwidth + 1];
558 assert_eq!(
559 get_request(&above_base_bandwidth_one_receipt),
560 Some(make_request_with_ones(&[0]))
561 );
562
563 let in_between_value = (values[values.len() / 2] + values[values.len() / 2 + 1]) / 2;
566 assert!(!values.contains(&in_between_value));
567 let in_between_size_receipt = [in_between_value];
568 assert_eq!(
569 get_request(&in_between_size_receipt),
570 Some(make_request_with_ones(&[values.len() / 2 + 1]))
571 );
572
573 let max_size_receipt = [max_receipt_size];
575 let max_size_receipt_value_idx =
576 values.iter().position(|v| *v == max_receipt_size).unwrap();
577 assert_eq!(
578 get_request(&max_size_receipt),
579 Some(make_request_with_ones(&[max_size_receipt_value_idx]))
580 );
581
582 assert!(2 * params.max_receipt_size > params.max_shard_bandwidth);
585 let two_max_size_receipts = [max_receipt_size, max_receipt_size];
586 assert_eq!(
587 get_request(&two_max_size_receipts),
588 Some(make_request_with_ones(&[max_size_receipt_value_idx]))
589 );
590
591 let lots_of_small_receipts: Vec<u64> = (0..10_000).into_iter().map(|_| 1_000).collect();
594 assert!(lots_of_small_receipts.iter().sum::<u64>() > params.max_shard_bandwidth);
595 let all_bitmap_indices: Vec<usize> = (0..BANDWIDTH_REQUEST_VALUES_NUM).collect();
596 assert_eq!(
597 get_request(&lots_of_small_receipts),
598 Some(make_request_with_ones(&all_bitmap_indices))
599 );
600 }
601
602 #[test]
605 fn test_make_bandwidth_request_random() {
606 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
607 let max_receipt_size = 4 * 1024 * 1024;
608 let params = BandwidthSchedulerParams::new(
609 NonZeroU64::new(6).unwrap(),
610 &make_runtime_config(max_receipt_size),
611 );
612
613 let min_receipt_size = 5_000;
614 let max_receipts_num = params.max_shard_bandwidth / min_receipt_size * 3 / 2;
615
616 for _test_idx in 0..100 {
617 let num_receipts = rng.gen_range(0..=max_receipts_num);
618 let receipt_sizes: Vec<u64> = (0..num_receipts)
619 .map(|_| rng.gen_range(min_receipt_size..=max_receipt_size))
620 .collect();
621
622 let request = BandwidthRequest::make_from_receipt_sizes(
623 ShardUId::single_shard().shard_id(),
624 make_sizes_iter(&receipt_sizes),
625 ¶ms,
626 )
627 .unwrap();
628
629 let expected_request =
630 make_bandwidth_request_slow(receipt_sizes.iter().copied(), ¶ms);
631 assert_eq!(request, expected_request);
632 }
633 }
634
635 fn make_bandwidth_request_slow(
638 receipt_sizes: impl Iterator<Item = u64>,
639 params: &BandwidthSchedulerParams,
640 ) -> Option<BandwidthRequest> {
641 let mut request = BandwidthRequest {
642 to_shard: ShardUId::single_shard().shard_id().into(),
643 requested_values_bitmap: BandwidthRequestBitmap::new(),
644 };
645 let values = BandwidthRequestValues::new(params).values;
646
647 let mut total_size = 0;
648 for receipt_size in receipt_sizes {
649 total_size += receipt_size;
650
651 for i in 0..values.len() {
652 if values[i] >= total_size {
653 request.requested_values_bitmap.set_bit(i, true);
654 break;
655 }
656 }
657 }
658
659 if request.requested_values_bitmap.is_all_zeros() {
660 return None;
661 }
662
663 Some(request)
664 }
665
666 #[test]
667 fn test_interpolate() {
668 assert_eq!(interpolate(100, 200, 0, 10), 100);
669 assert_eq!(interpolate(100, 200, 5, 10), 150);
670 assert_eq!(interpolate(100, 200, 10, 10), 200);
671 }
672}