1use std::{
2 fmt::Debug,
3 sync::{Arc, RwLock, RwLockWriteGuard},
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use anyhow::{Context as _, anyhow};
9use async_trait::async_trait;
10use ic_bn_lib_common::{
11 traits::shed::GetsSystemInfo,
12 types::shed::{ShedReason, ShedResponse, SystemOptions},
13};
14use systemstat::{Platform, System};
15use tower::{Layer, Service, ServiceExt};
16use tracing::{debug, error};
17
18use super::{BoxFuture, ewma::EWMA};
19use crate::Error;
20
21#[derive(Clone)]
22pub struct SystemInfo(Arc<System>);
23
24impl SystemInfo {
25 pub fn new() -> Self {
26 Self(Arc::new(System::new()))
27 }
28}
29
30impl Default for SystemInfo {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36#[async_trait]
37impl GetsSystemInfo for SystemInfo {
38 async fn cpu_usage(&self) -> Result<f64, Error> {
39 let cpu = self
40 .0
41 .cpu_load_aggregate()
42 .context("unable to measure CPU load")?;
43 tokio::time::sleep(Duration::from_millis(900)).await;
44 let cpu = cpu.done().context("unable to measure CPU load")?;
45
46 Ok(1.0 - cpu.idle as f64)
47 }
48
49 fn memory_usage(&self) -> Result<f64, Error> {
50 let mem = self.0.memory().context("unable to measure memory usage")?;
51 if mem.total.as_u64() == 0 {
52 return Err(anyhow!("total memory is zero").into());
53 }
54
55 Ok(1.0 - mem.free.as_u64() as f64 / mem.total.as_u64() as f64)
56 }
57
58 fn load_avg(&self) -> Result<(f64, f64, f64), Error> {
59 let la = self
60 .0
61 .load_average()
62 .context("unable to measure load average")?;
63
64 Ok((la.one as f64, la.five as f64, la.fifteen as f64))
65 }
66}
67
68#[derive(Debug)]
69struct StateInner {
70 cpu: EWMA,
71 memory: EWMA,
72 load_avg: (EWMA, EWMA, EWMA),
73 shed_reason: Option<ShedReason>,
74}
75
76impl StateInner {
77 fn new(alpha: f64) -> Self {
78 Self {
79 cpu: EWMA::new(alpha),
80 memory: EWMA::new(alpha),
81 load_avg: (EWMA::new(alpha), EWMA::new(alpha), EWMA::new(alpha)),
82 shed_reason: None,
83 }
84 }
85}
86
87#[derive(Debug)]
89pub struct State<S: GetsSystemInfo> {
90 opts: SystemOptions,
91 sys_info: S,
92 inner: RwLock<StateInner>,
93}
94
95impl<S: GetsSystemInfo> State<S> {
96 pub fn new(alpha: f64, opts: SystemOptions, sys_info: S) -> Self {
97 Self {
98 opts,
99 sys_info,
100 inner: RwLock::new(StateInner::new(alpha)),
101 }
102 }
103
104 async fn measure(&self) -> Result<(), Error> {
106 let cpu = self.sys_info.cpu_usage().await?;
107 let mem = self.sys_info.memory_usage()?;
108 let (l1, l5, l15) = self.sys_info.load_avg()?;
109
110 let mut inner = self.inner.write().unwrap();
111 inner.cpu.add(cpu);
112 inner.memory.add(mem);
113 inner.load_avg.0.add(l1);
114 inner.load_avg.1.add(l5);
115 inner.load_avg.2.add(l15);
116
117 inner.shed_reason = self.evaluate(&inner);
119 debug!(
120 "System load: CPU {cpu}, MEM {mem}, LAVG1: {l1}, LAVG5: {l5}, LAVG15: {l15}, Overload: {:?}",
121 inner.shed_reason
122 );
123
124 drop(inner); Ok(())
126 }
127
128 fn evaluate(&self, state: &RwLockWriteGuard<'_, StateInner>) -> Option<ShedReason> {
129 if self
130 .opts
131 .cpu
132 .map(|x| state.cpu.get().unwrap_or(0.0) > x)
133 .unwrap_or(false)
134 {
135 return Some(ShedReason::CPU);
136 }
137
138 if self
139 .opts
140 .memory
141 .map(|x| state.memory.get().unwrap_or(0.0) > x)
142 .unwrap_or(false)
143 {
144 return Some(ShedReason::Memory);
145 }
146
147 if self
148 .opts
149 .loadavg_1
150 .map(|x| state.load_avg.0.get().unwrap_or(0.0) > x)
151 .unwrap_or(false)
152 {
153 return Some(ShedReason::LoadAvg);
154 }
155
156 if self
157 .opts
158 .loadavg_5
159 .map(|x| state.load_avg.1.get().unwrap_or(0.0) > x)
160 .unwrap_or(false)
161 {
162 return Some(ShedReason::LoadAvg);
163 }
164
165 if self
166 .opts
167 .loadavg_15
168 .map(|x| state.load_avg.2.get().unwrap_or(0.0) > x)
169 .unwrap_or(false)
170 {
171 return Some(ShedReason::LoadAvg);
172 }
173
174 None
175 }
176
177 fn is_overloaded(&self) -> Option<ShedReason> {
178 self.inner.read().unwrap().shed_reason
179 }
180
181 async fn run(&self) {
183 let mut interval = tokio::time::interval(Duration::from_secs(1));
185 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
186
187 loop {
188 interval.tick().await;
189
190 if let Err(e) = self.measure().await {
191 error!("SystemLoadShedder: error: {e:#}");
192 }
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct SystemLoadShedder<S: GetsSystemInfo, I> {
200 state: Arc<State<S>>,
201 inner: I,
202}
203
204impl<S: GetsSystemInfo, I> SystemLoadShedder<S, I> {
205 pub const fn new(inner: I, state: Arc<State<S>>) -> Self {
206 Self { state, inner }
207 }
208}
209
210impl<S: GetsSystemInfo, R, I> Service<R> for SystemLoadShedder<S, I>
212where
213 R: Send + 'static,
214 I: Service<R> + Clone + Send + Sync + 'static,
215 I::Future: Send,
216{
217 type Response = ShedResponse<I::Response>;
218 type Error = I::Error;
219 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
220
221 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
222 Poll::Ready(Ok(()))
223 }
224
225 fn call(&mut self, req: R) -> Self::Future {
226 let shed_reason = self.state.is_overloaded();
228 if let Some(v) = shed_reason {
229 return Box::pin(async move { Ok(ShedResponse::Overload(v)) });
230 }
231
232 let inner = self.inner.clone();
233 Box::pin(async move {
234 let response = inner.oneshot(req).await;
235 Ok(ShedResponse::Inner(response?))
236 })
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct SystemLoadShedderLayer<S: GetsSystemInfo>(Arc<State<S>>);
243
244impl<S: GetsSystemInfo> SystemLoadShedderLayer<S> {
245 pub fn new(ewma_alpha: f64, opts: SystemOptions, sys_info: S) -> Self {
246 let state = Arc::new(State::new(ewma_alpha, opts, sys_info));
248
249 let state_bg = state.clone();
251 tokio::spawn(async move { state_bg.run().await });
252
253 Self(state)
254 }
255}
256
257impl<S: GetsSystemInfo, I: Clone + Send + Sync + 'static> Layer<I> for SystemLoadShedderLayer<S> {
258 type Service = SystemLoadShedder<S, I>;
259
260 fn layer(&self, inner: I) -> Self::Service {
261 SystemLoadShedder::new(inner, self.0.clone())
262 }
263}
264
265#[cfg(test)]
266mod test {
267 use std::sync::Mutex;
268
269 use super::*;
270
271 #[derive(Clone, Debug)]
272 struct StubSystemInfoVal {
273 cpu: f64,
274 memory: f64,
275 l1: f64,
276 l5: f64,
277 l15: f64,
278 }
279
280 #[derive(Clone, Debug)]
281 struct StubSystemInfo {
282 v: Arc<Mutex<StubSystemInfoVal>>,
283 }
284
285 #[async_trait]
286 impl GetsSystemInfo for StubSystemInfo {
287 async fn cpu_usage(&self) -> Result<f64, Error> {
288 Ok(self.v.lock().unwrap().cpu)
289 }
290
291 fn memory_usage(&self) -> Result<f64, Error> {
292 Ok(self.v.lock().unwrap().memory)
293 }
294
295 fn load_avg(&self) -> Result<(f64, f64, f64), Error> {
296 let v = self.v.lock().unwrap();
297 Ok((v.l1, v.l5, v.l15))
298 }
299 }
300
301 #[derive(Debug, Clone)]
302 struct StubService;
303
304 impl Service<Duration> for StubService {
305 type Response = ();
306 type Error = Error;
307 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
308
309 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
310 Poll::Ready(Ok(()))
311 }
312
313 fn call(&mut self, req: Duration) -> Self::Future {
314 let fut = async move {
315 tokio::time::sleep(req).await;
316 Ok(())
317 };
318
319 Box::pin(fut)
320 }
321 }
322
323 #[tokio::test]
324 async fn test_system_shedder() {
325 let inner = StubService;
326 let opts = SystemOptions {
327 cpu: Some(0.5),
328 memory: Some(0.5),
329 loadavg_1: Some(0.5),
330 loadavg_5: Some(0.5),
331 loadavg_15: Some(0.5),
332 };
333 let sys_info = StubSystemInfo {
334 v: Arc::new(Mutex::new(StubSystemInfoVal {
335 cpu: 0.0,
336 memory: 0.0,
337 l1: 0.0,
338 l5: 0.0,
339 l15: 0.0,
340 })),
341 };
342
343 let state = Arc::new(State::new(0.8, opts, sys_info.clone()));
344 let mut shedder = SystemLoadShedder::new(inner, state.clone());
345 let _ = state.measure().await;
346 let resp = shedder.call(Duration::ZERO).await.unwrap();
347 assert!(matches!(resp, ShedResponse::Inner(_)));
348
349 sys_info.v.lock().unwrap().cpu = 1.0;
350 let _ = state.measure().await;
351 let resp = shedder.call(Duration::ZERO).await.unwrap();
352 assert_eq!(resp, ShedResponse::Overload(ShedReason::CPU));
353 sys_info.v.lock().unwrap().cpu = 0.0;
354
355 sys_info.v.lock().unwrap().memory = 1.0;
356 let _ = state.measure().await;
357 let resp = shedder.call(Duration::ZERO).await.unwrap();
358 assert_eq!(resp, ShedResponse::Overload(ShedReason::Memory));
359 sys_info.v.lock().unwrap().memory = 0.0;
360
361 sys_info.v.lock().unwrap().l1 = 1.0;
362 let _ = state.measure().await;
363 let resp = shedder.call(Duration::ZERO).await.unwrap();
364 assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg));
365 sys_info.v.lock().unwrap().l1 = 0.0;
366
367 sys_info.v.lock().unwrap().l5 = 1.0;
368 let _ = state.measure().await;
369 let resp = shedder.call(Duration::ZERO).await.unwrap();
370 assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg));
371 sys_info.v.lock().unwrap().l5 = 0.0;
372
373 sys_info.v.lock().unwrap().l15 = 1.0;
374 let _ = state.measure().await;
375 let resp = shedder.call(Duration::ZERO).await.unwrap();
376 assert_eq!(resp, ShedResponse::Overload(ShedReason::LoadAvg));
377 sys_info.v.lock().unwrap().l15 = 0.0;
378
379 let _ = state.measure().await;
380 let resp = shedder.call(Duration::ZERO).await.unwrap();
381 assert!(matches!(resp, ShedResponse::Inner(_)));
382 }
383}