rustbac_client/
throttle.rs1use rustbac_datalink::DataLinkAddress;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
6use tokio::time::Instant;
7
8#[derive(Clone, Copy, Debug)]
9struct DeviceThrottleConfig {
10 max_concurrent: usize,
11 min_interval: Duration,
12}
13
14#[derive(Debug)]
19pub struct DeviceThrottle {
20 semaphores: Mutex<HashMap<DataLinkAddress, Arc<Semaphore>>>,
21 last_request: Mutex<HashMap<DataLinkAddress, Instant>>,
22 overrides: Mutex<HashMap<DataLinkAddress, DeviceThrottleConfig>>,
23 default_max_concurrent: usize,
24 default_min_interval: Duration,
25}
26
27impl DeviceThrottle {
28 pub fn new(max_concurrent: usize, min_interval: Duration) -> Self {
30 Self {
31 semaphores: Mutex::new(HashMap::new()),
32 last_request: Mutex::new(HashMap::new()),
33 overrides: Mutex::new(HashMap::new()),
34 default_max_concurrent: max_concurrent.max(1),
35 default_min_interval: min_interval,
36 }
37 }
38
39 pub async fn set_device_limit(
41 &self,
42 address: DataLinkAddress,
43 max_concurrent: usize,
44 min_interval: Duration,
45 ) {
46 let max_concurrent = max_concurrent.max(1);
47 self.overrides.lock().await.insert(
48 address,
49 DeviceThrottleConfig {
50 max_concurrent,
51 min_interval,
52 },
53 );
54
55 self.semaphores
57 .lock()
58 .await
59 .insert(address, Arc::new(Semaphore::new(max_concurrent)));
60 }
61
62 pub async fn acquire(&self, address: DataLinkAddress) -> OwnedSemaphorePermit {
65 let config = {
66 let overrides = self.overrides.lock().await;
67 overrides
68 .get(&address)
69 .copied()
70 .unwrap_or(DeviceThrottleConfig {
71 max_concurrent: self.default_max_concurrent,
72 min_interval: self.default_min_interval,
73 })
74 };
75
76 let semaphore = {
77 let mut semaphores = self.semaphores.lock().await;
78 semaphores
79 .entry(address)
80 .or_insert_with(|| Arc::new(Semaphore::new(config.max_concurrent)))
81 .clone()
82 };
83
84 let permit = semaphore
85 .acquire_owned()
86 .await
87 .expect("device throttle semaphore closed unexpectedly");
88
89 if !config.min_interval.is_zero() {
90 let mut last_request = self.last_request.lock().await;
91 if let Some(last) = last_request.get(&address) {
92 let elapsed = last.elapsed();
93 if elapsed < config.min_interval {
94 tokio::time::sleep(config.min_interval - elapsed).await;
95 }
96 }
97 last_request.insert(address, Instant::now());
98 }
99
100 permit
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::DeviceThrottle;
107 use rustbac_datalink::DataLinkAddress;
108 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
109 use std::time::Duration;
110 use tokio::time::{timeout, Instant};
111
112 fn addr(port: u16) -> DataLinkAddress {
113 DataLinkAddress::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port))
114 }
115
116 #[tokio::test]
117 async fn enforces_concurrency_limit() {
118 let throttle = DeviceThrottle::new(1, Duration::ZERO);
119 let first = throttle.acquire(addr(47808)).await;
120
121 let blocked = timeout(Duration::from_millis(40), throttle.acquire(addr(47808))).await;
122 assert!(blocked.is_err());
123
124 drop(first);
125 let second = timeout(Duration::from_millis(200), throttle.acquire(addr(47808)))
126 .await
127 .expect("second permit should be acquired");
128 drop(second);
129 }
130
131 #[tokio::test]
132 async fn enforces_minimum_interval() {
133 let throttle = DeviceThrottle::new(1, Duration::from_millis(80));
134 let first = throttle.acquire(addr(47809)).await;
135 drop(first);
136
137 let started = Instant::now();
138 let second = throttle.acquire(addr(47809)).await;
139 let elapsed = started.elapsed();
140 drop(second);
141
142 assert!(
143 elapsed >= Duration::from_millis(70),
144 "elapsed {:?} was shorter than expected interval",
145 elapsed
146 );
147 }
148
149 #[tokio::test]
150 async fn applies_per_device_overrides() {
151 let throttle = DeviceThrottle::new(1, Duration::from_millis(120));
152 let target = addr(47810);
153 let other = addr(47811);
154
155 throttle
156 .set_device_limit(target, 2, Duration::from_millis(10))
157 .await;
158
159 let first = throttle.acquire(target).await;
160 let second = throttle.acquire(target).await;
161 let third = timeout(Duration::from_millis(40), throttle.acquire(target)).await;
162 assert!(
163 third.is_err(),
164 "third permit should block at override limit"
165 );
166 drop(first);
167 drop(second);
168
169 let first_other = throttle.acquire(other).await;
170 let blocked_other = timeout(Duration::from_millis(40), throttle.acquire(other)).await;
171 assert!(blocked_other.is_err(), "default limit should still be one");
172 drop(first_other);
173 }
174}