api_tools/server/axum/layers/
time_limiter.rs1use crate::server::axum::{layers::body_from_parts, response::ApiError};
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::Response;
7use chrono::Local;
8use futures::future::BoxFuture;
9use std::fmt::Display;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12
13#[derive(Debug, Clone, PartialEq)]
16pub struct TimeSlots(Vec<TimeSlot>);
17
18impl TimeSlots {
19 pub fn values(&self) -> &Vec<TimeSlot> {
33 &self.0
34 }
35
36 pub fn contains(&self, time: &str) -> bool {
60 self.0.iter().any(|slot| *slot.start <= *time && *time <= *slot.end)
61 }
62}
63
64impl Display for TimeSlots {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 let mut slots = String::new();
67 for (i, slot) in self.0.iter().enumerate() {
68 slots.push_str(&format!("{} - {}", slot.start, slot.end));
69
70 if i < self.0.len() - 1 {
71 slots.push_str(", ");
72 }
73 }
74 write!(f, "{}", slots)
75 }
76}
77
78impl From<&str> for TimeSlots {
79 fn from(value: &str) -> Self {
80 Self(
81 value
82 .split(',')
83 .filter_map(|part| part.try_into().ok())
84 .collect::<Vec<_>>(),
85 )
86 }
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub struct TimeSlot {
91 pub start: String,
92 pub end: String,
93}
94
95impl TryFrom<&str> for TimeSlot {
96 type Error = ApiError;
97
98 fn try_from(value: &str) -> Result<Self, Self::Error> {
99 let (start, end) = value.split_once('-').ok_or(ApiError::InternalServerError(
100 "Time slots configuration error".to_string(),
101 ))?;
102
103 if start.len() != 5 || end.len() != 5 {
104 return Err(ApiError::InternalServerError(
105 "Time slots configuration error".to_string(),
106 ));
107 }
108
109 Ok(Self {
110 start: start.to_string(),
111 end: end.to_string(),
112 })
113 }
114}
115
116#[derive(Clone)]
117pub struct TimeLimiterLayer {
118 pub time_slots: TimeSlots,
119}
120
121impl TimeLimiterLayer {
122 pub fn new(time_slots: TimeSlots) -> Self {
124 Self { time_slots }
125 }
126}
127
128impl<S> Layer<S> for TimeLimiterLayer {
129 type Service = TimeLimiterMiddleware<S>;
130
131 fn layer(&self, inner: S) -> Self::Service {
132 TimeLimiterMiddleware {
133 inner,
134 time_slots: self.time_slots.clone(),
135 }
136 }
137}
138
139#[derive(Clone)]
140pub struct TimeLimiterMiddleware<S> {
141 inner: S,
142 time_slots: TimeSlots,
143}
144
145impl<S> Service<Request<Body>> for TimeLimiterMiddleware<S>
146where
147 S: Service<Request<Body>, Response = Response> + Send + 'static,
148 S::Future: Send + 'static,
149{
150 type Response = S::Response;
151 type Error = S::Error;
152 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
154
155 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 self.inner.poll_ready(cx)
157 }
158
159 fn call(&mut self, request: Request<Body>) -> Self::Future {
160 let now = Local::now().format("%H:%M").to_string();
161 let is_authorized = !self.time_slots.contains(&now);
162 let time_slots = self.time_slots.clone();
163
164 let future = self.inner.call(request);
165 Box::pin(async move {
166 let mut response = Response::default();
167
168 response = match is_authorized {
169 true => future.await?,
170 false => {
171 let (mut parts, _body) = response.into_parts();
172 let msg = body_from_parts(
173 &mut parts,
174 StatusCode::SERVICE_UNAVAILABLE,
175 format!("Service unavailable during these times: {}", time_slots).as_str(),
176 None,
177 );
178 Response::from_parts(parts, Body::from(msg))
179 }
180 };
181
182 Ok(response)
183 })
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_timeslots_from_str() {
193 let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
194 assert_eq!(time_slots.values().len(), 2);
195 assert_eq!(time_slots.values()[0].start, "08:00");
196 assert_eq!(time_slots.values()[0].end, "12:00");
197 assert_eq!(time_slots.values()[1].start, "13:00");
198 assert_eq!(time_slots.values()[1].end, "17:00");
199 }
200
201 #[test]
202 fn test_timeslot_try_from_valid() {
203 let slot: TimeSlot = "10:00-11:00".try_into().unwrap();
204 assert_eq!(slot.start, "10:00");
205 assert_eq!(slot.end, "11:00");
206 }
207
208 #[test]
209 fn test_timeslot_try_from_invalid_format() {
210 let slot = TimeSlot::try_from("1000-1100");
211 assert!(slot.is_err());
212 let slot = TimeSlot::try_from("10:00/11:00");
213 assert!(slot.is_err());
214 }
215
216 #[test]
217 fn test_timeslots_display() {
218 let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
219 let display = format!("{}", time_slots);
220 assert_eq!(display, "08:00 - 12:00, 13:00 - 17:00");
221 }
222
223 #[test]
224 fn test_timeslots_empty_display() {
225 let time_slots: TimeSlots = "".into();
226 let display = format!("{}", time_slots);
227 assert_eq!(display, "");
228 }
229}