Skip to main content

tank_tests/
requests.rs

1use std::{
2    borrow::Cow,
3    sync::{
4        LazyLock,
5        atomic::{AtomicUsize, Ordering},
6    },
7    time::{SystemTime, UNIX_EPOCH},
8};
9use tank::{
10    AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
11    join,
12    stream::{StreamExt, TryStreamExt},
13};
14use tokio::sync::Mutex;
15
16static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub enum Method {
20    GET,
21    POST,
22    PUT,
23    DELETE,
24}
25impl AsValue for Method {
26    fn as_empty_value() -> Value {
27        Value::Varchar(None)
28    }
29    fn as_value(self) -> Value {
30        Value::Varchar(Some(
31            match self {
32                Method::GET => "get",
33                Method::POST => "post",
34                Method::PUT => "put",
35                Method::DELETE => "delete",
36            }
37            .into(),
38        ))
39    }
40    fn try_from_value(value: Value) -> Result<Self>
41    where
42        Self: Sized,
43    {
44        if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
45            match &*v {
46                "get" => return Ok(Method::GET),
47                "post" => return Ok(Method::POST),
48                "put" => return Ok(Method::PUT),
49                "delete" => return Ok(Method::DELETE),
50                _ => {
51                    return Err(Error::msg(format!(
52                        "Unexpected value `{v}` for Method enum"
53                    )));
54                }
55            }
56        }
57        Err(Error::msg("Unexpected value for Method enum"))
58    }
59}
60
61#[derive(Default, Entity, PartialEq, Eq)]
62#[tank(schema = "api")]
63struct RequestLimit {
64    #[tank(primary_key)]
65    pub id: i32,
66    pub target_pattern: Cow<'static, str>,
67    pub requests: i32,
68    // If set it applies only to the requests with that method, otherwise it affets all methods
69    pub method: Option<Method>,
70    // If set, it means maximum request in unit of time, otherwise means maximum concurrent requests
71    pub interval_ms: Option<i32>,
72}
73impl RequestLimit {
74    pub fn new(
75        target_pattern: &'static str,
76        requests: i32,
77        method: Option<Method>,
78        interval_ms: Option<i32>,
79    ) -> Self {
80        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
81        Self {
82            id,
83            target_pattern: target_pattern.into(),
84            requests,
85            method,
86            interval_ms,
87        }
88    }
89}
90
91#[derive(Entity, PartialEq, Eq)]
92#[tank(schema = "api")]
93pub struct Request {
94    #[tank(primary_key)]
95    pub id: i64,
96    pub target: String,
97    pub method: Option<Method>,
98    pub beign_timestamp_ms: i64,
99    pub end_timestamp_ms: Option<i64>,
100}
101
102static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
103
104impl Request {
105    pub fn new(target: String, method: Option<Method>) -> Self {
106        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
107        Self {
108            id,
109            target,
110            method,
111            beign_timestamp_ms: SystemTime::now()
112                .duration_since(UNIX_EPOCH)
113                .unwrap()
114                .as_millis() as _,
115            end_timestamp_ms: None,
116        }
117    }
118    pub fn end(&mut self) {
119        self.end_timestamp_ms = Some(
120            SystemTime::now()
121                .duration_since(UNIX_EPOCH)
122                .unwrap()
123                .as_millis() as _,
124        );
125    }
126}
127
128pub async fn requests<E: Executor>(executor: &mut E) {
129    let _lock = MUTEX.lock();
130
131    // Setup
132    RequestLimit::drop_table(executor, true, false)
133        .await
134        .expect("Could not drop the RequestLimit table");
135    Request::drop_table(executor, true, false)
136        .await
137        .expect("Could not drop the Request table");
138
139    RequestLimit::create_table(executor, false, true)
140        .await
141        .expect("Could not create the RequestLimit table");
142    Request::create_table(executor, false, true)
143        .await
144        .expect("Could not create the Request table");
145
146    // Request limits
147    RequestLimit::insert_many(
148        executor,
149        &[
150            // [1]: Max 3 concurrent requests
151            RequestLimit::new("v1/%", 3, None, None),
152            // [2]:  Max 5 data concurrent requests
153            RequestLimit::new("v1/server/data/%", 5, None, None),
154            // [3]:  Max 2 user concurrent put request
155            RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
156            // [4]:  Max 1 user concurrent delete request
157            RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
158            // [5]:  Max 5 requests
159            RequestLimit::new("v2/%", 5, None, 60_000.into()),
160        ],
161    )
162    .await
163    .expect("Could not insert the limits");
164    let limits = RequestLimit::find_many(executor, true, None)
165        .map_err(|e| panic!("{e:#}"))
166        .count()
167        .await;
168    assert_eq!(limits, 5);
169
170    #[cfg(not(feature = "disable-joins"))]
171    {
172        let mut violated_limits = executor
173        .prepare(
174            QueryBuilder::new()
175                .select([
176                    RequestLimit::target_pattern,
177                    RequestLimit::requests,
178                    RequestLimit::method,
179                    RequestLimit::interval_ms,
180                ])
181                .from(join!(RequestLimit CROSS JOIN Request))
182                .where_expr(expr!(
183                    ? == RequestLimit::target_pattern as LIKE
184                        && Request::target == RequestLimit::target_pattern as LIKE
185                        && (RequestLimit::method == NULL
186                            || RequestLimit::method == Request::method)
187                        && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
188                            || RequestLimit::interval_ms != NULL
189                                && Request::end_timestamp_ms
190                                    >= current_timestamp_ms!() - RequestLimit::interval_ms)
191                ))
192                .group_by([
193                    RequestLimit::target_pattern,
194                    RequestLimit::requests,
195                    RequestLimit::method,
196                    RequestLimit::interval_ms,
197                ])
198                .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
199                .build(&executor.driver()),
200        )
201        .await
202        .expect("Failed to prepare the limit query");
203
204        let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
205        let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
206        let mut r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
207        let mut r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
208        let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
209
210        violated_limits.bind(r1.target.clone()).unwrap();
211        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
212        r1.save(executor).await.expect("Failed to save r1");
213
214        violated_limits.bind(r2.target.clone()).unwrap();
215        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
216        r2.save(executor).await.expect("Failed to save r2");
217
218        violated_limits.bind(r3.target.clone()).unwrap();
219        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [3]
220
221        // Request 4 fits because it doesn't refer to /user
222        violated_limits.bind(r4.target.clone()).unwrap();
223        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
224        r4.save(executor).await.expect("Failed to save r4");
225
226        violated_limits.bind(r5.target.clone()).unwrap();
227        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [3]
228
229        r1.end();
230        r1.save(executor).await.expect("Could not terminate r1");
231
232        violated_limits.bind(r3.target.clone()).unwrap();
233        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
234        r3.save(executor).await.expect("Failed to save r3");
235
236        // 3 Running requests
237
238        let mut data_reqs = vec![];
239        for i in 0..5 {
240            let req = Request::new(format!("v1/server/data/item/{}", i), None);
241            req.save(executor)
242                .await
243                .expect("Failed to save data request");
244            data_reqs.push(req);
245        }
246
247        // 8 Running requests
248
249        violated_limits
250            .bind("v1/server/data/item/999".to_string())
251            .unwrap();
252        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [2]
253
254        for i in 0..4 {
255            data_reqs[i].end();
256            data_reqs[i]
257                .save(executor)
258                .await
259                .expect(&format!("Failed to save data_reqs[{i}]"));
260        }
261
262        // 4 Running requests
263
264        violated_limits
265            .bind("v1/server/data/item/999".to_string())
266            .unwrap();
267        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [1] still
268
269        r2.end();
270        r2.save(executor).await.expect("Could not terminate r2");
271        data_reqs[4].end();
272        data_reqs[4]
273            .save(executor)
274            .await
275            .expect("Could not terminate data_reqs[4]");
276
277        violated_limits
278            .bind("v1/server/data/item/999".to_string())
279            .unwrap();
280        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
281
282        r3.end();
283        r3.save(executor).await.expect("Could not terminate r3");
284        r4.end();
285        r4.save(executor).await.expect("Could not terminate r4");
286
287        // Check [4]
288
289        let mut d1 = Request::new("v1/server/user/del/1".into(), Method::DELETE.into());
290        let d2 = Request::new("v1/server/user/del/2".into(), Method::DELETE.into());
291
292        violated_limits.bind(d1.target.clone()).unwrap();
293        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
294        d1.save(executor).await.expect("Failed to save d1");
295
296        violated_limits.bind(d2.target.clone()).unwrap();
297        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [4] // Problem here
298
299        d1.end();
300        d1.save(executor).await.expect("Failed to end d1");
301
302        violated_limits.bind(d2.target.clone()).unwrap();
303        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
304
305        let mut v2_reqs = vec![];
306        for i in 0..5 {
307            let mut req = Request::new(format!("v2/resource/{}", i), Method::GET.into());
308            req.save(executor).await.expect("Failed to save v2 req");
309            req.end(); // Must end them to be counted by the interval logic
310            req.save(executor).await.expect("Failed to end v2 req");
311            v2_reqs.push(req);
312        }
313
314        violated_limits.bind("v2/resource/new".to_string()).unwrap();
315        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [5]
316    }
317}