ceiling_macros/lib.rs
1//! Ceiling is a simple, lightweight, and highly configurable library for handling and creating rate limiting rules.
2//!
3//! The main entrypoint to the library is the `rate_limiter!` macro found below.
4mod generic_input;
5mod group_input;
6mod rate_limiter_input;
7
8use group_input::GroupInput;
9use proc_macro2::TokenStream;
10use quote::quote;
11use rand::distributions::DistString;
12use rate_limiter_input::{RateLimiterInput, Rule};
13use syn::{parse_macro_input, Ident, LitStr, Path, Result};
14
15/// This macro is the entrypoint for creating rate limiting rules with ceiling.
16/// The macro takes input corresponding to the inputs to the rate limiter and the rules.
17///
18/// # Example
19/// ```
20/// ceiling::rate_limiter! {
21/// // takes in three inputs named `ip`, `route`, and `method`
22/// // they must implement `std::fmt::Display` so they can be coerced into strings as needed
23/// ip, route, method in {
24/// // the following creates a public (detailed information is meant to be returned to the client) rate limiting rule named main with a limit of 2 requests every 2 seconds (interval) for the key created by concatenating the ip, route, and method inputs together
25/// // when the rate limit is hit, the timeout specified is 3 seconds from the time of the request that emptied the bucket
26/// main = pub 2 requests every 2 seconds for { ip + route + method } timeout 3 seconds;
27/// // the following only contains the required components of a rate limiting rule
28/// // this one crates a private rate limiting rule with a limit of 3 request every 2 minutes (interval) for the key ip + route
29/// // since timeout is not specified, the bucket will reset when the interval is up
30/// burst = 3 requests every 2 minutes for { ip + route };
31/// // `as RateLimiter` tells the macro to name the generated struct RateLimiter
32/// // `async` says the following custom store is asynchronous
33/// // i.e. implements `ceiling::AsyncStore` instead of `ceiling::SyncStore`
34/// // `in crate::MyAsyncStore` tells the macro to use the struct `crate::MyAsyncStore` for the bucket stores
35/// // specifying a bucket store is not required, if none is provided it will use `ceiling::DefaultStore`
36/// } as RateLimiter async in crate::MyAsyncStore
37/// }
38/// ```
39/// ```
40/// let rate_limiter = RateLimiter::new();
41/// // "hits" the rate limiter, what would happen when someone, for example, makes a request
42/// // the return result is a `bool` (`rate_limiter`) of whether the request is being rate limiter (`true` means it is and should not continue)
43/// // and a `RateLimiterHit` (the name of the struct is rate limiter name + "Hit") struct containing detailed metadata on the state of all the rate limiting rules
44/// // rules can be found by using the name of the rule, i.e. `hit.main` corresponds to the rule named `main`
45/// // the value of a rule's metadata is a tuple of type `(u32, u64, bool, String)` corresponding to the requests remaining, the reset time, whether the rule is public or not, and the key of the bucket
46/// let (rate_limiter, hit) = rate_limiter.hit("1.1.1.1", "/example", "GET").await;
47/// // with the crate feature `serde` enabled, the `hit` object implements `serde::Serialize` and can be easily serialized to any format
48/// // the serialized data will only contain the public rules, the various fields can be found below
49/// // as another option, the hit object has a `to_headers` method that will return a Vec<(&str, String)> corresponding to the header and value
50/// // information on the headers can be found below
51/// let headers = hit.to_headers();
52/// for (header, value) in headers {
53/// response.header(header, value);
54/// }
55/// ```
56///
57/// ## Headers/Metadata Attributes
58/// | Header | Attribute | Description |
59/// | ----------------------- | ------------- | ----------------------------------------------------------------------------------------------- |
60/// | X-RateLimit-Limit | "limit" | limit of hits per interval seconds |
61/// | X-RateLimit-Interval | "interval" | interval before bucket resets after first hit |
62/// | X-RateLimit-Timeout | "timeout" | timeout before the bucket resets after limit is reached |
63/// | X-RateLimit-Remaining | "remaining" | hits remaining in interval |
64/// | X-RateLimit-Reset | "reset" | timestamp in seconds when the bucket resets |
65/// | X-RateLimit-Reset-After | "reset_after" | seconds until bucket resets |
66/// | X-RateLimit-Key | "key" | the bucket key, may be shared between routes and therefore useful for client-side rate limiting |
67#[proc_macro]
68pub fn rate_limiter(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
69 impl_rate_limiter(parse_macro_input!(input as RateLimiterInput))
70 .unwrap()
71 .into()
72}
73
74fn impl_rate_limiter(
75 RateLimiterInput {
76 inputs,
77 rules,
78 name,
79 store,
80 async_store,
81 }: RateLimiterInput,
82) -> Result<TokenStream> {
83 let name = syn::parse_str::<syn::Ident>(&name)?;
84 let store = syn::parse_str::<Path>(&store.unwrap_or_else(|| "ceiling::DefaultStore".into()))?;
85
86 let input_type_params = inputs
87 .iter()
88 .map(|i| syn::parse_str::<syn::Ident>(format!("{}_IN", i.to_uppercase()).as_str()).unwrap())
89 .collect::<Vec<_>>();
90 let inputs = inputs
91 .iter()
92 .map(|i| syn::parse_str::<syn::Ident>(format!("{i}_input").as_str()).unwrap())
93 .collect::<Vec<_>>();
94 let input_params = inputs
95 .iter()
96 .zip(&input_type_params)
97 .map(|(i, t)| quote!(#i: #t))
98 .collect::<Vec<_>>();
99
100 let hit = syn::parse_str::<syn::Ident>(format!("{}Hit", name).as_str())?;
101
102 let rule_names = rules
103 .iter()
104 .map(|r| syn::parse_str::<syn::Ident>(&r.name).unwrap())
105 .collect::<Vec<_>>();
106 let rule_impls = rules
107 .iter()
108 .map(|r| impl_rule(r, async_store))
109 .collect::<Vec<_>>();
110
111 let num_rules = rules.iter().filter(|r| r.public).count();
112 let num_headers = num_rules * 7;
113
114 let rules_serde = rule_names.iter().zip(&rules).map(|(name, r)| {
115 let Rule {
116 name: _,
117 limit,
118 interval,
119 timeout,
120 key: _,
121 public,
122 } = r;
123 if *public {
124 quote! {
125 let mut m: std::collections::HashMap<&str, Val> = std::collections::HashMap::with_capacity(7);
126 m.insert("limit", #limit.into());
127 m.insert("interval", #interval.into());
128 m.insert("timeout", #timeout.into());
129 m.insert("remaining", self.#name.0.into());
130 m.insert("reset", self.#name.1.into());
131 m.insert("reset_after", (self.#name.1).saturating_sub(now).into());
132 m.insert("key", (&self.#name.3).into());
133 map.serialize_entry(stringify!(self.#name), &m)?;
134 }
135 } else {
136 quote!()
137 }
138 });
139 let rules_headers = rule_names.iter().zip(&rules).map(|(name, r)| {
140 let Rule {
141 name: _,
142 limit,
143 interval,
144 timeout,
145 key: _,
146 public,
147 } = r;
148 if *public {
149 quote! {
150 vec.push(("X-RateLimit-Limit", format!("{} {}", stringify!(#name), #limit)));
151 vec.push(("X-RateLimit-Interval", format!("{} {}", stringify!(#name), #interval)));
152 vec.push(("X-RateLimit-Timeout", format!("{} {}", stringify!(#name), #timeout)));
153 vec.push(("X-RateLimit-Remaining", format!("{} {}", stringify!(#name), self.#name.0)));
154 vec.push(("X-RateLimit-Reset", format!("{} {}", stringify!(#name), self.#name.1)));
155 vec.push(("X-RateLimit-Reset-After", format!("{} {}", stringify!(#name), (self.#name.1).saturating_sub(now))));
156 vec.push(("X-RateLimit-Key", format!("{} {}", stringify!(#name), self.#name.3)));
157 }
158 } else {
159 quote!()
160 }
161 });
162
163 let async_hit = if async_store { quote!(async) } else { quote!() };
164 let use_store = if async_store {
165 quote!(
166 use ceiling::AsyncStore;
167 )
168 } else {
169 quote!(
170 use ceiling::SyncStore;
171 )
172 };
173 Ok(quote! {
174 #[derive(Clone, Debug)]
175 pub struct #name {
176 #(#rule_names: std::sync::Arc<#store>),*
177 }
178
179 impl #name {
180 pub fn new() -> Self {
181 Self {
182 #(#rule_names: std::sync::Arc::new(#store::new())),*
183 }
184 }
185
186 pub #async_hit fn hit<#(#input_type_params),*>(&self, #(#input_params),*) -> (bool, #hit)
187 where
188 #(#input_type_params: std::fmt::Display),*
189 {
190 #use_store
191
192 let now = std::time::SystemTime::now()
193 .duration_since(std::time::UNIX_EPOCH)
194 .unwrap()
195 .as_secs();
196 let mut hit = false;
197 #(#rule_impls)*
198 (hit, #hit {
199 #(#rule_names),*
200 })
201 }
202 }
203
204 #[derive(Clone, Debug)]
205 pub struct #hit {
206 pub #(#rule_names: (u32, u64, bool, String)),*
207 }
208
209 impl #hit {
210 pub fn to_headers(&self) -> Vec<(&str, String)> {
211 let now = std::time::SystemTime::now()
212 .duration_since(std::time::UNIX_EPOCH)
213 .unwrap()
214 .as_secs();
215 let mut vec = Vec::with_capacity(#num_headers);
216 #(#rules_headers)*
217 vec
218 }
219 }
220
221 #[cfg(feature = "serde")]
222 impl serde::Serialize for #hit {
223 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
224 where
225 S: serde::Serializer,
226 {
227 use serde::ser::SerializeMap;
228
229 let now = std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .unwrap()
232 .as_secs();
233 let mut map = serializer.serialize_map(Some(#num_rules))?;
234 #(#rules_serde)*
235 map.end()
236 }
237 }
238
239 #[cfg(feature = "serde")]
240 enum Val {
241 Int(u64),
242 Str(String),
243 }
244
245 #[cfg(feature = "serde")]
246 impl From<u32> for Val {
247 fn from(v: u32) -> Val {
248 Val::Int(v as u64)
249 }
250 }
251
252 #[cfg(feature = "serde")]
253 impl From<u64> for Val {
254 fn from(v: u64) -> Val {
255 Val::Int(v)
256 }
257 }
258
259 #[cfg(feature = "serde")]
260 impl From<&String> for Val {
261 fn from(v: &String) -> Val {
262 Val::Str(v.to_string())
263 }
264 }
265
266 #[cfg(feature = "serde")]
267 impl serde::Serialize for Val {
268 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
269 where
270 S: serde::Serializer,
271 {
272 match self {
273 Self::Int(v) => serializer.serialize_u64(*v),
274 Self::Str(v) => serializer.serialize_str(v),
275 }
276 }
277 }
278
279 })
280}
281
282fn impl_rule(rule: &Rule, async_store: bool) -> TokenStream {
283 let Rule {
284 name,
285 limit,
286 interval,
287 timeout,
288 key,
289 public,
290 } = rule;
291 let name = syn::parse_str::<syn::Ident>(name).unwrap();
292 let key = key
293 .iter()
294 .map(|k| syn::parse_str::<syn::Ident>(format!("{k}_input").as_str()).unwrap())
295 .collect::<Vec<_>>();
296 let key = if key.is_empty() {
297 quote!("".to_string())
298 } else {
299 let lit = key.iter().map(|_| "{}").collect::<Vec<_>>().join("+");
300 quote!(format!(#lit, #(#key),*))
301 };
302 let get = if async_store {
303 quote!(self.#name.get(&key).await)
304 } else {
305 quote!(self.#name.get(&key))
306 };
307 let set = if async_store {
308 quote!(self.#name.set(&key, #name, reset_updated).await)
309 } else {
310 quote!(self.#name.set(&key, #name, reset_updated))
311 };
312 let prune = if async_store {
313 quote!(self.#name.prune(now).await)
314 } else {
315 quote!(self.#name.prune(now))
316 };
317 quote! {
318 let #name = {
319 let key = #key;
320 let lock = #get;
321 let mut #name = (*lock).unwrap_or((#limit, now + (#interval as u64)));
322 let mut reset_updated = false;
323 if #name.1 < now {
324 #name = (#limit, now + (#interval as u64));
325 reset_updated = true;
326 }
327 if #name.0 > 1 {
328 #name.0 -= 1;
329 #set;
330 } else if #name.0 == 1 {
331 #name = (0, now + (#timeout as u64));
332 reset_updated = true;
333 #set;
334 hit = true;
335 } else {
336 hit = true;
337 }
338 drop(lock);
339 #prune;
340 (#name.0, #name.1, #public, key)
341 };
342 }
343}
344
345/// `group!` is a utility macro for grouping multiple values into a single key
346///
347/// # Example
348/// ```
349/// // this will generate a function called `bucket` that takes an &str and returns an &str
350/// // if the provided value matches any of the values in the macro it will return a shared bucket key
351/// // i.e. `bucket("/help")` will return the same value as `bucket("/help2")`
352/// // if no matches are found, then it will return the value provided
353/// ceiling::group! {
354/// bucket {
355/// "/help", "/help2", "/help3";
356/// "/one", "/two";
357/// }
358/// }
359/// ```
360#[proc_macro]
361pub fn group(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
362 impl_group(parse_macro_input!(input as GroupInput))
363 .unwrap()
364 .into()
365}
366
367fn impl_group(GroupInput { name, groups }: GroupInput) -> Result<TokenStream> {
368 let groups = groups.into_iter().enumerate().map(|(i, g)| {
369 let s = syn::parse_str::<LitStr>(
370 format!(
371 "\"__{}-{}\"",
372 i,
373 rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 20),
374 )
375 .as_str(),
376 )
377 .unwrap();
378 quote! {
379 #(
380 #g => #s,
381 )*
382 }
383 });
384 let name = syn::parse_str::<Ident>(&name)?;
385 let gen = quote! {
386 fn #name(value: &str) -> &str {
387 match value {
388 #( #groups )*
389 _ => value
390 }
391 }
392 };
393 Ok(gen)
394}