api_tools/server/axum/layers/
time_limiter.rs

1//! Time limiter layer
2
3use crate::server::axum::{layers::body_from_parts, response::ApiError};
4use axum::body::Body;
5use axum::http::{Request, StatusCode};
6use axum::response::Response;
7use chrono::Local;
8use futures::future::BoxFuture;
9use std::fmt::Display;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12
13/// TimeSlots represents a collection of time intervals
14/// where each interval is defined by a start and end time.
15#[derive(Debug, Clone, PartialEq)]
16pub struct TimeSlots(Vec<TimeSlot>);
17
18impl TimeSlots {
19    /// Get time slots vector
20    ///
21    /// # Example
22    /// ```
23    /// use api_tools::server::axum::layers::time_limiter::TimeSlots;
24    ///
25    /// let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
26    /// assert_eq!(time_slots.values().len(), 2);
27    /// assert_eq!(time_slots.values()[0].start, "08:00");
28    /// assert_eq!(time_slots.values()[0].end, "12:00");
29    /// assert_eq!(time_slots.values()[1].start, "13:00");
30    /// assert_eq!(time_slots.values()[1].end, "17:00");
31    /// ```
32    pub fn values(&self) -> &Vec<TimeSlot> {
33        &self.0
34    }
35
36    /// Check if a time is in the time slots list
37    ///
38    /// # Example
39    /// ```
40    /// use api_tools::server::axum::layers::time_limiter::TimeSlots;
41    ///
42    /// let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
43    /// let now = "09:00";
44    /// assert_eq!(time_slots.contains(now), true);
45    ///
46    /// let now = "08:00";
47    /// assert_eq!(time_slots.contains(now), true);
48    ///
49    /// let now = "17:00";
50    /// assert_eq!(time_slots.contains(now), true);
51    ///
52    /// let now = "12:30";
53    /// assert_eq!(time_slots.contains(now), false);
54    ///
55    /// let time_slots: TimeSlots = "".into();
56    /// let now = "09:00";
57    /// assert_eq!(time_slots.contains(now), false);
58    /// ```
59    pub fn contains(&self, time: &str) -> bool {
60        self.0.iter().any(|slot| *slot.start <= *time && *time <= *slot.end)
61    }
62}
63
64impl Display for TimeSlots {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        let mut slots = String::new();
67        for (i, slot) in self.0.iter().enumerate() {
68            slots.push_str(&format!("{} - {}", slot.start, slot.end));
69
70            if i < self.0.len() - 1 {
71                slots.push_str(", ");
72            }
73        }
74        write!(f, "{}", slots)
75    }
76}
77
78impl From<&str> for TimeSlots {
79    fn from(value: &str) -> Self {
80        Self(
81            value
82                .split(',')
83                .filter_map(|part| part.try_into().ok())
84                .collect::<Vec<_>>(),
85        )
86    }
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub struct TimeSlot {
91    pub start: String,
92    pub end: String,
93}
94
95impl TryFrom<&str> for TimeSlot {
96    type Error = ApiError;
97
98    fn try_from(value: &str) -> Result<Self, Self::Error> {
99        let (start, end) = value.split_once('-').ok_or(ApiError::InternalServerError(
100            "Time slots configuration error".to_string(),
101        ))?;
102
103        if start.len() != 5 || end.len() != 5 {
104            return Err(ApiError::InternalServerError(
105                "Time slots configuration error".to_string(),
106            ));
107        }
108
109        Ok(Self {
110            start: start.to_string(),
111            end: end.to_string(),
112        })
113    }
114}
115
116#[derive(Clone)]
117pub struct TimeLimiterLayer {
118    pub time_slots: TimeSlots,
119}
120
121impl TimeLimiterLayer {
122    /// Create a new `TimeLimiterLayer`
123    pub fn new(time_slots: TimeSlots) -> Self {
124        Self { time_slots }
125    }
126}
127
128impl<S> Layer<S> for TimeLimiterLayer {
129    type Service = TimeLimiterMiddleware<S>;
130
131    fn layer(&self, inner: S) -> Self::Service {
132        TimeLimiterMiddleware {
133            inner,
134            time_slots: self.time_slots.clone(),
135        }
136    }
137}
138
139#[derive(Clone)]
140pub struct TimeLimiterMiddleware<S> {
141    inner: S,
142    time_slots: TimeSlots,
143}
144
145impl<S> Service<Request<Body>> for TimeLimiterMiddleware<S>
146where
147    S: Service<Request<Body>, Response = Response> + Send + 'static,
148    S::Future: Send + 'static,
149{
150    type Response = S::Response;
151    type Error = S::Error;
152    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
153    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
154
155    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        self.inner.poll_ready(cx)
157    }
158
159    fn call(&mut self, request: Request<Body>) -> Self::Future {
160        let now = Local::now().format("%H:%M").to_string();
161        let is_authorized = !self.time_slots.contains(&now);
162        let time_slots = self.time_slots.clone();
163
164        let future = self.inner.call(request);
165        Box::pin(async move {
166            let mut response = Response::default();
167
168            response = match is_authorized {
169                true => future.await?,
170                false => {
171                    let (mut parts, _body) = response.into_parts();
172                    let msg = body_from_parts(
173                        &mut parts,
174                        StatusCode::SERVICE_UNAVAILABLE,
175                        format!("Service unavailable during these times: {}", time_slots).as_str(),
176                        None,
177                    );
178                    Response::from_parts(parts, Body::from(msg))
179                }
180            };
181
182            Ok(response)
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_timeslots_from_str() {
193        let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
194        assert_eq!(time_slots.values().len(), 2);
195        assert_eq!(time_slots.values()[0].start, "08:00");
196        assert_eq!(time_slots.values()[0].end, "12:00");
197        assert_eq!(time_slots.values()[1].start, "13:00");
198        assert_eq!(time_slots.values()[1].end, "17:00");
199    }
200
201    #[test]
202    fn test_timeslot_try_from_valid() {
203        let slot: TimeSlot = "10:00-11:00".try_into().unwrap();
204        assert_eq!(slot.start, "10:00");
205        assert_eq!(slot.end, "11:00");
206    }
207
208    #[test]
209    fn test_timeslot_try_from_invalid_format() {
210        let slot = TimeSlot::try_from("1000-1100");
211        assert!(slot.is_err());
212        let slot = TimeSlot::try_from("10:00/11:00");
213        assert!(slot.is_err());
214    }
215
216    #[test]
217    fn test_timeslots_display() {
218        let time_slots: TimeSlots = "08:00-12:00,13:00-17:00".into();
219        let display = format!("{}", time_slots);
220        assert_eq!(display, "08:00 - 12:00, 13:00 - 17:00");
221    }
222
223    #[test]
224    fn test_timeslots_empty_display() {
225        let time_slots: TimeSlots = "".into();
226        let display = format!("{}", time_slots);
227        assert_eq!(display, "");
228    }
229}