nzb_dispatch/
bandwidth.rs1use std::num::NonZeroU32;
2use std::sync::atomic::{AtomicU32, Ordering};
3
4use arc_swap::ArcSwapOption;
5use governor::DefaultDirectRateLimiter as RateLimiter;
6use governor::Quota;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10#[derive(Default, Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq)]
11pub struct BandwidthConfig {
12 pub download_bps: Option<NonZeroU32>,
14}
15
16struct Limit {
17 limiter: ArcSwapOption<RateLimiter>,
18 current_bps: AtomicU32,
19}
20
21impl Limit {
22 fn new_inner(bps: Option<NonZeroU32>) -> Option<Arc<RateLimiter>> {
23 let bps = bps?;
24 Some(Arc::new(RateLimiter::direct(Quota::per_second(bps))))
25 }
26
27 fn new(bps: Option<NonZeroU32>) -> Self {
28 Self {
29 limiter: ArcSwapOption::new(Self::new_inner(bps)),
30 current_bps: AtomicU32::new(bps.map(|v| v.get()).unwrap_or(0)),
31 }
32 }
33
34 async fn acquire(&self, size: NonZeroU32) -> anyhow::Result<()> {
35 let lim = self.limiter.load().clone();
36 if let Some(rl) = lim.as_ref() {
37 rl.until_n_ready(size).await?;
38 }
39 Ok(())
40 }
41
42 fn set(&self, limit: Option<NonZeroU32>) {
43 let new = Self::new_inner(limit);
44 self.limiter.swap(new);
45 self.current_bps
46 .store(limit.map(|v| v.get()).unwrap_or(0), Ordering::Relaxed);
47 }
48
49 fn get(&self) -> Option<NonZeroU32> {
50 NonZeroU32::new(self.current_bps.load(Ordering::Relaxed))
51 }
52}
53
54pub struct BandwidthLimiter {
55 download: Limit,
56}
57
58impl BandwidthLimiter {
59 pub fn new(config: BandwidthConfig) -> Self {
60 Self {
61 download: Limit::new(config.download_bps),
62 }
63 }
64
65 pub async fn acquire_download(&self, len: NonZeroU32) -> anyhow::Result<()> {
66 self.download.acquire(len).await
67 }
68
69 pub fn set_download_bps(&self, bps: Option<NonZeroU32>) {
70 self.download.set(bps);
71 }
72
73 pub fn get_download_bps(&self) -> Option<NonZeroU32> {
74 self.download.get()
75 }
76
77 pub fn get_config(&self) -> BandwidthConfig {
78 BandwidthConfig {
79 download_bps: self.get_download_bps(),
80 }
81 }
82}