Skip to main content

nzb_dispatch/
bandwidth.rs

1use std::num::NonZeroU32;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use arc_swap::ArcSwapOption;
5use governor::DefaultDirectRateLimiter as RateLimiter;
6use governor::Quota;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10#[derive(Default, Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
11pub struct BandwidthConfig {
12    /// Download speed limit in bytes per second (None = unlimited)
13    pub download_bps: Option<NonZeroU32>,
14}
15
16struct Limit {
17    limiter: ArcSwapOption<RateLimiter>,
18    current_bps: AtomicU32,
19}
20
21impl Limit {
22    fn new_inner(bps: Option<NonZeroU32>) -> Option<Arc<RateLimiter>> {
23        let bps = bps?;
24        Some(Arc::new(RateLimiter::direct(Quota::per_second(bps))))
25    }
26
27    fn new(bps: Option<NonZeroU32>) -> Self {
28        Self {
29            limiter: ArcSwapOption::new(Self::new_inner(bps)),
30            current_bps: AtomicU32::new(bps.map(|v| v.get()).unwrap_or(0)),
31        }
32    }
33
34    async fn acquire(&self, size: NonZeroU32) -> anyhow::Result<()> {
35        let lim = self.limiter.load().clone();
36        if let Some(rl) = lim.as_ref() {
37            rl.until_n_ready(size).await?;
38        }
39        Ok(())
40    }
41
42    fn set(&self, limit: Option<NonZeroU32>) {
43        let new = Self::new_inner(limit);
44        self.limiter.swap(new);
45        self.current_bps
46            .store(limit.map(|v| v.get()).unwrap_or(0), Ordering::Relaxed);
47    }
48
49    fn get(&self) -> Option<NonZeroU32> {
50        NonZeroU32::new(self.current_bps.load(Ordering::Relaxed))
51    }
52}
53
54pub struct BandwidthLimiter {
55    download: Limit,
56}
57
58impl BandwidthLimiter {
59    pub fn new(config: BandwidthConfig) -> Self {
60        Self {
61            download: Limit::new(config.download_bps),
62        }
63    }
64
65    pub async fn acquire_download(&self, len: NonZeroU32) -> anyhow::Result<()> {
66        self.download.acquire(len).await
67    }
68
69    pub fn set_download_bps(&self, bps: Option<NonZeroU32>) {
70        self.download.set(bps);
71    }
72
73    pub fn get_download_bps(&self) -> Option<NonZeroU32> {
74        self.download.get()
75    }
76
77    pub fn get_config(&self) -> BandwidthConfig {
78        BandwidthConfig {
79            download_bps: self.get_download_bps(),
80        }
81    }
82}