near_primitives/
congestion_info.rs

1use std::collections::BTreeMap;
2
3use crate::errors::RuntimeError;
4use borsh::{BorshDeserialize, BorshSerialize};
5use near_parameters::config::CongestionControlConfig;
6use near_primitives_core::types::{Gas, ShardId};
7use near_schema_checker_lib::ProtocolSchema;
8use ordered_float::NotNan;
9
10/// This class combines the congestion control config, congestion info and
11/// missed chunks count. It contains the main congestion control logic and
12/// exposes methods that can be used for congestion control.
13///
14/// Use this struct to make congestion control decisions, by looking at the
15/// congestion info of a previous chunk produced on a remote shard. For building
16/// up a congestion info for the local shard, this struct should not be
17/// necessary. Use `CongestionInfo` directly.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct CongestionControl {
20    config: CongestionControlConfig,
21    /// Finalized congestion info of a previous chunk.
22    info: CongestionInfo,
23    /// How many block heights had no chunk since the last successful chunk on
24    /// the respective shard.
25    missed_chunks_count: u64,
26}
27
28impl CongestionControl {
29    pub fn new(
30        config: CongestionControlConfig,
31        info: CongestionInfo,
32        missed_chunks_count: u64,
33    ) -> Self {
34        Self { config, info, missed_chunks_count }
35    }
36
37    pub fn config(&self) -> &CongestionControlConfig {
38        &self.config
39    }
40
41    pub fn congestion_info(&self) -> &CongestionInfo {
42        &self.info
43    }
44
45    pub fn congestion_level(&self) -> f64 {
46        let incoming_congestion = self.incoming_congestion();
47        let outgoing_congestion = self.outgoing_congestion();
48        let memory_congestion = self.memory_congestion();
49        let missed_chunks_congestion = self.missed_chunks_congestion();
50
51        incoming_congestion
52            .max(outgoing_congestion)
53            .max(memory_congestion)
54            .max(missed_chunks_congestion)
55    }
56
57    fn incoming_congestion(&self) -> f64 {
58        self.info.incoming_congestion(&self.config)
59    }
60
61    fn outgoing_congestion(&self) -> f64 {
62        self.info.outgoing_congestion(&self.config)
63    }
64
65    fn memory_congestion(&self) -> f64 {
66        self.info.memory_congestion(&self.config)
67    }
68
69    fn missed_chunks_congestion(&self) -> f64 {
70        if self.missed_chunks_count <= 1 {
71            return 0.0;
72        }
73
74        clamped_f64_fraction(
75            self.missed_chunks_count as u128,
76            self.config.max_congestion_missed_chunks,
77        )
78    }
79
80    /// How much gas another shard can send to us in the next block.
81    pub fn outgoing_gas_limit(&self, sender_shard: ShardId) -> Gas {
82        let congestion = self.congestion_level();
83
84        if Self::is_fully_congested(congestion) {
85            // Red traffic light: reduce to minimum speed
86            if sender_shard == ShardId::from(self.info.allowed_shard()) {
87                self.config.allowed_shard_outgoing_gas
88            } else {
89                0
90            }
91        } else {
92            mix(self.config.max_outgoing_gas, self.config.min_outgoing_gas, congestion)
93        }
94    }
95
96    pub fn is_fully_congested(congestion_level: f64) -> bool {
97        // note: using float equality is okay here because
98        // `clamped_f64_fraction` clamps to exactly 1.0.
99        debug_assert!(congestion_level <= 1.0);
100        congestion_level == 1.0
101    }
102
103    /// How much data another shard can send to us in the next block.
104    pub fn outgoing_size_limit(&self, sender_shard: ShardId) -> Gas {
105        if sender_shard == ShardId::from(self.info.allowed_shard()) {
106            // The allowed shard is allowed to send more data to us.
107            self.config.outgoing_receipts_big_size_limit
108        } else {
109            // Other shards have a low standard limit.
110            self.config.outgoing_receipts_usual_size_limit
111        }
112    }
113
114    /// How much gas we accept for executing new transactions going to any
115    /// uncongested shards.
116    pub fn process_tx_limit(&self) -> Gas {
117        mix(self.config.max_tx_gas, self.config.min_tx_gas, self.incoming_congestion())
118    }
119
120    /// Whether we can accept new transaction with the receiver set to this shard.
121    ///
122    /// If the shard doesn't accept new transaction, provide the reason for
123    /// extra debugging information.
124    pub fn shard_accepts_transactions(&self) -> ShardAcceptsTransactions {
125        let incoming_congestion = self.incoming_congestion();
126        let outgoing_congestion = self.outgoing_congestion();
127        let memory_congestion = self.memory_congestion();
128        let missed_chunks_congestion = self.missed_chunks_congestion();
129
130        let congestion_level = incoming_congestion
131            .max(outgoing_congestion)
132            .max(memory_congestion)
133            .max(missed_chunks_congestion);
134
135        // Convert to NotNan here, if not possible, the max above is already meaningless.
136        let congestion_level =
137            NotNan::new(congestion_level).unwrap_or_else(|_| NotNan::new(1.0).unwrap());
138        if *congestion_level < self.config.reject_tx_congestion_threshold {
139            return ShardAcceptsTransactions::Yes;
140        }
141
142        let reason = if missed_chunks_congestion >= *congestion_level {
143            RejectTransactionReason::MissedChunks { missed_chunks: self.missed_chunks_count }
144        } else if incoming_congestion >= *congestion_level {
145            RejectTransactionReason::IncomingCongestion { congestion_level }
146        } else if outgoing_congestion >= *congestion_level {
147            RejectTransactionReason::OutgoingCongestion { congestion_level }
148        } else {
149            RejectTransactionReason::MemoryCongestion { congestion_level }
150        };
151        ShardAcceptsTransactions::No(reason)
152    }
153}
154
155/// Result of [`CongestionControl::shard_accepts_transactions`].
156pub enum ShardAcceptsTransactions {
157    Yes,
158    No(RejectTransactionReason),
159}
160
161/// Detailed information for why a shard rejects new transactions.
162pub enum RejectTransactionReason {
163    IncomingCongestion { congestion_level: NotNan<f64> },
164    OutgoingCongestion { congestion_level: NotNan<f64> },
165    MemoryCongestion { congestion_level: NotNan<f64> },
166    MissedChunks { missed_chunks: u64 },
167}
168
169/// Stores the congestion level of a shard.
170///
171/// The CongestionInfo is a part of the ChunkHeader. It is versioned and each
172/// version should not be changed. Rather a new version with the desired changes
173/// should be added and used in place of the old one. When adding new versions
174/// please also update the default.
175#[derive(
176    BorshSerialize,
177    BorshDeserialize,
178    serde::Serialize,
179    serde::Deserialize,
180    Debug,
181    Clone,
182    Copy,
183    PartialEq,
184    Eq,
185    ProtocolSchema,
186)]
187#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
188pub enum CongestionInfo {
189    V1(CongestionInfoV1),
190}
191
192impl Default for CongestionInfo {
193    fn default() -> Self {
194        Self::V1(CongestionInfoV1::default())
195    }
196}
197
198impl CongestionInfo {
199    // A helper method to compare the congestion info from the chunk extra of
200    // the previous chunk and the header of the current chunk. It returns true
201    // if the congestion info was correctly set in the chunk header based on the
202    // information from the chunk extra.
203    //
204    // TODO(congestion_control) validate allowed shard
205    pub fn validate_extra_and_header(extra: &CongestionInfo, header: &CongestionInfo) -> bool {
206        match (extra, header) {
207            (CongestionInfo::V1(extra), CongestionInfo::V1(header)) => {
208                extra.delayed_receipts_gas == header.delayed_receipts_gas
209                    && extra.buffered_receipts_gas == header.buffered_receipts_gas
210                    && extra.receipt_bytes == header.receipt_bytes
211                    && extra.allowed_shard == header.allowed_shard
212            }
213        }
214    }
215
216    pub fn delayed_receipts_gas(&self) -> u128 {
217        match self {
218            CongestionInfo::V1(inner) => inner.delayed_receipts_gas,
219        }
220    }
221
222    pub fn buffered_receipts_gas(&self) -> u128 {
223        match self {
224            CongestionInfo::V1(inner) => inner.buffered_receipts_gas,
225        }
226    }
227
228    pub fn receipt_bytes(&self) -> u64 {
229        match self {
230            CongestionInfo::V1(inner) => inner.receipt_bytes,
231        }
232    }
233
234    pub fn allowed_shard(&self) -> u16 {
235        match self {
236            CongestionInfo::V1(inner) => inner.allowed_shard,
237        }
238    }
239
240    pub fn set_allowed_shard(&mut self, allowed_shard: u16) {
241        match self {
242            CongestionInfo::V1(inner) => inner.allowed_shard = allowed_shard,
243        }
244    }
245
246    pub fn add_receipt_bytes(&mut self, bytes: u64) -> Result<(), RuntimeError> {
247        match self {
248            CongestionInfo::V1(inner) => {
249                inner.receipt_bytes = inner.receipt_bytes.checked_add(bytes).ok_or_else(|| {
250                    RuntimeError::UnexpectedIntegerOverflow("add_receipt_bytes".into())
251                })?;
252            }
253        }
254        Ok(())
255    }
256
257    pub fn remove_receipt_bytes(&mut self, bytes: u64) -> Result<(), RuntimeError> {
258        match self {
259            CongestionInfo::V1(inner) => {
260                inner.receipt_bytes = inner.receipt_bytes.checked_sub(bytes).ok_or_else(|| {
261                    RuntimeError::UnexpectedIntegerOverflow("remove_receipt_bytes".into())
262                })?;
263            }
264        }
265        Ok(())
266    }
267
268    pub fn add_delayed_receipt_gas(&mut self, gas: Gas) -> Result<(), RuntimeError> {
269        match self {
270            CongestionInfo::V1(inner) => {
271                inner.delayed_receipts_gas =
272                    inner.delayed_receipts_gas.checked_add(gas as u128).ok_or_else(|| {
273                        RuntimeError::UnexpectedIntegerOverflow("add_delayed_receipt_gas".into())
274                    })?;
275            }
276        }
277        Ok(())
278    }
279
280    pub fn remove_delayed_receipt_gas(&mut self, gas: Gas) -> Result<(), RuntimeError> {
281        match self {
282            CongestionInfo::V1(inner) => {
283                inner.delayed_receipts_gas =
284                    inner.delayed_receipts_gas.checked_sub(gas as u128).ok_or_else(|| {
285                        RuntimeError::UnexpectedIntegerOverflow("remove_delayed_receipt_gas".into())
286                    })?;
287            }
288        }
289        Ok(())
290    }
291
292    pub fn add_buffered_receipt_gas(&mut self, gas: Gas) -> Result<(), RuntimeError> {
293        match self {
294            CongestionInfo::V1(inner) => {
295                inner.buffered_receipts_gas =
296                    inner.buffered_receipts_gas.checked_add(gas as u128).ok_or_else(|| {
297                        RuntimeError::UnexpectedIntegerOverflow("add_buffered_receipt_gas".into())
298                    })?;
299            }
300        }
301        Ok(())
302    }
303
304    pub fn remove_buffered_receipt_gas(&mut self, gas: u128) -> Result<(), RuntimeError> {
305        match self {
306            CongestionInfo::V1(inner) => {
307                inner.buffered_receipts_gas =
308                    inner.buffered_receipts_gas.checked_sub(gas).ok_or_else(|| {
309                        RuntimeError::UnexpectedIntegerOverflow(
310                            "remove_buffered_receipt_gas".into(),
311                        )
312                    })?;
313            }
314        }
315        Ok(())
316    }
317
318    /// Congestion level ignoring the chain context (missed chunks count).
319    pub fn localized_congestion_level(&self, config: &CongestionControlConfig) -> f64 {
320        let incoming_congestion = self.incoming_congestion(config);
321        let outgoing_congestion = self.outgoing_congestion(config);
322        let memory_congestion = self.memory_congestion(config);
323        incoming_congestion.max(outgoing_congestion).max(memory_congestion)
324    }
325
326    pub fn incoming_congestion(&self, config: &CongestionControlConfig) -> f64 {
327        clamped_f64_fraction(self.delayed_receipts_gas(), config.max_congestion_incoming_gas)
328    }
329
330    pub fn outgoing_congestion(&self, config: &CongestionControlConfig) -> f64 {
331        clamped_f64_fraction(self.buffered_receipts_gas(), config.max_congestion_outgoing_gas)
332    }
333
334    pub fn memory_congestion(&self, config: &CongestionControlConfig) -> f64 {
335        clamped_f64_fraction(self.receipt_bytes() as u128, config.max_congestion_memory_consumption)
336    }
337
338    /// Computes and sets the `allowed_shard` field.
339    ///
340    /// If in a fully congested state, decide which shard of the shards is
341    /// allowed to forward gas to `own_shard` this round. In this case, we stop all
342    /// of the shards from sending anything to `own_shard`. But to guarantee
343    /// progress, we allow one shard to send `allowed_shard_outgoing_gas`
344    /// in the next chunk.
345    ///
346    /// It is also used to determine the size limit for outgoing receipts from sender shards.
347    /// Only the allowed shard can send receipts of size `outgoing_receipts_big_size_limit`.
348    /// Other shards can only send receipts of size `outgoing_receipts_usual_size_limit`.
349    pub fn finalize_allowed_shard(
350        &mut self,
351        own_shard: ShardId,
352        all_shards: &[ShardId],
353        congestion_seed: u64,
354    ) {
355        let allowed_shard = Self::get_new_allowed_shard(own_shard, all_shards, congestion_seed);
356        self.set_allowed_shard(allowed_shard.into());
357    }
358
359    fn get_new_allowed_shard(
360        own_shard: ShardId,
361        all_shards: &[ShardId],
362        congestion_seed: u64,
363    ) -> ShardId {
364        if let Some(index) = congestion_seed.checked_rem(all_shards.len() as u64) {
365            // round robin for other shards based on the seed
366            return *all_shards
367                .get(index as usize)
368                .expect("`checked_rem` should have ensured array access is in bound");
369        }
370        // checked_rem failed, hence all_shards.len() is 0
371        // own_shard is the only choice.
372        return own_shard;
373    }
374}
375
376/// The block congestion info contains the congestion info for all shards in the
377/// block extended with the missed chunks count.
378#[derive(Clone, Debug, Default)]
379pub struct BlockCongestionInfo {
380    /// The per shard congestion info. It's important that the data structure is
381    /// deterministic because the allowed shard id selection depends on the
382    /// order of shard ids in this map. Ideally it should also be sorted by shard id.
383    shards_congestion_info: BTreeMap<ShardId, ExtendedCongestionInfo>,
384}
385
386impl BlockCongestionInfo {
387    pub fn new(shards_congestion_info: BTreeMap<ShardId, ExtendedCongestionInfo>) -> Self {
388        Self { shards_congestion_info }
389    }
390
391    pub fn iter(&self) -> impl Iterator<Item = (&ShardId, &ExtendedCongestionInfo)> {
392        self.shards_congestion_info.iter()
393    }
394
395    pub fn all_shards(&self) -> Vec<ShardId> {
396        self.shards_congestion_info.keys().copied().collect()
397    }
398
399    pub fn get(&self, shard_id: &ShardId) -> Option<&ExtendedCongestionInfo> {
400        self.shards_congestion_info.get(shard_id)
401    }
402
403    pub fn get_mut(&mut self, shard_id: &ShardId) -> Option<&mut ExtendedCongestionInfo> {
404        self.shards_congestion_info.get_mut(shard_id)
405    }
406
407    pub fn insert(
408        &mut self,
409        shard_id: ShardId,
410        value: ExtendedCongestionInfo,
411    ) -> Option<ExtendedCongestionInfo> {
412        self.shards_congestion_info.insert(shard_id, value)
413    }
414
415    pub fn is_empty(&self) -> bool {
416        self.shards_congestion_info.is_empty()
417    }
418}
419
420/// The extended congestion info contains the congestion info and extra
421/// information extracted from the block that is needed for congestion control.
422#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
423pub struct ExtendedCongestionInfo {
424    pub congestion_info: CongestionInfo,
425    pub missed_chunks_count: u64,
426}
427
428impl ExtendedCongestionInfo {
429    pub fn new(congestion_info: CongestionInfo, missed_chunks_count: u64) -> Self {
430        Self { congestion_info, missed_chunks_count }
431    }
432}
433
434/// Stores the congestion level of a shard.
435#[derive(
436    BorshSerialize,
437    BorshDeserialize,
438    serde::Serialize,
439    serde::Deserialize,
440    Default,
441    Debug,
442    Clone,
443    Copy,
444    PartialEq,
445    Eq,
446    ProtocolSchema,
447)]
448#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
449pub struct CongestionInfoV1 {
450    /// Sum of gas in currently delayed receipts.
451    pub delayed_receipts_gas: u128,
452    /// Sum of gas in currently buffered receipts.
453    pub buffered_receipts_gas: u128,
454    /// Size of borsh serialized receipts stored in state because they
455    /// were delayed, buffered, postponed, or yielded.
456    pub receipt_bytes: u64,
457    /// If fully congested, only this shard can forward receipts.
458    pub allowed_shard: u16,
459}
460
461/// Returns `value / max` clamped to te range [0,1].
462#[inline]
463fn clamped_f64_fraction(value: u128, max: u64) -> f64 {
464    assert!(max > 0);
465    if max as u128 <= value { 1.0 } else { value as f64 / max as f64 }
466}
467
468/// linearly interpolate between two values
469///
470/// This method treats u16 as a fraction of u16::MAX.
471/// This makes multiplication of numbers on the upper end of `u128` better behaved
472/// than using f64 which lacks precision for such high numbers and might have platform incompatibilities.
473fn mix(left: u64, right: u64, ratio: f64) -> u64 {
474    debug_assert!(ratio >= 0.0);
475    debug_assert!(ratio <= 1.0);
476
477    // Note on precision: f64 is only precise to 53 binary digits. That is
478    // enough to represent ~9 PGAS without error. Precision above that is
479    // rounded according to the IEEE 754-2008 standard which Rust's f64
480    // implements.
481    // For example, a value of 100 Pgas is rounded to steps of 8 gas.
482    let left_part = left as f64 * (1.0 - ratio);
483    let right_part = right as f64 * ratio;
484    // Accumulated error is doubled again, up to 16 gas for 100 Pgas.
485    let total = left_part + right_part;
486
487    // Conversion is save because left and right were both u64 and the result is
488    // between the two. Even with precision errors, we cannot breach the
489    // boundaries.
490    return total.round() as u64;
491}
492
493impl ShardAcceptsTransactions {
494    pub fn is_yes(&self) -> bool {
495        matches!(self, ShardAcceptsTransactions::Yes)
496    }
497
498    pub fn is_no(&self) -> bool {
499        !self.is_yes()
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use itertools::Itertools;
506    use near_parameters::RuntimeConfigStore;
507    use near_primitives_core::version::PROTOCOL_VERSION;
508
509    use super::*;
510
511    fn get_config() -> CongestionControlConfig {
512        // Fix the initial configuration of congestion control for the tests.
513        let runtime_config_store = RuntimeConfigStore::new(None);
514        let runtime_config = runtime_config_store.get_config(PROTOCOL_VERSION);
515        runtime_config.congestion_control_config
516    }
517
518    #[test]
519    fn test_mix() {
520        assert_eq!(500, mix(0, 1000, 0.5));
521        assert_eq!(0, mix(0, 0, 0.3));
522        assert_eq!(1000, mix(1000, 1000, 0.1));
523        assert_eq!(60, mix(50, 80, 0.33));
524    }
525
526    #[test]
527    fn test_mix_edge_cases() {
528        // at `u64::MAX` we should see no precision errors
529        assert_eq!(u64::MAX, mix(u64::MAX, u64::MAX, 0.33));
530        assert_eq!(u64::MAX, mix(u64::MAX, u64::MAX, 0.63));
531        assert_eq!(u64::MAX, mix(u64::MAX, u64::MAX, 0.99));
532
533        // precision errors must be consistent
534        assert_eq!(u64::MAX, mix(u64::MAX - 1, u64::MAX, 0.25));
535        assert_eq!(u64::MAX, mix(u64::MAX - 255, u64::MAX, 0.25));
536        assert_eq!(u64::MAX, mix(u64::MAX - 1023, u64::MAX, 0.25));
537
538        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 1024, u64::MAX, 0.25));
539        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 1500, u64::MAX, 0.25));
540        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 2047, u64::MAX, 0.25));
541        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 2048, u64::MAX, 0.25));
542        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 2049, u64::MAX, 0.25));
543        assert_eq!(u64::MAX - 2047, mix(u64::MAX - 3000, u64::MAX, 0.25));
544
545        assert_eq!(u64::MAX - 4095, mix(u64::MAX - 4000, u64::MAX, 0.25));
546    }
547
548    #[test]
549    fn test_clamped_f64_fraction() {
550        assert_eq!(0.0, clamped_f64_fraction(0, 10));
551        assert_eq!(0.5, clamped_f64_fraction(5, 10));
552        assert_eq!(1.0, clamped_f64_fraction(10, 10));
553
554        assert_eq!(0.0, clamped_f64_fraction(0, 1));
555        assert_eq!(0.0, clamped_f64_fraction(0, u64::MAX));
556
557        assert_eq!(0.5, clamped_f64_fraction(1, 2));
558        assert_eq!(0.5, clamped_f64_fraction(100, 200));
559        assert_eq!(0.5, clamped_f64_fraction(u64::MAX as u128 / 2, u64::MAX));
560
561        // test clamp
562        assert_eq!(1.0, clamped_f64_fraction(11, 10));
563        assert_eq!(1.0, clamped_f64_fraction(u128::MAX, 10));
564        assert_eq!(1.0, clamped_f64_fraction(u128::MAX, u64::MAX));
565    }
566
567    /// Default congestion info should be no congestion => maximally permissive.
568    #[test]
569    fn test_default_congestion() {
570        let config = get_config();
571        let info = CongestionInfo::default();
572        let congestion_control = CongestionControl::new(config, info, 0);
573
574        assert_eq!(0.0, info.memory_congestion(&config));
575        assert_eq!(0.0, info.incoming_congestion(&config));
576        assert_eq!(0.0, info.outgoing_congestion(&config));
577        assert_eq!(0.0, info.localized_congestion_level(&config));
578
579        assert_eq!(0.0, congestion_control.memory_congestion());
580        assert_eq!(0.0, congestion_control.incoming_congestion());
581        assert_eq!(0.0, congestion_control.outgoing_congestion());
582        assert_eq!(0.0, congestion_control.congestion_level());
583
584        assert!(
585            config
586                .max_outgoing_gas
587                .abs_diff(congestion_control.outgoing_gas_limit(ShardId::new(0)))
588                <= 1
589        );
590
591        assert!(config.max_tx_gas.abs_diff(congestion_control.process_tx_limit()) <= 1);
592        assert!(congestion_control.shard_accepts_transactions().is_yes());
593    }
594
595    #[test]
596    fn test_memory_congestion() {
597        let config = get_config();
598        let mut info = CongestionInfo::default();
599
600        info.add_receipt_bytes(config.max_congestion_memory_consumption).unwrap();
601        info.add_receipt_bytes(500).unwrap();
602        info.remove_receipt_bytes(500).unwrap();
603
604        {
605            let control = CongestionControl::new(config, info, 0);
606            assert_eq!(1.0, control.congestion_level());
607            // fully congested, no more forwarding allowed
608            assert_eq!(0, control.outgoing_gas_limit(ShardId::new(1)));
609            assert!(control.shard_accepts_transactions().is_no());
610            // processing to other shards is not restricted by memory congestion
611            assert_eq!(config.max_tx_gas, control.process_tx_limit());
612        }
613
614        // Assert threshold is 80%. Change this number if the config changes
615        assert_eq!(0.8, config.reject_tx_congestion_threshold);
616
617        // reduce congestion to 80%
618        info.remove_receipt_bytes(config.max_congestion_memory_consumption / 5).unwrap();
619        {
620            let control = CongestionControl::new(config, info, 0);
621            assert_eq!(0.8, control.congestion_level());
622            assert_eq!(
623                mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8),
624                control.outgoing_gas_limit(ShardId::new(1))
625            );
626            // at 80%, still no new transactions are allowed
627            assert!(control.shard_accepts_transactions().is_no());
628        }
629
630        // reduce congestion to 10%
631        info.remove_receipt_bytes(7 * config.max_congestion_memory_consumption / 10).unwrap();
632        {
633            let control = CongestionControl::new(config, info, 0);
634            assert_eq!(0.1, control.congestion_level());
635            assert_eq!(
636                mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.1),
637                control.outgoing_gas_limit(ShardId::new(1))
638            );
639            // at 12.5%, new transactions are allowed (threshold is 0.25)
640            assert!(control.shard_accepts_transactions().is_yes());
641        }
642    }
643
644    #[test]
645    fn test_incoming_congestion() {
646        let config = get_config();
647        let mut info = CongestionInfo::default();
648
649        info.add_delayed_receipt_gas(config.max_congestion_incoming_gas).unwrap();
650        info.add_delayed_receipt_gas(500).unwrap();
651        info.remove_delayed_receipt_gas(500).unwrap();
652
653        {
654            let control = CongestionControl::new(config, info, 0);
655            assert_eq!(1.0, control.congestion_level());
656            // fully congested, no more forwarding allowed
657            assert_eq!(0, control.outgoing_gas_limit(ShardId::new(1)));
658            assert!(control.shard_accepts_transactions().is_no());
659            // processing to other shards is restricted by own incoming congestion
660            assert_eq!(config.min_tx_gas, control.process_tx_limit());
661        }
662
663        // Assert threshold is 80%. Change this number if the config changes
664        assert_eq!(0.8, config.reject_tx_congestion_threshold);
665
666        // reduce congestion to 80%
667        info.remove_delayed_receipt_gas(config.max_congestion_incoming_gas / 5).unwrap();
668        {
669            let control = CongestionControl::new(config, info, 0);
670            assert_eq!(0.8, control.congestion_level());
671            assert_eq!(
672                mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8),
673                control.outgoing_gas_limit(ShardId::new(1))
674            );
675            // at 80%, still no new transactions are allowed
676            assert!(control.shard_accepts_transactions().is_no());
677        }
678
679        // reduce congestion to 10%
680        info.remove_delayed_receipt_gas(7 * config.max_congestion_incoming_gas / 10).unwrap();
681        {
682            let control = CongestionControl::new(config, info, 0);
683            assert_eq!(0.1, control.congestion_level());
684            assert_eq!(
685                mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.1),
686                control.outgoing_gas_limit(ShardId::new(1))
687            );
688            // at 10%, new transactions are allowed (threshold is 80%)
689            assert!(control.shard_accepts_transactions().is_yes());
690        }
691    }
692
693    #[test]
694    fn test_outgoing_congestion() {
695        let config = get_config();
696        let mut info = CongestionInfo::default();
697
698        info.add_buffered_receipt_gas(config.max_congestion_outgoing_gas).unwrap();
699        info.add_buffered_receipt_gas(500).unwrap();
700        info.remove_buffered_receipt_gas(500).unwrap();
701
702        let control = CongestionControl::new(config, info, 0);
703        assert_eq!(1.0, control.congestion_level());
704        // fully congested, no more forwarding allowed
705        assert_eq!(0, control.outgoing_gas_limit(ShardId::new(1)));
706        assert!(control.shard_accepts_transactions().is_no());
707        // processing to other shards is not restricted by own outgoing congestion
708        assert_eq!(config.max_tx_gas, control.process_tx_limit());
709
710        // Assert threshold is 80%. Change this number if the config changes
711        assert_eq!(0.8, config.reject_tx_congestion_threshold);
712
713        // reduce congestion to 80%
714        let gas_diff = config.max_congestion_outgoing_gas / 5;
715        info.remove_buffered_receipt_gas(gas_diff.into()).unwrap();
716        let control = CongestionControl::new(config, info, 0);
717        assert_eq!(0.8, control.congestion_level());
718        assert_eq!(
719            mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8),
720            control.outgoing_gas_limit(ShardId::new(1))
721        );
722        // at 80%, still no new transactions to us are allowed
723        assert!(control.shard_accepts_transactions().is_no());
724
725        // reduce congestion to 10%
726        let gas_diff = 7 * config.max_congestion_outgoing_gas / 10;
727        info.remove_buffered_receipt_gas(gas_diff.into()).unwrap();
728        let control = CongestionControl::new(config, info, 0);
729        assert_eq!(0.1, control.congestion_level());
730        assert_eq!(
731            mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.1),
732            control.outgoing_gas_limit(ShardId::new(1))
733        );
734        // at 10%, new transactions are allowed
735        assert!(control.shard_accepts_transactions().is_yes());
736    }
737
738    #[test]
739    fn test_missed_chunks_congestion() {
740        // The default config is quite restricting, allow more missed chunks for
741        // this test to check the middle cases.
742        let mut config = get_config();
743        config.max_congestion_missed_chunks = 10;
744
745        let info = CongestionInfo::default();
746
747        // Test missed chunks congestion without any other congestion
748        let make = |count| CongestionControl::new(config, info, count);
749
750        assert_eq!(make(0).congestion_level(), 0.0);
751        assert_eq!(make(1).congestion_level(), 0.0);
752        assert_eq!(make(2).congestion_level(), 0.2);
753        assert_eq!(make(3).congestion_level(), 0.3);
754        assert_eq!(make(10).congestion_level(), 1.0);
755        assert_eq!(make(20).congestion_level(), 1.0);
756
757        // Test missed chunks congestion with outgoing congestion
758        let mut info = CongestionInfo::default();
759        info.add_buffered_receipt_gas(config.max_congestion_outgoing_gas / 2).unwrap();
760        let make = |count| CongestionControl::new(config, info, count);
761
762        // include missing chunks congestion
763        assert_eq!(make(0).congestion_level(), 0.5);
764        assert_eq!(make(1).congestion_level(), 0.5);
765        assert_eq!(make(2).congestion_level(), 0.5);
766        assert_eq!(make(5).congestion_level(), 0.5);
767        assert_eq!(make(6).congestion_level(), 0.6);
768        assert_eq!(make(10).congestion_level(), 1.0);
769        assert_eq!(make(20).congestion_level(), 1.0);
770
771        // exclude missing chunks congestion
772        assert_eq!(make(0).info.localized_congestion_level(&config), 0.5);
773        assert_eq!(make(1).info.localized_congestion_level(&config), 0.5);
774        assert_eq!(make(2).info.localized_congestion_level(&config), 0.5);
775        assert_eq!(make(5).info.localized_congestion_level(&config), 0.5);
776        assert_eq!(make(6).info.localized_congestion_level(&config), 0.5);
777        assert_eq!(make(10).info.localized_congestion_level(&config), 0.5);
778        assert_eq!(make(20).info.localized_congestion_level(&config), 0.5);
779    }
780
781    #[test]
782    fn test_missed_chunks_finalize() {
783        // The default config is quite restricting, allow more missed chunks for
784        // this test to check the middle cases.
785        let mut config = get_config();
786        config.max_congestion_missed_chunks = 10;
787
788        // Setup half congested congestion info.
789        let mut info = CongestionInfo::default();
790        info.add_buffered_receipt_gas(config.max_congestion_outgoing_gas / 2).unwrap();
791
792        let shard = ShardId::new(2);
793        let all_shards = [0, 1, 2, 3, 4].into_iter().map(ShardId::new).collect_vec();
794
795        // Test without missed chunks congestion.
796
797        let missed_chunks_count = 0;
798        let mut control = CongestionControl::new(config, info, missed_chunks_count);
799        control.info.finalize_allowed_shard(shard, &all_shards, 3);
800
801        let expected_outgoing_limit =
802            0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64;
803        for &shard in &all_shards {
804            assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64);
805        }
806
807        // Test with some missed chunks congestion.
808
809        let missed_chunks_count = 8;
810        let mut control = CongestionControl::new(config, info, missed_chunks_count);
811        control.info.finalize_allowed_shard(shard, &all_shards, 3);
812
813        let expected_outgoing_limit =
814            mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8) as f64;
815        for &shard in &all_shards {
816            assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64);
817        }
818
819        // Test with full missed chunks congestion.
820
821        let missed_chunks_count = config.max_congestion_missed_chunks;
822        let mut control = CongestionControl::new(config, info, missed_chunks_count);
823        control.info.finalize_allowed_shard(shard, &all_shards, 3);
824
825        // Full congestion - only the allowed shard should be able to send something.
826        for shard in all_shards {
827            if shard == ShardId::from(control.info.allowed_shard()) {
828                assert_eq!(control.outgoing_gas_limit(shard), config.allowed_shard_outgoing_gas);
829            } else {
830                assert_eq!(control.outgoing_gas_limit(shard), 0);
831            }
832        }
833    }
834}