ceiling_macros/
lib.rs

1//! Ceiling is a simple, lightweight, and highly configurable library for handling and creating rate limiting rules.
2//!
3//! The main entrypoint to the library is the `rate_limiter!` macro found below.
4mod generic_input;
5mod group_input;
6mod rate_limiter_input;
7
8use group_input::GroupInput;
9use proc_macro2::TokenStream;
10use quote::quote;
11use rand::distributions::DistString;
12use rate_limiter_input::{RateLimiterInput, Rule};
13use syn::{parse_macro_input, Ident, LitStr, Path, Result};
14
15/// This macro is the entrypoint for creating rate limiting rules with ceiling.
16/// The macro takes input corresponding to the inputs to the rate limiter and the rules.
17///
18/// # Example
19/// ```
20/// ceiling::rate_limiter! {
21///     // takes in three inputs named `ip`, `route`, and `method`
22///     // they must implement `std::fmt::Display` so they can be coerced into strings as needed
23///     ip, route, method in {
24///         // the following creates a public (detailed information is meant to be returned to the client) rate limiting rule named main with a limit of 2 requests every 2 seconds (interval) for the key created by concatenating the ip, route, and method inputs together
25///         // when the rate limit is hit, the timeout specified is 3 seconds from the time of the request that emptied the bucket
26///         main = pub 2 requests every 2 seconds for { ip + route + method } timeout 3 seconds;
27///         // the following only contains the required components of a rate limiting rule
28///         // this one crates a private rate limiting rule with a limit of 3 request every 2 minutes (interval) for the key ip + route
29///         // since timeout is not specified, the bucket will reset when the interval is up
30///         burst = 3 requests every 2 minutes for { ip + route };
31///     // `as RateLimiter` tells the macro to name the generated struct RateLimiter
32///     // `async` says the following custom store is asynchronous
33///     // i.e. implements `ceiling::AsyncStore` instead of `ceiling::SyncStore`
34///     // `in crate::MyAsyncStore` tells the macro to use the struct `crate::MyAsyncStore` for the bucket stores
35///     // specifying a bucket store is not required, if none is provided it will use `ceiling::DefaultStore`
36///     } as RateLimiter async in crate::MyAsyncStore
37/// }
38/// ```
39/// ```
40/// let rate_limiter = RateLimiter::new();
41/// // "hits" the rate limiter, what would happen when someone, for example, makes a request
42/// // the return result is a `bool` (`rate_limiter`) of whether the request is being rate limiter (`true` means it is and should not continue)
43/// // and a `RateLimiterHit` (the name of the struct is rate limiter name + "Hit") struct containing detailed metadata on the state of all the rate limiting rules
44/// // rules can be found by using the name of the rule, i.e. `hit.main` corresponds to the rule named `main`
45/// // the value of a rule's metadata is a tuple of type `(u32, u64, bool, String)` corresponding to the requests remaining, the reset time, whether the rule is public or not, and the key of the bucket
46/// let (rate_limiter, hit) = rate_limiter.hit("1.1.1.1", "/example", "GET").await;
47/// // with the crate feature `serde` enabled, the `hit` object implements `serde::Serialize` and can be easily serialized to any format
48/// // the serialized data will only contain the public rules, the various fields can be found below
49/// // as another option, the hit object has a `to_headers` method that will return a Vec<(&str, String)> corresponding to the header and value
50/// // information on the headers can be found below
51/// let headers = hit.to_headers();
52/// for (header, value) in headers {
53///     response.header(header, value);
54/// }
55/// ```
56///
57/// ## Headers/Metadata Attributes
58/// | Header                  | Attribute     | Description                                                                                     |
59/// | ----------------------- | ------------- | ----------------------------------------------------------------------------------------------- |
60/// | X-RateLimit-Limit       | "limit"       | limit of hits per interval seconds                                                              |
61/// | X-RateLimit-Interval    | "interval"    | interval before bucket resets after first hit                                                   |
62/// | X-RateLimit-Timeout     | "timeout"     | timeout before the bucket resets after limit is reached                                         |
63/// | X-RateLimit-Remaining   | "remaining"   | hits remaining in interval                                                                      |
64/// | X-RateLimit-Reset       | "reset"       | timestamp in seconds when the bucket resets                                                     |
65/// | X-RateLimit-Reset-After | "reset_after" | seconds until bucket resets                                                                     |
66/// | X-RateLimit-Key         | "key"         | the bucket key, may be shared between routes and therefore useful for client-side rate limiting |
67#[proc_macro]
68pub fn rate_limiter(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
69    impl_rate_limiter(parse_macro_input!(input as RateLimiterInput))
70        .unwrap()
71        .into()
72}
73
74fn impl_rate_limiter(
75    RateLimiterInput {
76        inputs,
77        rules,
78        name,
79        store,
80        async_store,
81    }: RateLimiterInput,
82) -> Result<TokenStream> {
83    let name = syn::parse_str::<syn::Ident>(&name)?;
84    let store = syn::parse_str::<Path>(&store.unwrap_or_else(|| "ceiling::DefaultStore".into()))?;
85
86    let input_type_params = inputs
87        .iter()
88        .map(|i| syn::parse_str::<syn::Ident>(format!("{}_IN", i.to_uppercase()).as_str()).unwrap())
89        .collect::<Vec<_>>();
90    let inputs = inputs
91        .iter()
92        .map(|i| syn::parse_str::<syn::Ident>(format!("{i}_input").as_str()).unwrap())
93        .collect::<Vec<_>>();
94    let input_params = inputs
95        .iter()
96        .zip(&input_type_params)
97        .map(|(i, t)| quote!(#i: #t))
98        .collect::<Vec<_>>();
99
100    let hit = syn::parse_str::<syn::Ident>(format!("{}Hit", name).as_str())?;
101
102    let rule_names = rules
103        .iter()
104        .map(|r| syn::parse_str::<syn::Ident>(&r.name).unwrap())
105        .collect::<Vec<_>>();
106    let rule_impls = rules
107        .iter()
108        .map(|r| impl_rule(r, async_store))
109        .collect::<Vec<_>>();
110
111    let num_rules = rules.iter().filter(|r| r.public).count();
112    let num_headers = num_rules * 7;
113
114    let rules_serde = rule_names.iter().zip(&rules).map(|(name, r)| {
115        let Rule {
116            name: _,
117            limit,
118            interval,
119            timeout,
120            key: _,
121            public,
122        } = r;
123        if *public {
124            quote! {
125                let mut m: std::collections::HashMap<&str, Val> = std::collections::HashMap::with_capacity(7);
126                m.insert("limit", #limit.into());
127                m.insert("interval", #interval.into());
128                m.insert("timeout", #timeout.into());
129                m.insert("remaining", self.#name.0.into());
130                m.insert("reset", self.#name.1.into());
131                m.insert("reset_after", (self.#name.1).saturating_sub(now).into());
132                m.insert("key", (&self.#name.3).into());
133                map.serialize_entry(stringify!(self.#name), &m)?;
134            }
135        } else {
136            quote!()
137        }
138    });
139    let rules_headers = rule_names.iter().zip(&rules).map(|(name, r)| {
140        let Rule {
141            name: _,
142            limit,
143            interval,
144            timeout,
145            key: _,
146            public,
147        } = r;
148        if *public {
149            quote! {
150                vec.push(("X-RateLimit-Limit", format!("{} {}", stringify!(#name), #limit)));
151                vec.push(("X-RateLimit-Interval", format!("{} {}", stringify!(#name), #interval)));
152                vec.push(("X-RateLimit-Timeout", format!("{} {}", stringify!(#name), #timeout)));
153                vec.push(("X-RateLimit-Remaining", format!("{} {}", stringify!(#name), self.#name.0)));
154                vec.push(("X-RateLimit-Reset", format!("{} {}", stringify!(#name), self.#name.1)));
155                vec.push(("X-RateLimit-Reset-After", format!("{} {}", stringify!(#name), (self.#name.1).saturating_sub(now))));
156                vec.push(("X-RateLimit-Key", format!("{} {}", stringify!(#name), self.#name.3)));
157            }
158        } else {
159            quote!()
160        }
161    });
162
163    let async_hit = if async_store { quote!(async) } else { quote!() };
164    let use_store = if async_store {
165        quote!(
166            use ceiling::AsyncStore;
167        )
168    } else {
169        quote!(
170            use ceiling::SyncStore;
171        )
172    };
173    Ok(quote! {
174        #[derive(Clone, Debug)]
175        pub struct #name {
176            #(#rule_names: std::sync::Arc<#store>),*
177        }
178
179        impl #name {
180            pub fn new() -> Self {
181                Self {
182                    #(#rule_names: std::sync::Arc::new(#store::new())),*
183                }
184            }
185
186            pub #async_hit fn hit<#(#input_type_params),*>(&self, #(#input_params),*) -> (bool, #hit)
187            where
188                #(#input_type_params: std::fmt::Display),*
189                {
190                    #use_store
191
192                    let now = std::time::SystemTime::now()
193                        .duration_since(std::time::UNIX_EPOCH)
194                        .unwrap()
195                        .as_secs();
196                    let mut hit = false;
197                    #(#rule_impls)*
198                    (hit, #hit {
199                        #(#rule_names),*
200                    })
201                }
202        }
203
204        #[derive(Clone, Debug)]
205        pub struct #hit {
206            pub #(#rule_names: (u32, u64, bool, String)),*
207        }
208
209        impl #hit {
210            pub fn to_headers(&self) -> Vec<(&str, String)> {
211                let now = std::time::SystemTime::now()
212                .duration_since(std::time::UNIX_EPOCH)
213                .unwrap()
214                .as_secs();
215                let mut vec = Vec::with_capacity(#num_headers);
216                #(#rules_headers)*
217                vec
218            }
219        }
220
221        #[cfg(feature = "serde")]
222        impl serde::Serialize for #hit {
223            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
224            where
225                S: serde::Serializer,
226            {
227                use serde::ser::SerializeMap;
228
229                let now = std::time::SystemTime::now()
230                .duration_since(std::time::UNIX_EPOCH)
231                .unwrap()
232                .as_secs();
233                let mut map = serializer.serialize_map(Some(#num_rules))?;
234                #(#rules_serde)*
235                map.end()
236            }
237        }
238
239        #[cfg(feature = "serde")]
240        enum Val {
241            Int(u64),
242            Str(String),
243        }
244
245        #[cfg(feature = "serde")]
246        impl From<u32> for Val {
247            fn from(v: u32) -> Val {
248                Val::Int(v as u64)
249            }
250        }
251
252        #[cfg(feature = "serde")]
253        impl From<u64> for Val {
254            fn from(v: u64) -> Val {
255                Val::Int(v)
256            }
257        }
258
259        #[cfg(feature = "serde")]
260        impl From<&String> for Val {
261            fn from(v: &String) -> Val {
262                Val::Str(v.to_string())
263            }
264        }
265
266        #[cfg(feature = "serde")]
267        impl serde::Serialize for Val {
268            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
269            where
270                S: serde::Serializer,
271            {
272                match self {
273                    Self::Int(v) => serializer.serialize_u64(*v),
274                    Self::Str(v) => serializer.serialize_str(v),
275                }
276            }
277        }
278
279    })
280}
281
282fn impl_rule(rule: &Rule, async_store: bool) -> TokenStream {
283    let Rule {
284        name,
285        limit,
286        interval,
287        timeout,
288        key,
289        public,
290    } = rule;
291    let name = syn::parse_str::<syn::Ident>(name).unwrap();
292    let key = key
293        .iter()
294        .map(|k| syn::parse_str::<syn::Ident>(format!("{k}_input").as_str()).unwrap())
295        .collect::<Vec<_>>();
296    let key = if key.is_empty() {
297        quote!("".to_string())
298    } else {
299        let lit = key.iter().map(|_| "{}").collect::<Vec<_>>().join("+");
300        quote!(format!(#lit, #(#key),*))
301    };
302    let get = if async_store {
303        quote!(self.#name.get(&key).await)
304    } else {
305        quote!(self.#name.get(&key))
306    };
307    let set = if async_store {
308        quote!(self.#name.set(&key, #name, reset_updated).await)
309    } else {
310        quote!(self.#name.set(&key, #name, reset_updated))
311    };
312    let prune = if async_store {
313        quote!(self.#name.prune(now).await)
314    } else {
315        quote!(self.#name.prune(now))
316    };
317    quote! {
318        let #name = {
319            let key = #key;
320            let lock = #get;
321            let mut #name = (*lock).unwrap_or((#limit, now + (#interval as u64)));
322            let mut reset_updated = false;
323            if #name.1 < now {
324                #name = (#limit, now + (#interval as u64));
325                reset_updated = true;
326            }
327            if #name.0 > 1 {
328                #name.0 -= 1;
329                #set;
330            } else if #name.0 == 1 {
331                #name = (0, now + (#timeout as u64));
332                reset_updated = true;
333                #set;
334                hit = true;
335            } else {
336                hit = true;
337            }
338            drop(lock);
339            #prune;
340            (#name.0, #name.1, #public, key)
341        };
342    }
343}
344
345/// `group!` is a utility macro for grouping multiple values into a single key
346///
347/// # Example
348/// ```
349/// // this will generate a function called `bucket` that takes an &str and returns an &str
350/// // if the provided value matches any of the values in the macro it will return a shared bucket key
351/// // i.e. `bucket("/help")` will return the same value as `bucket("/help2")`
352/// // if no matches are found, then it will return the value provided
353/// ceiling::group! {
354///     bucket {
355///         "/help", "/help2", "/help3";
356///         "/one", "/two";
357///     }
358/// }
359/// ```
360#[proc_macro]
361pub fn group(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
362    impl_group(parse_macro_input!(input as GroupInput))
363        .unwrap()
364        .into()
365}
366
367fn impl_group(GroupInput { name, groups }: GroupInput) -> Result<TokenStream> {
368    let groups = groups.into_iter().enumerate().map(|(i, g)| {
369        let s = syn::parse_str::<LitStr>(
370            format!(
371                "\"__{}-{}\"",
372                i,
373                rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 20),
374            )
375            .as_str(),
376        )
377        .unwrap();
378        quote! {
379            #(
380                #g => #s,
381            )*
382        }
383    });
384    let name = syn::parse_str::<Ident>(&name)?;
385    let gen = quote! {
386        fn #name(value: &str) -> &str {
387            match value {
388                #( #groups )*
389                _ => value
390            }
391        }
392    };
393    Ok(gen)
394}