Skip to main content

tank_tests/
requests.rs

1use std::{
2    borrow::Cow,
3    sync::{
4        LazyLock,
5        atomic::{AtomicUsize, Ordering},
6    },
7    time::{SystemTime, UNIX_EPOCH},
8};
9use tank::{
10    AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
11    join, stream::StreamExt,
12};
13use tokio::sync::Mutex;
14
15static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum Method {
19    GET,
20    POST,
21    PUT,
22    DELETE,
23}
24impl AsValue for Method {
25    fn as_empty_value() -> Value {
26        Value::Varchar(None)
27    }
28    fn as_value(self) -> Value {
29        Value::Varchar(Some(
30            match self {
31                Method::GET => "get",
32                Method::POST => "post",
33                Method::PUT => "put",
34                Method::DELETE => "delete",
35            }
36            .into(),
37        ))
38    }
39    fn try_from_value(value: Value) -> Result<Self>
40    where
41        Self: Sized,
42    {
43        if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
44            match &*v {
45                "get" => return Ok(Method::GET),
46                "post" => return Ok(Method::POST),
47                "put" => return Ok(Method::PUT),
48                "delete" => return Ok(Method::DELETE),
49                _ => {
50                    return Err(Error::msg(format!(
51                        "Unexpected value `{v}` for Method enum"
52                    )));
53                }
54            }
55        }
56        Err(Error::msg("Unexpected value for Method enum"))
57    }
58}
59
60#[derive(Default, Entity, PartialEq, Eq)]
61#[tank(schema = "api")]
62struct RequestLimit {
63    #[tank(primary_key)]
64    pub id: i32,
65    pub target_pattern: Cow<'static, str>,
66    pub requests: i32,
67    // If set it applies only to the requests with that method, otherwise it affets all methods
68    pub method: Option<Method>,
69    // If set, it means maximum request in unit of time, otherwise means maximum concurrent requests
70    pub interval_ms: Option<i32>,
71}
72impl RequestLimit {
73    pub fn new(
74        target_pattern: &'static str,
75        requests: i32,
76        method: Option<Method>,
77        interval_ms: Option<i32>,
78    ) -> Self {
79        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
80        Self {
81            id,
82            target_pattern: target_pattern.into(),
83            requests,
84            method,
85            interval_ms,
86        }
87    }
88}
89
90#[derive(Entity, PartialEq, Eq)]
91#[tank(schema = "api")]
92pub struct Request {
93    #[tank(primary_key)]
94    pub id: i64,
95    pub target: String,
96    pub method: Option<Method>,
97    pub beign_timestamp_ms: i64,
98    pub end_timestamp_ms: Option<i64>,
99}
100
101static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
102
103impl Request {
104    pub fn new(target: String, method: Option<Method>) -> Self {
105        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
106        Self {
107            id,
108            target,
109            method,
110            beign_timestamp_ms: SystemTime::now()
111                .duration_since(UNIX_EPOCH)
112                .unwrap()
113                .as_millis() as _,
114            end_timestamp_ms: None,
115        }
116    }
117    pub fn end(&mut self) {
118        self.end_timestamp_ms = Some(
119            SystemTime::now()
120                .duration_since(UNIX_EPOCH)
121                .unwrap()
122                .as_millis() as _,
123        );
124    }
125}
126
127pub async fn requests<E: Executor>(executor: &mut E) {
128    let _lock = MUTEX.lock();
129
130    // Setup
131    RequestLimit::drop_table(executor, true, false)
132        .await
133        .expect("Could not drop the RequestLimit table");
134    Request::drop_table(executor, true, false)
135        .await
136        .expect("Could not drop the Request table");
137
138    RequestLimit::create_table(executor, false, true)
139        .await
140        .expect("Could not create the RequestLimit table");
141    Request::create_table(executor, false, true)
142        .await
143        .expect("Could not create the Request table");
144
145    // Request limits
146    RequestLimit::insert_many(
147        executor,
148        &[
149            // [1]: Max 3 concurrent requests
150            RequestLimit::new("v1/%", 3, None, None),
151            // [2]:  Max 5 data concurrent requests
152            RequestLimit::new("v1/server/data/%", 5, None, None),
153            // [3]:  Max 2 user concurrent put request
154            RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
155            // [4]:  Max 1 user concurrent delete request
156            RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
157            // [5]:  Max 5 requests
158            RequestLimit::new("v2/%", 5, None, 60_000.into()),
159        ],
160    )
161    .await
162    .expect("Could not insert the limits");
163    let limits = RequestLimit::find_many(executor, true, None)
164        .map(|v| v.expect("Found error"))
165        .count()
166        .await;
167    assert_eq!(limits, 5);
168
169    #[cfg(not(feature = "disable-joins"))]
170    {
171        let mut violated_limits = executor
172        .prepare(
173            QueryBuilder::new()
174                .select([
175                    RequestLimit::target_pattern,
176                    RequestLimit::requests,
177                    RequestLimit::method,
178                    RequestLimit::interval_ms,
179                ])
180                .from(join!(RequestLimit CROSS JOIN Request))
181                .where_expr(expr!(
182                    ? == RequestLimit::target_pattern as LIKE
183                        && Request::target == RequestLimit::target_pattern as LIKE
184                        && (RequestLimit::method == NULL
185                            || RequestLimit::method == Request::method)
186                        && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
187                            || RequestLimit::interval_ms != NULL
188                                && Request::end_timestamp_ms
189                                    >= current_timestamp_ms!() - RequestLimit::interval_ms)
190                ))
191                .group_by([
192                    RequestLimit::target_pattern,
193                    RequestLimit::requests,
194                    RequestLimit::method,
195                    RequestLimit::interval_ms,
196                ])
197                .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
198                .build(&executor.driver()),
199        )
200        .await
201        .expect("Failed to prepare the limit query");
202
203        let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
204        let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
205        let r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
206        let r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
207        let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
208
209        violated_limits.bind(r1.target.clone()).unwrap();
210        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
211        r1.save(executor).await.expect("Failed to save r1");
212
213        violated_limits.bind(r2.target.clone()).unwrap();
214        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
215        r2.save(executor).await.expect("Failed to save r2");
216
217        violated_limits.bind(r3.target.clone()).unwrap();
218        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [3]
219
220        // Request 4 fits because it doesn't refer to /user
221        violated_limits.bind(r4.target.clone()).unwrap();
222        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
223        r4.save(executor).await.expect("Failed to save r4");
224
225        violated_limits.bind(r5.target.clone()).unwrap();
226        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [3]
227
228        r1.end();
229        r1.save(executor).await.expect("Could not terminate r1");
230
231        violated_limits.bind(r3.target.clone()).unwrap();
232        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
233        r3.save(executor).await.expect("Failed to save r3");
234
235        // 3 Running requests
236
237        let mut data_reqs = vec![];
238        for i in 0..5 {
239            let req = Request::new(format!("v1/server/data/item/{}", i), None);
240            req.save(executor)
241                .await
242                .expect("Failed to save data request");
243            data_reqs.push(req);
244        }
245
246        // 8 Running requests
247
248        violated_limits
249            .bind("v1/server/data/item/999".to_string())
250            .unwrap();
251        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [2]
252
253        for i in 0..4 {
254            data_reqs[i].end();
255            data_reqs[i]
256                .save(executor)
257                .await
258                .expect(&format!("Failed to save data_reqs[{i}]"));
259        }
260
261        // 4 Running requests
262
263        violated_limits
264            .bind("v1/server/data/item/999".to_string())
265            .unwrap();
266        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [1] still
267
268        r2.end();
269        r2.save(executor).await.expect("Could not terminate r2");
270        data_reqs[4].end();
271        data_reqs[4]
272            .save(executor)
273            .await
274            .expect("Could not terminate data_reqs[4]");
275
276        violated_limits
277            .bind("v1/server/data/item/999".to_string())
278            .unwrap();
279        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
280    }
281}