architect_api/utils/
rate_limit.rs

1use crate::{utils::duration::parse_duration, NonZeroDurationAsStr};
2use anyhow::{anyhow, Result};
3use governor::Quota;
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_with::{serde_as, serde_conv};
7use std::{num::NonZeroU32, str::FromStr, time::Duration};
8
9#[serde_as]
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
11pub struct RateLimit {
12    pub max: NonZeroU32,
13    #[serde_as(as = "NonZeroDurationAsStr")]
14    #[schemars(with = "NonZeroDurationAsStr")]
15    pub per: Duration,
16}
17
18impl RateLimit {
19    pub fn as_quota(&self) -> governor::Quota {
20        governor::Quota::with_period(self.per)
21            .unwrap() // NonZeroDurationAsStr ensures this is non-zero
22            .allow_burst(self.max)
23    }
24}
25
26impl From<&Quota> for RateLimit {
27    fn from(quota: &Quota) -> Self {
28        RateLimit { max: quota.burst_size(), per: quota.replenish_interval() }
29    }
30}
31
32impl FromStr for RateLimit {
33    type Err = anyhow::Error;
34
35    fn from_str(s: &str) -> Result<Self, Self::Err> {
36        let (max, per) =
37            s.split_once('/').ok_or_else(|| anyhow!("invalid rate limit"))?;
38        Ok(RateLimit {
39            max: max.trim().parse()?,
40            per: parse_duration(per.trim())?.to_std()?,
41        })
42    }
43}
44
45serde_conv!(
46    pub QuotaAsRateLimit,
47    Quota,
48    RateLimit::from,
49    try_into_quota
50);
51
52fn try_into_quota(rate_limit: RateLimit) -> Result<Quota, std::convert::Infallible> {
53    Ok(rate_limit.as_quota())
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn test_serde_rate_limit() {
62        let rate_limit = RateLimit {
63            max: NonZeroU32::new(100).unwrap(),
64            per: Duration::from_secs(60 * 5),
65        };
66        insta::assert_json_snapshot!(rate_limit, @r#"
67        {
68          "max": 100,
69          "per": "300.000000000s"
70        }
71        "#);
72        let parsed_rate_limit = r#"
73        {
74          "max": 100,
75          "per": "300.000000000s"
76        }
77        "#;
78        let parsed_rate_limit: RateLimit =
79            serde_json::from_str(parsed_rate_limit).unwrap();
80        assert_eq!(rate_limit, parsed_rate_limit);
81        // test roundtrip of serde_json::Value
82        let json_value = serde_json::to_value(rate_limit).unwrap();
83        let parsed: RateLimit = serde_json::from_value(json_value).unwrap();
84        assert_eq!(rate_limit, parsed);
85    }
86}