near_primitives/
bandwidth_scheduler.rs

1use std::collections::BTreeMap;
2use std::num::NonZeroU64;
3
4use bitvec::order::Lsb0;
5use bitvec::slice::BitSlice;
6use borsh::{BorshDeserialize, BorshSerialize};
7use near_parameters::RuntimeConfig;
8use near_primitives_core::hash::CryptoHash;
9use near_primitives_core::types::ShardId;
10use near_schema_checker_lib::ProtocolSchema;
11
12/// Represents size of receipts, in the context of cross-shard bandwidth, in bytes.
13/// TODO(bandwidth_scheduler) - consider using ByteSize
14pub type Bandwidth = u64;
15
16/// A list of shard's bandwidth requests.
17/// Describes how much the shard would like to send to other shards.
18#[derive(
19    BorshSerialize,
20    BorshDeserialize,
21    serde::Serialize,
22    serde::Deserialize,
23    Debug,
24    Clone,
25    PartialEq,
26    Eq,
27    ProtocolSchema,
28)]
29#[borsh(use_discriminant = true)]
30#[repr(u8)]
31#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
32pub enum BandwidthRequests {
33    V1(BandwidthRequestsV1) = 0,
34}
35
36impl BandwidthRequests {
37    pub fn empty() -> BandwidthRequests {
38        BandwidthRequests::V1(BandwidthRequestsV1 { requests: Vec::new() })
39    }
40}
41
42/// Version 1 of [`BandwidthRequest`].
43#[derive(
44    BorshSerialize,
45    BorshDeserialize,
46    serde::Serialize,
47    serde::Deserialize,
48    Default,
49    Debug,
50    Clone,
51    PartialEq,
52    Eq,
53    ProtocolSchema,
54)]
55#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
56pub struct BandwidthRequestsV1 {
57    pub requests: Vec<BandwidthRequest>,
58}
59
60/// `BandwidthRequest` describes the size of receipts that a shard would like to send to another shard.
61/// When a shard wants to send a lot of receipts to another shard, it needs to create a request and wait
62/// for a bandwidth grant from the bandwidth scheduler.
63#[derive(
64    BorshSerialize,
65    BorshDeserialize,
66    serde::Serialize,
67    serde::Deserialize,
68    Debug,
69    Clone,
70    PartialEq,
71    Eq,
72    ProtocolSchema,
73)]
74#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
75pub struct BandwidthRequest {
76    /// Requesting bandwidth to this shard.
77    pub to_shard: u16,
78    /// Bitmap which describes what values of bandwidth are requested.
79    pub requested_values_bitmap: BandwidthRequestBitmap,
80}
81
82impl BandwidthRequest {
83    /// Creates a bandwidth request based on the sizes of receipts in the outgoing buffer.
84    /// Returns None when a request is not needed (receipt size below base bandwidth).
85    pub fn make_from_receipt_sizes<E>(
86        to_shard: ShardId,
87        receipt_sizes: impl Iterator<Item = Result<u64, E>>,
88        params: &BandwidthSchedulerParams,
89    ) -> Result<Option<BandwidthRequest>, E> {
90        let values = BandwidthRequestValues::new(params).values;
91        let mut bitmap = BandwidthRequestBitmap::new();
92
93        // For every receipt find out how much bandwidth would be needed to send out
94        // all the receipts up to this one. Then find the value that is at least as
95        // large as the required bandwidth and request it in the request bitmap.
96        let mut total_size: u64 = 0;
97        let mut cur_value_idx: usize = 0;
98        for receipt_size_res in receipt_sizes {
99            let receipt_size = receipt_size_res?;
100            total_size = total_size.checked_add(receipt_size).expect(
101                "Total size of receipts doesn't fit in u64, are there exabytes of receipts?",
102            );
103
104            if total_size <= params.base_bandwidth {
105                continue;
106            }
107
108            // Find a value that is at least as big as the total_size
109            while cur_value_idx < values.len() && values[cur_value_idx] < total_size {
110                cur_value_idx += 1;
111            }
112
113            if cur_value_idx == values.len() {
114                // There is no value to request this much, stop the loop.
115                break;
116            }
117
118            // Request the value that is at least as large as total_size
119            bitmap.set_bit(cur_value_idx, true);
120        }
121
122        if bitmap.is_all_zeros() {
123            // No point in making a bandwidth request that doesn't request anything
124            return Ok(None);
125        }
126
127        Ok(Some(BandwidthRequest { to_shard: to_shard.into(), requested_values_bitmap: bitmap }))
128    }
129}
130
131/// There are this many predefined values of bandwidth that can be requested in a BandwidthRequest.
132pub const BANDWIDTH_REQUEST_VALUES_NUM: usize = 40;
133
134/// Values of bandwidth that can be requested in a bandwidth request.
135/// When the nth bit is set in a request bitmap, it means that a shard is requesting the nth value from this list.
136/// The list is sorted, from smallest to largest values.
137#[derive(Clone, Debug, PartialEq, Eq)]
138pub struct BandwidthRequestValues {
139    pub values: [Bandwidth; BANDWIDTH_REQUEST_VALUES_NUM],
140}
141
142/// Performs linear interpolation between min and max.
143/// interpolate(100, 200, 0, 10) = 100
144/// interpolate(100, 200, 5, 10) = 150
145/// interpolate(100, 200, 10, 10) = 200
146fn interpolate(min: u64, max: u64, i: u64, n: u64) -> u64 {
147    min + (max - min) * i / n
148}
149
150impl BandwidthRequestValues {
151    pub fn new(params: &BandwidthSchedulerParams) -> BandwidthRequestValues {
152        // values[-1] = base_bandwidth
153        // values[values.len() - 1] = max_single_grant
154        // values[i] = linear interpolation between values[-1] and values[values.len() - 1]
155        // TODO(bandwidth_scheduler) - consider using exponential interpolation.
156        let mut values = [0; BANDWIDTH_REQUEST_VALUES_NUM];
157
158        let values_len: u64 =
159            values.len().try_into().expect("Converting usize to u64 shouldn't fail");
160        for i in 0..values.len() {
161            let i_u64: u64 = i.try_into().expect("Converting usize to u64 shouldn't fail");
162
163            values[i] =
164                interpolate(params.base_bandwidth, params.max_single_grant, i_u64 + 1, values_len);
165        }
166
167        BandwidthRequestValues { values }
168    }
169}
170
171/// Bitmap which describes which values from the predefined list are being requested.
172/// The nth bit is set to 1 when the nth value from the list is being requested.
173#[derive(
174    BorshSerialize,
175    BorshDeserialize,
176    serde::Serialize,
177    serde::Deserialize,
178    Clone,
179    PartialEq,
180    Eq,
181    ProtocolSchema,
182)]
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184pub struct BandwidthRequestBitmap {
185    pub data: [u8; BANDWIDTH_REQUEST_BITMAP_SIZE],
186}
187
188pub const BANDWIDTH_REQUEST_BITMAP_SIZE: usize = BANDWIDTH_REQUEST_VALUES_NUM / 8;
189const _: () = assert!(
190    BANDWIDTH_REQUEST_VALUES_NUM % 8 == 0,
191    "Every bit in the bitmap should be used. It's wasteful to have unused bits.
192    And having unused bits would require extra validation logic"
193);
194
195impl std::fmt::Debug for BandwidthRequestBitmap {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        write!(f, "BandwidthRequestBitmap(")?;
198        for i in 0..self.len() {
199            if self.get_bit(i) {
200                write!(f, "1")?;
201            } else {
202                write!(f, "0")?;
203            }
204        }
205        write!(f, ")")
206    }
207}
208
209impl BandwidthRequestBitmap {
210    pub fn new() -> BandwidthRequestBitmap {
211        BandwidthRequestBitmap { data: [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE] }
212    }
213
214    pub fn get_bit(&self, idx: usize) -> bool {
215        assert!(idx < self.len());
216
217        let bit_slice = BitSlice::<_, Lsb0>::from_slice(self.data.as_slice());
218        *bit_slice.get(idx).unwrap()
219    }
220
221    pub fn set_bit(&mut self, idx: usize, val: bool) {
222        assert!(idx < self.len());
223
224        let bit_slice = BitSlice::<_, Lsb0>::from_slice_mut(self.data.as_mut_slice());
225        bit_slice.set(idx, val);
226    }
227
228    pub fn len(&self) -> usize {
229        BANDWIDTH_REQUEST_VALUES_NUM
230    }
231
232    pub fn is_all_zeros(&self) -> bool {
233        self.data == [0u8; BANDWIDTH_REQUEST_BITMAP_SIZE]
234    }
235}
236
237/// `BandwidthRequests` from all chunks in a block.
238#[derive(Debug, Clone, PartialEq, Eq)]
239pub struct BlockBandwidthRequests {
240    /// For every shard - all the bandwidth requests generated by this shard.
241    pub shards_bandwidth_requests: BTreeMap<ShardId, BandwidthRequests>,
242}
243
244impl BlockBandwidthRequests {
245    pub fn empty() -> BlockBandwidthRequests {
246        BlockBandwidthRequests { shards_bandwidth_requests: BTreeMap::new() }
247    }
248}
249
250/// Persistent state used by the bandwidth scheduler.
251/// It is kept in the shard trie.
252/// The state should be the same on all shards. All shards start with the same state
253/// and apply the same bandwidth scheduler algorithm at the same heights, so the resulting
254/// scheduler state stays the same.
255#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
256#[borsh(use_discriminant = true)]
257#[repr(u8)]
258pub enum BandwidthSchedulerState {
259    V1(BandwidthSchedulerStateV1) = 0,
260}
261
262#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
263pub struct BandwidthSchedulerStateV1 {
264    /// Allowance for every pair of (sender, receiver). Used in the scheduler algorithm.
265    /// Bandwidth scheduler updates the allowances on every run.
266    pub link_allowances: Vec<LinkAllowance>,
267    /// Sanity check hash to assert that all shards run bandwidth scheduler in the exact same way.
268    /// Hash of previous scheduler state and (some) scheduler inputs.
269    pub sanity_check_hash: CryptoHash,
270}
271
272/// Allowance for a (sender, receiver) pair of shards.
273/// Used in bandwidth scheduler.
274#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq, ProtocolSchema)]
275pub struct LinkAllowance {
276    /// Sender shard
277    pub sender: ShardId,
278    /// Receiver shard
279    pub receiver: ShardId,
280    /// Link allowance, determines priority for granting bandwidth.
281    /// See the bandwidth scheduler module-level comment for a more
282    /// detailed description.
283    pub allowance: Bandwidth,
284}
285
286/// Parameters used in the bandwidth scheduler algorithm.
287#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
288pub struct BandwidthSchedulerParams {
289    /// This much bandwidth is granted by default.
290    /// base_bandwidth = (max_shard_bandwidth - max_single_grant) / (num_shards - 1)
291    pub base_bandwidth: Bandwidth,
292    /// The maximum amount of data that a shard can send or receive at a single height.
293    pub max_shard_bandwidth: Bandwidth,
294    /// The maximum amount of bandwidth that can be granted on a single link.
295    /// Should be at least as big as `max_receipt_size`.
296    pub max_single_grant: Bandwidth,
297    /// Maximum size of a single receipt.
298    pub max_receipt_size: Bandwidth,
299    /// Maximum bandwidth allowance that a link can accumulate.
300    pub max_allowance: Bandwidth,
301}
302
303impl BandwidthSchedulerParams {
304    /// Calculate values of scheduler params based on the current configuration
305    pub fn new(num_shards: NonZeroU64, runtime_config: &RuntimeConfig) -> BandwidthSchedulerParams {
306        let scheduler_config = runtime_config.bandwidth_scheduler_config;
307
308        Self::calculate(
309            scheduler_config.max_shard_bandwidth,
310            scheduler_config.max_single_grant,
311            scheduler_config.max_allowance,
312            scheduler_config.max_base_bandwidth,
313            runtime_config.wasm_config.limit_config.max_receipt_size,
314            num_shards.get(),
315        )
316    }
317
318    fn calculate(
319        max_shard_bandwidth: Bandwidth,
320        max_single_grant: Bandwidth,
321        max_allowance: Bandwidth,
322        max_base_bandwidth: Bandwidth,
323        max_receipt_size: Bandwidth,
324        num_shards: u64,
325    ) -> BandwidthSchedulerParams {
326        assert!(
327            max_single_grant >= max_receipt_size,
328            "A max_single_grant can't be lower than max_receipt_size - it'll be impossible to send a max size receipt"
329        );
330        assert!(
331            max_single_grant <= max_shard_bandwidth,
332            "A single grant must not be greater than max_shard_bandwidth"
333        );
334
335        // Granting `max_single_grant` on one link and `base_bandwidth` on all other links can't
336        // exceed `max_shard_bandwidth`, we have to ensure that:
337        // base_bandwidth * (num_shards - 1) + max_single_grant <= max_shard_bandwidth
338        // Base bandwidth is calculated by taking the bandwidth that would remain available after
339        // granting `max_single_grant` on one link and dividing it equally between the other links.
340        let available_bandwidth = max_shard_bandwidth - max_single_grant;
341        let mut base_bandwidth = available_bandwidth / std::cmp::max(1, num_shards - 1);
342        if base_bandwidth > max_base_bandwidth {
343            base_bandwidth = max_base_bandwidth;
344        }
345
346        BandwidthSchedulerParams {
347            base_bandwidth,
348            max_shard_bandwidth,
349            max_single_grant,
350            max_receipt_size,
351            max_allowance,
352        }
353    }
354
355    /// Example params, used only in tests
356    pub fn for_test(num_shards: u64) -> BandwidthSchedulerParams {
357        let max_shard_bandwidth = 4_500_000;
358        let max_single_grant = 4 * 1024 * 1024;
359        let max_allowance = max_shard_bandwidth;
360        let max_base_bandwidth = 100_000;
361        let max_receipt_size = 4 * 1024 * 1024;
362
363        Self::calculate(
364            max_shard_bandwidth,
365            max_single_grant,
366            max_allowance,
367            max_base_bandwidth,
368            max_receipt_size,
369            num_shards,
370        )
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use std::num::NonZeroU64;
377    use std::ops::Deref;
378    use std::sync::Arc;
379
380    use near_parameters::RuntimeConfig;
381    use rand::{Rng, SeedableRng};
382
383    use crate::bandwidth_scheduler::{BANDWIDTH_REQUEST_VALUES_NUM, interpolate};
384    use crate::shard_layout::ShardUId;
385
386    use super::{
387        BandwidthRequest, BandwidthRequestBitmap, BandwidthRequestValues, BandwidthSchedulerParams,
388    };
389    use rand_chacha::ChaCha20Rng;
390
391    fn make_runtime_config(max_receipt_size: u64) -> RuntimeConfig {
392        let mut runtime_config = RuntimeConfig::test();
393
394        // wasm_config is in Arc, need to clone, modify and set new Arc to modify parameter
395        let mut wasm_config = runtime_config.wasm_config.deref().clone();
396        wasm_config.limit_config.max_receipt_size = max_receipt_size;
397        runtime_config.wasm_config = Arc::new(wasm_config);
398
399        runtime_config
400    }
401
402    /// Ensure that a maximum size receipt can still be sent after granting everyone
403    /// base bandwidth without going over the max_shard_bandwidth limit.
404    fn assert_max_size_can_get_through(params: &BandwidthSchedulerParams, num_shards: u64) {
405        assert!(
406            (num_shards - 1) * params.base_bandwidth + params.max_receipt_size
407                <= params.max_shard_bandwidth
408        )
409    }
410
411    #[test]
412    fn test_scheduler_params_one_shard() {
413        let max_receipt_size = 4 * 1024 * 1024;
414        let num_shards = 1;
415
416        let runtime_config = make_runtime_config(max_receipt_size);
417        let scheduler_params =
418            BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
419        let expected = BandwidthSchedulerParams {
420            base_bandwidth: 100_000,
421            max_shard_bandwidth: 4_500_000,
422            max_single_grant: 4 * 1024 * 1024,
423            max_receipt_size,
424            max_allowance: 4_500_000,
425        };
426        assert_eq!(scheduler_params, expected);
427        assert_max_size_can_get_through(&scheduler_params, num_shards);
428    }
429
430    #[test]
431    fn test_scheduler_params_six_shards() {
432        let max_receipt_size = 4 * 1024 * 1024;
433        let num_shards = 6;
434
435        let runtime_config = make_runtime_config(max_receipt_size);
436        let scheduler_params =
437            BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
438        let expected = BandwidthSchedulerParams {
439            base_bandwidth: (4_500_000 - max_receipt_size) / 5,
440            max_shard_bandwidth: 4_500_000,
441            max_single_grant: 4 * 1024 * 1024,
442            max_receipt_size,
443            max_allowance: 4_500_000,
444        };
445        assert_eq!(scheduler_params, expected);
446        assert_max_size_can_get_through(&scheduler_params, num_shards);
447    }
448
449    /// max_receipt_size is larger than max_shard_bandwidth - incorrect configuration
450    #[test]
451    #[should_panic]
452    fn test_scheduler_params_invalid_config() {
453        let max_receipt_size = 40 * 1024 * 1024;
454        let num_shards = 6;
455        let runtime_config = make_runtime_config(max_receipt_size);
456        BandwidthSchedulerParams::new(NonZeroU64::new(num_shards).unwrap(), &runtime_config);
457    }
458
459    #[test]
460    fn test_bandwidth_request_bitmap() {
461        let mut bitmap = BandwidthRequestBitmap::new();
462        assert_eq!(bitmap.len(), BANDWIDTH_REQUEST_VALUES_NUM);
463
464        let mut fake_bitmap = [false; BANDWIDTH_REQUEST_VALUES_NUM];
465        let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
466
467        for _ in 0..(BANDWIDTH_REQUEST_VALUES_NUM * 5) {
468            let random_index = rng.gen_range(0..BANDWIDTH_REQUEST_VALUES_NUM);
469            let value = rng.gen_bool(0.5);
470
471            bitmap.set_bit(random_index, value);
472            fake_bitmap[random_index] = value;
473
474            for i in 0..BANDWIDTH_REQUEST_VALUES_NUM {
475                assert_eq!(bitmap.get_bit(i), fake_bitmap[i]);
476            }
477        }
478    }
479
480    #[test]
481    fn test_bandwidth_request_values() {
482        let max_receipt_size = 4 * 1024 * 1024;
483
484        let params = BandwidthSchedulerParams::new(
485            NonZeroU64::new(6).unwrap(),
486            &make_runtime_config(max_receipt_size),
487        );
488        let values = BandwidthRequestValues::new(&params);
489
490        assert!(values.values[0] > params.base_bandwidth);
491        assert_eq!(values.values[BANDWIDTH_REQUEST_VALUES_NUM - 1], params.max_single_grant);
492
493        assert_eq!(params.base_bandwidth, 61139);
494        assert_eq!(
495            values.values,
496            [
497                164468, 267797, 371126, 474455, 577784, 681113, 784442, 887772, 991101, 1094430,
498                1197759, 1301088, 1404417, 1507746, 1611075, 1714405, 1817734, 1921063, 2024392,
499                2127721, 2231050, 2334379, 2437708, 2541038, 2644367, 2747696, 2851025, 2954354,
500                3057683, 3161012, 3264341, 3367671, 3471000, 3574329, 3677658, 3780987, 3884316,
501                3987645, 4090974, 4194304
502            ]
503        );
504    }
505
506    // Make a bandwidth request to shard 0 with a bitmap which has ones at the specified indices.
507    fn make_request_with_ones(ones_indexes: &[usize]) -> BandwidthRequest {
508        let mut req = BandwidthRequest {
509            to_shard: ShardUId::single_shard().shard_id().into(),
510            requested_values_bitmap: BandwidthRequestBitmap::new(),
511        };
512        for i in ones_indexes {
513            req.requested_values_bitmap.set_bit(*i, true);
514        }
515        req
516    }
517
518    fn make_sizes_iter<'a>(
519        sizes: &'a [u64],
520    ) -> impl Iterator<Item = Result<u64, std::convert::Infallible>> + 'a {
521        sizes.iter().map(|&size| Ok(size))
522    }
523
524    #[test]
525    fn test_make_bandwidth_request_from_receipt_sizes() {
526        let max_receipt_size = 4 * 1024 * 1024;
527        let params = BandwidthSchedulerParams::new(
528            NonZeroU64::new(6).unwrap(),
529            &make_runtime_config(max_receipt_size),
530        );
531        let values = BandwidthRequestValues::new(&params).values;
532
533        let get_request = |receipt_sizes: &[u64]| -> Option<BandwidthRequest> {
534            BandwidthRequest::make_from_receipt_sizes(
535                ShardUId::single_shard().shard_id(),
536                make_sizes_iter(receipt_sizes),
537                &params,
538            )
539            .unwrap()
540        };
541
542        // No receipts - no bandwidth request.
543        assert_eq!(get_request(&[]), None);
544
545        // Receipts with total size smaller than base_bandwidth don't need a bandwidth request.
546        let below_base_bandwidth_receipts = [10_000, 20, 999, 2362, 3343, 232, 22];
547        assert!(below_base_bandwidth_receipts.iter().sum::<u64>() < params.base_bandwidth);
548        assert_eq!(get_request(&below_base_bandwidth_receipts), None);
549
550        // Receipts with total size equal to base_bandwidth don't need a bandwidth_request
551        let equal_to_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000];
552        assert_eq!(equal_to_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth);
553        assert_eq!(get_request(&equal_to_base_bandwidth_receipts), None);
554
555        // Receipts with total size barely larger than base_bandwidth need a bandwidth request.
556        // Only the first bit in the bitmap should be set to 1.
557        let above_base_bandwidth_receipts = [10_000, 20_000, params.base_bandwidth - 30_000, 1];
558        assert_eq!(above_base_bandwidth_receipts.iter().sum::<u64>(), params.base_bandwidth + 1);
559        assert_eq!(get_request(&above_base_bandwidth_receipts), Some(make_request_with_ones(&[0])));
560
561        // A single receipt which is slightly larger than base_bandwidth needs a bandwidth request.
562        let above_base_bandwidth_one_receipt = [params.base_bandwidth + 1];
563        assert_eq!(
564            get_request(&above_base_bandwidth_one_receipt),
565            Some(make_request_with_ones(&[0]))
566        );
567
568        // When requesting bandwidth that is between two values on the list, the request
569        // should ask for the first value that is bigger than the needed bandwidth.
570        let in_between_value = (values[values.len() / 2] + values[values.len() / 2 + 1]) / 2;
571        assert!(!values.contains(&in_between_value));
572        let in_between_size_receipt = [in_between_value];
573        assert_eq!(
574            get_request(&in_between_size_receipt),
575            Some(make_request_with_ones(&[values.len() / 2 + 1]))
576        );
577
578        // A single max size receipt should have the corresponding value set to one.
579        let max_size_receipt = [max_receipt_size];
580        let max_size_receipt_value_idx =
581            values.iter().position(|v| *v == max_receipt_size).unwrap();
582        assert_eq!(
583            get_request(&max_size_receipt),
584            Some(make_request_with_ones(&[max_size_receipt_value_idx]))
585        );
586
587        // Two max size receipts should produce the same bandwidth request as one max size receipt.
588        // 2 * max_size_receipt > max_shard_bandwidth, so it doesn't make sense to request more bandwidth.
589        assert!(2 * params.max_receipt_size > params.max_shard_bandwidth);
590        let two_max_size_receipts = [max_receipt_size, max_receipt_size];
591        assert_eq!(
592            get_request(&two_max_size_receipts),
593            Some(make_request_with_ones(&[max_size_receipt_value_idx]))
594        );
595
596        // A ton of small receipts should cause all bits to be set to one.
597        // 10_000 receipts, each with size 1000. More than a shard can send out at a single height.
598        let lots_of_small_receipts: Vec<u64> = (0..10_000).into_iter().map(|_| 1_000).collect();
599        assert!(lots_of_small_receipts.iter().sum::<u64>() > params.max_shard_bandwidth);
600        let all_bitmap_indices: Vec<usize> = (0..BANDWIDTH_REQUEST_VALUES_NUM).collect();
601        assert_eq!(
602            get_request(&lots_of_small_receipts),
603            Some(make_request_with_ones(&all_bitmap_indices))
604        );
605    }
606
607    /// Generate random receipt sizes and create a bandwidth request from them.
608    /// Compare the created bandwidth request with a request created using simpler logic.
609    #[test]
610    fn test_make_bandwidth_request_random() {
611        let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
612        let max_receipt_size = 4 * 1024 * 1024;
613        let params = BandwidthSchedulerParams::new(
614            NonZeroU64::new(6).unwrap(),
615            &make_runtime_config(max_receipt_size),
616        );
617
618        let min_receipt_size = 5_000;
619        let max_receipts_num = params.max_shard_bandwidth / min_receipt_size * 3 / 2;
620
621        for _test_idx in 0..100 {
622            let num_receipts = rng.gen_range(0..=max_receipts_num);
623            let receipt_sizes: Vec<u64> = (0..num_receipts)
624                .map(|_| rng.gen_range(min_receipt_size..=max_receipt_size))
625                .collect();
626
627            let request = BandwidthRequest::make_from_receipt_sizes(
628                ShardUId::single_shard().shard_id(),
629                make_sizes_iter(&receipt_sizes),
630                &params,
631            )
632            .unwrap();
633
634            let expected_request =
635                make_bandwidth_request_slow(receipt_sizes.iter().copied(), &params);
636            assert_eq!(request, expected_request);
637        }
638    }
639
640    /// A more naive implementation of bandwidth request generation.
641    /// For every total_size find the value that is at least this large and request it.
642    fn make_bandwidth_request_slow(
643        receipt_sizes: impl Iterator<Item = u64>,
644        params: &BandwidthSchedulerParams,
645    ) -> Option<BandwidthRequest> {
646        let mut request = BandwidthRequest {
647            to_shard: ShardUId::single_shard().shard_id().into(),
648            requested_values_bitmap: BandwidthRequestBitmap::new(),
649        };
650        let values = BandwidthRequestValues::new(params).values;
651
652        let mut total_size = 0;
653        for receipt_size in receipt_sizes {
654            total_size += receipt_size;
655
656            for i in 0..values.len() {
657                if values[i] >= total_size {
658                    request.requested_values_bitmap.set_bit(i, true);
659                    break;
660                }
661            }
662        }
663
664        if request.requested_values_bitmap.is_all_zeros() {
665            return None;
666        }
667
668        Some(request)
669    }
670
671    #[test]
672    fn test_interpolate() {
673        assert_eq!(interpolate(100, 200, 0, 10), 100);
674        assert_eq!(interpolate(100, 200, 5, 10), 150);
675        assert_eq!(interpolate(100, 200, 10, 10), 200);
676    }
677}