1#![allow(unused_imports)]
2use std::{
3 borrow::Cow,
4 sync::{
5 LazyLock,
6 atomic::{AtomicUsize, Ordering},
7 },
8 time::{SystemTime, UNIX_EPOCH},
9};
10use tank::{
11 AsValue, Entity, Error, Executor, QueryBuilder, Result, Value, current_timestamp_ms, expr,
12 join,
13 stream::{StreamExt, TryStreamExt},
14};
15use tokio::sync::Mutex;
16
17static MUTEX: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
18
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub enum Method {
21 GET,
22 POST,
23 PUT,
24 DELETE,
25}
26impl AsValue for Method {
27 fn as_empty_value() -> Value {
28 Value::Varchar(None)
29 }
30 fn as_value(self) -> Value {
31 Value::Varchar(Some(
32 match self {
33 Method::GET => "get",
34 Method::POST => "post",
35 Method::PUT => "put",
36 Method::DELETE => "delete",
37 }
38 .into(),
39 ))
40 }
41 fn try_from_value(value: Value) -> Result<Self>
42 where
43 Self: Sized,
44 {
45 if let Value::Varchar(Some(v)) = value.try_as(&String::as_empty_value())? {
46 match &*v {
47 "get" => return Ok(Method::GET),
48 "post" => return Ok(Method::POST),
49 "put" => return Ok(Method::PUT),
50 "delete" => return Ok(Method::DELETE),
51 _ => {
52 return Err(Error::msg(format!(
53 "Unexpected value `{v}` for Method enum"
54 )));
55 }
56 }
57 }
58 Err(Error::msg("Unexpected value for Method enum"))
59 }
60}
61
62#[derive(Default, Entity, PartialEq, Eq)]
63#[tank(schema = "api")]
64struct RequestLimit {
65 #[tank(primary_key)]
66 pub id: i32,
67 pub target_pattern: Cow<'static, str>,
68 pub requests: i32,
69 pub method: Option<Method>,
71 pub interval_ms: Option<i32>,
73}
74impl RequestLimit {
75 pub fn new(
76 target_pattern: &'static str,
77 requests: i32,
78 method: Option<Method>,
79 interval_ms: Option<i32>,
80 ) -> Self {
81 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
82 Self {
83 id,
84 target_pattern: target_pattern.into(),
85 requests,
86 method,
87 interval_ms,
88 }
89 }
90}
91
92#[derive(Entity, PartialEq, Eq)]
93#[tank(schema = "api")]
94pub struct Request {
95 #[tank(primary_key)]
96 pub id: i64,
97 pub target: String,
98 pub method: Option<Method>,
99 pub beign_timestamp_ms: i64,
100 pub end_timestamp_ms: Option<i64>,
101}
102
103static GLOBAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
104
105impl Request {
106 pub fn new(target: String, method: Option<Method>) -> Self {
107 let id = GLOBAL_COUNTER.fetch_add(1, Ordering::Relaxed) as _;
108 Self {
109 id,
110 target,
111 method,
112 beign_timestamp_ms: SystemTime::now()
113 .duration_since(UNIX_EPOCH)
114 .unwrap()
115 .as_millis() as _,
116 end_timestamp_ms: None,
117 }
118 }
119 pub fn end(&mut self) {
120 self.end_timestamp_ms = Some(
121 SystemTime::now()
122 .duration_since(UNIX_EPOCH)
123 .unwrap()
124 .as_millis() as _,
125 );
126 }
127}
128
129pub async fn requests(executor: &mut impl Executor) {
130 let _lock = MUTEX.lock();
131
132 RequestLimit::drop_table(executor, true, false)
134 .await
135 .expect("Could not drop the RequestLimit table");
136 Request::drop_table(executor, true, false)
137 .await
138 .expect("Could not drop the Request table");
139
140 RequestLimit::create_table(executor, false, true)
141 .await
142 .expect("Could not create the RequestLimit table");
143 Request::create_table(executor, false, true)
144 .await
145 .expect("Could not create the Request table");
146
147 RequestLimit::insert_many(
149 executor,
150 &[
151 RequestLimit::new("v1/%", 3, None, None),
153 RequestLimit::new("v1/server/data/%", 5, None, None),
155 RequestLimit::new("v1/server/user/%", 2, Method::PUT.into(), None),
157 RequestLimit::new("v1/server/user/%", 1, Method::DELETE.into(), None),
159 RequestLimit::new("v2/%", 5, None, 60_000.into()),
161 ],
162 )
163 .await
164 .expect("Could not insert the limits");
165 let limits = RequestLimit::find_many(executor, true, None)
166 .map_err(|e| panic!("{e:#}"))
167 .count()
168 .await;
169 assert_eq!(limits, 5);
170
171 #[cfg(not(feature = "disable-joins"))]
172 {
173 let mut violated_limits = executor
174 .prepare(
175 QueryBuilder::new()
176 .select([
177 RequestLimit::target_pattern,
178 RequestLimit::requests,
179 RequestLimit::method,
180 RequestLimit::interval_ms,
181 ])
182 .from(join!(RequestLimit CROSS JOIN Request))
183 .where_expr(expr!(
184 ? == RequestLimit::target_pattern as LIKE
185 && Request::target == RequestLimit::target_pattern as LIKE
186 && (RequestLimit::method == NULL
187 || RequestLimit::method == Request::method)
188 && (RequestLimit::interval_ms == NULL && Request::end_timestamp_ms == NULL
189 || RequestLimit::interval_ms != NULL
190 && Request::end_timestamp_ms
191 >= current_timestamp_ms!() - RequestLimit::interval_ms)
192 ))
193 .group_by([
194 RequestLimit::target_pattern,
195 RequestLimit::requests,
196 RequestLimit::method,
197 RequestLimit::interval_ms,
198 ])
199 .having(expr!(COUNT(Request::id) >= RequestLimit::requests))
200 .build(&executor.driver()),
201 )
202 .await
203 .expect("Failed to prepare the limit query");
204
205 let mut r1 = Request::new("v1/server/user/new/1".into(), Method::PUT.into());
206 let mut r2 = Request::new("v1/server/user/new/2".into(), Method::PUT.into());
207 let mut r3 = Request::new("v1/server/user/new/3".into(), Method::PUT.into());
208 let mut r4 = Request::new("v1/server/articles/new/4".into(), Method::PUT.into());
209 let r5 = Request::new("v1/server/user/new/5".into(), Method::PUT.into());
210
211 violated_limits.bind(r1.target.clone()).unwrap();
212 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
213 r1.save(executor).await.expect("Failed to save r1");
214
215 violated_limits.bind(r2.target.clone()).unwrap();
216 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
217 r2.save(executor).await.expect("Failed to save r2");
218
219 violated_limits.bind(r3.target.clone()).unwrap();
220 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); violated_limits.bind(r4.target.clone()).unwrap();
224 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
225 r4.save(executor).await.expect("Failed to save r4");
226
227 violated_limits.bind(r5.target.clone()).unwrap();
228 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); r1.end();
231 r1.save(executor).await.expect("Could not terminate r1");
232
233 violated_limits.bind(r3.target.clone()).unwrap();
234 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
235 r3.save(executor).await.expect("Failed to save r3");
236
237 let mut data_reqs = vec![];
240 for i in 0..5 {
241 let req = Request::new(format!("v1/server/data/item/{}", i), None);
242 req.save(executor)
243 .await
244 .expect("Failed to save data request");
245 data_reqs.push(req);
246 }
247
248 violated_limits
251 .bind("v1/server/data/item/999".to_string())
252 .unwrap();
253 assert_eq!(executor.fetch(&mut violated_limits).count().await, 2); for i in 0..4 {
256 data_reqs[i].end();
257 data_reqs[i]
258 .save(executor)
259 .await
260 .expect(&format!("Failed to save data_reqs[{i}]"));
261 }
262
263 violated_limits
266 .bind("v1/server/data/item/999".to_string())
267 .unwrap();
268 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); r2.end();
271 r2.save(executor).await.expect("Could not terminate r2");
272 data_reqs[4].end();
273 data_reqs[4]
274 .save(executor)
275 .await
276 .expect("Could not terminate data_reqs[4]");
277
278 violated_limits
279 .bind("v1/server/data/item/999".to_string())
280 .unwrap();
281 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
282
283 r3.end();
284 r3.save(executor).await.expect("Could not terminate r3");
285 r4.end();
286 r4.save(executor).await.expect("Could not terminate r4");
287
288 let mut d1 = Request::new("v1/server/user/del/1".into(), Method::DELETE.into());
291 let d2 = Request::new("v1/server/user/del/2".into(), Method::DELETE.into());
292
293 violated_limits.bind(d1.target.clone()).unwrap();
294 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
295 d1.save(executor).await.expect("Failed to save d1");
296
297 violated_limits.bind(d2.target.clone()).unwrap();
298 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); d1.end();
301 d1.save(executor).await.expect("Failed to end d1");
302
303 violated_limits.bind(d2.target.clone()).unwrap();
304 assert_eq!(executor.fetch(&mut violated_limits).count().await, 0);
305
306 let mut v2_reqs = vec![];
307 for i in 0..5 {
308 let mut req = Request::new(format!("v2/resource/{}", i), Method::GET.into());
309 req.save(executor).await.expect("Failed to save v2 req");
310 req.end(); req.save(executor).await.expect("Failed to end v2 req");
312 v2_reqs.push(req);
313 }
314
315 violated_limits.bind("v2/resource/new".to_string()).unwrap();
316 assert_eq!(executor.fetch(&mut violated_limits).count().await, 1); }
318}