Skip to main content

tank_tests/
requests.rs

1#![allow(unused_imports)]
2use std::{
3    borrow::Cow,
4    sync::{
5        LazyLock,
6        atomic::{AtomicUsize, Ordering},
7    },
8    time::{SystemTime, UNIX_EPOCH},
9};
10use tank::{
11    AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
12    join,
13    stream::{StreamExt, TryStreamExt},
14};
15use tokio::sync::Mutex;
16
17static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
18
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub enum Method {
21    GET,
22    POST,
23    PUT,
24    DELETE,
25}
26impl AsValue for Method {
27    fn as_empty_value() -> Value {
28        Value::Varchar(None)
29    }
30    fn as_value(self) -> Value {
31        Value::Varchar(Some(
32            match self {
33                Method::GET => "get",
34                Method::POST => "post",
35                Method::PUT => "put",
36                Method::DELETE => "delete",
37            }
38            .into(),
39        ))
40    }
41    fn try_from_value(value: Value) -> Result<Self>
42    where
43        Self: Sized,
44    {
45        if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
46            match &*v {
47                "get" => return Ok(Method::GET),
48                "post" => return Ok(Method::POST),
49                "put" => return Ok(Method::PUT),
50                "delete" => return Ok(Method::DELETE),
51                _ => {
52                    return Err(Error::msg(format!(
53                        "Unexpected value `{v}` for Method enum"
54                    )));
55                }
56            }
57        }
58        Err(Error::msg("Unexpected value for Method enum"))
59    }
60}
61
62#[derive(Default, Entity, PartialEq, Eq)]
63#[tank(schema = "api")]
64struct RequestLimit {
65    #[tank(primary_key)]
66    pub id: i32,
67    pub target_pattern: Cow<'static, str>,
68    pub requests: i32,
69    // If set it applies only to the requests with that method, otherwise it affets all methods
70    pub method: Option<Method>,
71    // If set, it means maximum request in unit of time, otherwise means maximum concurrent requests
72    pub interval_ms: Option<i32>,
73}
74impl RequestLimit {
75    pub fn new(
76        target_pattern: &'static str,
77        requests: i32,
78        method: Option<Method>,
79        interval_ms: Option<i32>,
80    ) -> Self {
81        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
82        Self {
83            id,
84            target_pattern: target_pattern.into(),
85            requests,
86            method,
87            interval_ms,
88        }
89    }
90}
91
92#[derive(Entity, PartialEq, Eq)]
93#[tank(schema = "api")]
94pub struct Request {
95    #[tank(primary_key)]
96    pub id: i64,
97    pub target: String,
98    pub method: Option<Method>,
99    pub beign_timestamp_ms: i64,
100    pub end_timestamp_ms: Option<i64>,
101}
102
103static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
104
105impl Request {
106    pub fn new(target: String, method: Option<Method>) -> Self {
107        let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
108        Self {
109            id,
110            target,
111            method,
112            beign_timestamp_ms: SystemTime::now()
113                .duration_since(UNIX_EPOCH)
114                .unwrap()
115                .as_millis() as _,
116            end_timestamp_ms: None,
117        }
118    }
119    pub fn end(&mut self) {
120        self.end_timestamp_ms = Some(
121            SystemTime::now()
122                .duration_since(UNIX_EPOCH)
123                .unwrap()
124                .as_millis() as _,
125        );
126    }
127}
128
129pub async fn requests(executor: &mut impl Executor) {
130    let _lock = MUTEX.lock();
131
132    // Setup
133    RequestLimit::drop_table(executor, true, false)
134        .await
135        .expect("Could not drop the RequestLimit table");
136    Request::drop_table(executor, true, false)
137        .await
138        .expect("Could not drop the Request table");
139
140    RequestLimit::create_table(executor, false, true)
141        .await
142        .expect("Could not create the RequestLimit table");
143    Request::create_table(executor, false, true)
144        .await
145        .expect("Could not create the Request table");
146
147    // Request limits
148    RequestLimit::insert_many(
149        executor,
150        &[
151            // [1]: Max 3 concurrent requests
152            RequestLimit::new("v1/%", 3, None, None),
153            // [2]:  Max 5 data concurrent requests
154            RequestLimit::new("v1/server/data/%", 5, None, None),
155            // [3]:  Max 2 user concurrent put request
156            RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
157            // [4]:  Max 1 user concurrent delete request
158            RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
159            // [5]:  Max 5 requests
160            RequestLimit::new("v2/%", 5, None, 60_000.into()),
161        ],
162    )
163    .await
164    .expect("Could not insert the limits");
165    let limits = RequestLimit::find_many(executor, true, None)
166        .map_err(|e| panic!("{e:#}"))
167        .count()
168        .await;
169    assert_eq!(limits, 5);
170
171    #[cfg(not(feature = "disable-joins"))]
172    {
173        let mut violated_limits = executor
174        .prepare(
175            QueryBuilder::new()
176                .select([
177                    RequestLimit::target_pattern,
178                    RequestLimit::requests,
179                    RequestLimit::method,
180                    RequestLimit::interval_ms,
181                ])
182                .from(join!(RequestLimit CROSS JOIN Request))
183                .where_expr(expr!(
184                    ? == RequestLimit::target_pattern as LIKE
185                        && Request::target == RequestLimit::target_pattern as LIKE
186                        && (RequestLimit::method == NULL
187                            || RequestLimit::method == Request::method)
188                        && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
189                            || RequestLimit::interval_ms != NULL
190                                && Request::end_timestamp_ms
191                                    >= current_timestamp_ms!() - RequestLimit::interval_ms)
192                ))
193                .group_by([
194                    RequestLimit::target_pattern,
195                    RequestLimit::requests,
196                    RequestLimit::method,
197                    RequestLimit::interval_ms,
198                ])
199                .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
200                .build(&executor.driver()),
201        )
202        .await
203        .expect("Failed to prepare the limit query");
204
205        let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
206        let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
207        let mut r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
208        let mut r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
209        let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
210
211        violated_limits.bind(r1.target.clone()).unwrap();
212        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
213        r1.save(executor).await.expect("Failed to save r1");
214
215        violated_limits.bind(r2.target.clone()).unwrap();
216        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
217        r2.save(executor).await.expect("Failed to save r2");
218
219        violated_limits.bind(r3.target.clone()).unwrap();
220        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [3]
221
222        // Request 4 fits because it doesn't refer to /user
223        violated_limits.bind(r4.target.clone()).unwrap();
224        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
225        r4.save(executor).await.expect("Failed to save r4");
226
227        violated_limits.bind(r5.target.clone()).unwrap();
228        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [3]
229
230        r1.end();
231        r1.save(executor).await.expect("Could not terminate r1");
232
233        violated_limits.bind(r3.target.clone()).unwrap();
234        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
235        r3.save(executor).await.expect("Failed to save r3");
236
237        // 3 Running requests
238
239        let mut data_reqs = vec![];
240        for i in 0..5 {
241            let req = Request::new(format!("v1/server/data/item/{}", i), None);
242            req.save(executor)
243                .await
244                .expect("Failed to save data request");
245            data_reqs.push(req);
246        }
247
248        // 8 Running requests
249
250        violated_limits
251            .bind("v1/server/data/item/999".to_string())
252            .unwrap();
253        assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); // Violates [1], [2]
254
255        for i in 0..4 {
256            data_reqs[i].end();
257            data_reqs[i]
258                .save(executor)
259                .await
260                .expect(&format!("Failed to save data_reqs[{i}]"));
261        }
262
263        // 4 Running requests
264
265        violated_limits
266            .bind("v1/server/data/item/999".to_string())
267            .unwrap();
268        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [1] still
269
270        r2.end();
271        r2.save(executor).await.expect("Could not terminate r2");
272        data_reqs[4].end();
273        data_reqs[4]
274            .save(executor)
275            .await
276            .expect("Could not terminate data_reqs[4]");
277
278        violated_limits
279            .bind("v1/server/data/item/999".to_string())
280            .unwrap();
281        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
282
283        r3.end();
284        r3.save(executor).await.expect("Could not terminate r3");
285        r4.end();
286        r4.save(executor).await.expect("Could not terminate r4");
287
288        // Check [4]
289
290        let mut d1 = Request::new("v1/server/user/del/1".into(), Method::DELETE.into());
291        let d2 = Request::new("v1/server/user/del/2".into(), Method::DELETE.into());
292
293        violated_limits.bind(d1.target.clone()).unwrap();
294        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
295        d1.save(executor).await.expect("Failed to save d1");
296
297        violated_limits.bind(d2.target.clone()).unwrap();
298        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [4] // Problem here
299
300        d1.end();
301        d1.save(executor).await.expect("Failed to end d1");
302
303        violated_limits.bind(d2.target.clone()).unwrap();
304        assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
305
306        let mut v2_reqs = vec![];
307        for i in 0..5 {
308            let mut req = Request::new(format!("v2/resource/{}", i), Method::GET.into());
309            req.save(executor).await.expect("Failed to save v2 req");
310            req.end(); // Must end them to be counted by the interval logic
311            req.save(executor).await.expect("Failed to end v2 req");
312            v2_reqs.push(req);
313        }
314
315        violated_limits.bind("v2/resource/new".to_string()).unwrap();
316        assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); // Violates [5]
317    }
318}