1#![cfg_attr(not(feature = "std"), no_std)]
16#![allow(clippy::unused_unit)]
17
18use frame_support::{pallet_prelude::*, traits::UnixTime, transactional, BoundedVec};
19use frame_system::pallet_prelude::*;
20use orml_traits::{RateLimiter, RateLimiterError};
21use orml_utilities::OrderedSet;
22use parity_scale_codec::MaxEncodedLen;
23use scale_info::TypeInfo;
24use sp_runtime::traits::{BlockNumberProvider, SaturatedConversion, Zero};
25use sp_std::{prelude::*, vec::Vec};
26
27pub use module::*;
28pub use weights::WeightInfo;
29
30mod mock;
31mod tests;
32pub mod weights;
33
34#[frame_support::pallet]
35pub mod module {
36 use super::*;
37
38 #[derive(PartialEq, Eq, Clone, Encode, Decode, RuntimeDebug, TypeInfo, DecodeWithMemTracking)]
40 pub enum Period {
41 Blocks(u64),
42 Seconds(u64),
43 }
44
45 #[derive(PartialEq, Eq, Clone, Encode, Decode, RuntimeDebug, TypeInfo, DecodeWithMemTracking)]
47 pub enum RateLimitRule {
48 PerPeriod { period: Period, quota: u128 },
52 TokenBucket {
57 period: Period,
58 quota_increment: u128,
59 max_quota: u128,
60 },
61 Unlimited,
63 NotAllowed,
65 }
66
67 pub const MAX_FILTER_KEY_LENGTH: u32 = 256;
69
70 #[derive(
72 PartialOrd,
73 Ord,
74 PartialEq,
75 Eq,
76 Clone,
77 Encode,
78 Decode,
79 RuntimeDebug,
80 TypeInfo,
81 MaxEncodedLen,
82 DecodeWithMemTracking,
83 )]
84 pub enum KeyFilter {
85 Match(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
87 StartsWith(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
89 EndsWith(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
91 }
92
93 #[pallet::config]
94 pub trait Config: frame_system::Config {
95 type RuntimeEvent: From<Event<Self>> + IsType<<Self as frame_system::Config>::RuntimeEvent>;
96
97 type GovernanceOrigin: EnsureOrigin<Self::RuntimeOrigin>;
99
100 type RateLimiterId: Parameter + Member + Copy + TypeInfo;
101
102 #[pallet::constant]
104 type MaxWhitelistFilterCount: Get<u32>;
105
106 type UnixTime: UnixTime;
108
109 type BlockNumberProvider: BlockNumberProvider<BlockNumber = BlockNumberFor<Self>>;
111
112 type WeightInfo: WeightInfo;
114 }
115
116 #[pallet::error]
117 pub enum Error<T> {
118 InvalidRateLimitRule,
120 FilterExisted,
122 FilterNotExisted,
124 MaxFilterExceeded,
127 }
128
129 #[pallet::event]
130 #[pallet::generate_deposit(pub(crate) fn deposit_event)]
131 pub enum Event<T: Config> {
132 RateLimitRuleUpdated {
134 rate_limiter_id: T::RateLimiterId,
135 encoded_key: Vec<u8>,
136 update: Option<RateLimitRule>,
137 },
138 WhitelistFilterAdded { rate_limiter_id: T::RateLimiterId },
140 WhitelistFilterRemoved { rate_limiter_id: T::RateLimiterId },
142 WhitelistFilterReset { rate_limiter_id: T::RateLimiterId },
144 }
145
146 #[pallet::storage]
150 #[pallet::getter(fn rate_limit_rules)]
151 pub type RateLimitRules<T: Config> =
152 StorageDoubleMap<_, Twox64Concat, T::RateLimiterId, Blake2_128Concat, Vec<u8>, RateLimitRule, OptionQuery>;
153
154 #[pallet::storage]
159 #[pallet::getter(fn rate_limit_quota)]
160 pub type RateLimitQuota<T: Config> =
161 StorageDoubleMap<_, Twox64Concat, T::RateLimiterId, Blake2_128Concat, Vec<u8>, (u64, u128), ValueQuery>;
162
163 #[pallet::storage]
167 #[pallet::getter(fn limit_whitelist)]
168 pub type LimitWhitelist<T: Config> =
169 StorageMap<_, Twox64Concat, T::RateLimiterId, OrderedSet<KeyFilter, T::MaxWhitelistFilterCount>, ValueQuery>;
170
171 #[pallet::pallet]
172 #[pallet::without_storage_info]
173 pub struct Pallet<T>(_);
174
175 #[pallet::hooks]
176 impl<T: Config> Hooks<BlockNumberFor<T>> for Pallet<T> {}
177
178 #[pallet::call]
179 impl<T: Config> Pallet<T> {
180 #[pallet::call_index(0)]
190 #[pallet::weight(T::WeightInfo::update_rate_limit_rule())]
191 #[transactional]
192 pub fn update_rate_limit_rule(
193 origin: OriginFor<T>,
194 rate_limiter_id: T::RateLimiterId,
195 encoded_key: Vec<u8>,
196 update: Option<RateLimitRule>,
197 ) -> DispatchResult {
198 T::GovernanceOrigin::ensure_origin(origin)?;
199
200 RateLimitRules::<T>::try_mutate_exists(rate_limiter_id, &encoded_key, |maybe_limit| -> DispatchResult {
201 *maybe_limit = update.clone();
202
203 if let Some(rule) = maybe_limit {
204 match rule {
205 RateLimitRule::PerPeriod { period, quota } => {
206 match period {
207 Period::Blocks(blocks_count) => {
208 ensure!(!blocks_count.is_zero(), Error::<T>::InvalidRateLimitRule);
209 }
210 Period::Seconds(secs_count) => {
211 ensure!(!secs_count.is_zero(), Error::<T>::InvalidRateLimitRule);
212 }
213 }
214
215 ensure!(!quota.is_zero(), Error::<T>::InvalidRateLimitRule);
216 }
217 RateLimitRule::TokenBucket {
218 period,
219 quota_increment,
220 max_quota,
221 } => {
222 match period {
223 Period::Blocks(blocks_count) => {
224 ensure!(!blocks_count.is_zero(), Error::<T>::InvalidRateLimitRule);
225 }
226 Period::Seconds(secs_count) => {
227 ensure!(!secs_count.is_zero(), Error::<T>::InvalidRateLimitRule);
228 }
229 }
230
231 ensure!(
232 !quota_increment.is_zero() && !max_quota.is_zero(),
233 Error::<T>::InvalidRateLimitRule
234 );
235 }
236 RateLimitRule::Unlimited => {}
237 RateLimitRule::NotAllowed => {}
238 }
239 }
240
241 RateLimitQuota::<T>::remove(rate_limiter_id, &encoded_key);
243
244 Self::deposit_event(Event::RateLimitRuleUpdated {
245 rate_limiter_id,
246 encoded_key: encoded_key.clone(),
247 update,
248 });
249
250 Ok(())
251 })
252 }
253
254 #[pallet::call_index(1)]
262 #[pallet::weight(T::WeightInfo::add_whitelist())]
263 #[transactional]
264 pub fn add_whitelist(
265 origin: OriginFor<T>,
266 rate_limiter_id: T::RateLimiterId,
267 key_filter: KeyFilter,
268 ) -> DispatchResult {
269 T::GovernanceOrigin::ensure_origin(origin)?;
270
271 LimitWhitelist::<T>::try_mutate(rate_limiter_id, |whitelist| -> DispatchResult {
272 ensure!(!whitelist.contains(&key_filter), Error::<T>::FilterExisted);
273 let inserted = whitelist.insert(key_filter);
274 ensure!(inserted, Error::<T>::MaxFilterExceeded);
275
276 Self::deposit_event(Event::WhitelistFilterAdded { rate_limiter_id });
277 Ok(())
278 })
279 }
280
281 #[pallet::call_index(2)]
289 #[pallet::weight(T::WeightInfo::remove_whitelist())]
290 #[transactional]
291 pub fn remove_whitelist(
292 origin: OriginFor<T>,
293 rate_limiter_id: T::RateLimiterId,
294 key_filter: KeyFilter,
295 ) -> DispatchResult {
296 T::GovernanceOrigin::ensure_origin(origin)?;
297
298 LimitWhitelist::<T>::try_mutate(rate_limiter_id, |whitelist| -> DispatchResult {
299 ensure!(whitelist.contains(&key_filter), Error::<T>::FilterNotExisted);
300 whitelist.remove(&key_filter);
301
302 Self::deposit_event(Event::WhitelistFilterRemoved { rate_limiter_id });
303 Ok(())
304 })
305 }
306
307 #[pallet::call_index(3)]
315 #[pallet::weight(T::WeightInfo::reset_whitelist())]
316 #[transactional]
317 pub fn reset_whitelist(
318 origin: OriginFor<T>,
319 rate_limiter_id: T::RateLimiterId,
320 new_list: Vec<KeyFilter>,
321 ) -> DispatchResult {
322 T::GovernanceOrigin::ensure_origin(origin)?;
323
324 let whitelist: BoundedVec<KeyFilter, T::MaxWhitelistFilterCount> =
325 BoundedVec::try_from(new_list).map_err(|_| Error::<T>::MaxFilterExceeded)?;
326 let ordered_set: OrderedSet<KeyFilter, T::MaxWhitelistFilterCount> = whitelist.into();
327 LimitWhitelist::<T>::insert(rate_limiter_id, ordered_set);
328
329 Self::deposit_event(Event::WhitelistFilterReset { rate_limiter_id });
330 Ok(())
331 }
332 }
333
334 impl<T: Config> Pallet<T> {
335 pub fn access_remainer_quota_after_update(
338 rate_limit_rule: RateLimitRule,
339 limiter_id: &T::RateLimiterId,
340 encoded_key: &Vec<u8>,
341 ) -> u128 {
342 RateLimitQuota::<T>::mutate(limiter_id, encoded_key, |(last_updated, remainer_quota)| -> u128 {
343 match rate_limit_rule {
344 RateLimitRule::PerPeriod { period, quota } => {
345 let (now, count): (u64, u64) = match period {
346 Period::Blocks(blocks_count) => (
347 T::BlockNumberProvider::current_block_number().saturated_into(),
348 blocks_count,
349 ),
350 Period::Seconds(secs_count) => (T::UnixTime::now().as_secs(), secs_count),
351 };
352
353 let interval: u64 = now.saturating_sub(*last_updated);
354 if interval >= count {
355 *last_updated = now;
356 *remainer_quota = quota;
357 }
358 }
359
360 RateLimitRule::TokenBucket {
361 period,
362 quota_increment,
363 max_quota,
364 } => {
365 let (now, count): (u64, u64) = match period {
366 Period::Blocks(blocks_count) => (
367 T::BlockNumberProvider::current_block_number().saturated_into(),
368 blocks_count,
369 ),
370 Period::Seconds(secs_count) => (T::UnixTime::now().as_secs(), secs_count),
371 };
372
373 let interval: u64 = now.saturating_sub(*last_updated);
374 if !count.is_zero() && interval >= count {
375 let inc_times: u128 = interval
376 .checked_div(count)
377 .expect("already ensure count is not zero; qed")
378 .saturated_into();
379
380 *last_updated = now;
381 *remainer_quota = quota_increment
382 .saturating_mul(inc_times)
383 .saturating_add(*remainer_quota)
384 .min(max_quota);
385 }
386 }
387
388 RateLimitRule::Unlimited | RateLimitRule::NotAllowed => {}
389 }
390
391 *remainer_quota
392 })
393 }
394 }
395
396 impl<T: Config> RateLimiter for Pallet<T> {
397 type RateLimiterId = T::RateLimiterId;
398
399 fn is_whitelist(limiter_id: Self::RateLimiterId, key: impl Encode) -> bool {
400 let encode_key: Vec<u8> = key.encode();
401
402 for key_filter in LimitWhitelist::<T>::get(limiter_id).0 {
403 match key_filter {
404 KeyFilter::Match(bounded_vec) => {
405 if encode_key == bounded_vec.into_inner() {
406 return true;
407 }
408 }
409 KeyFilter::StartsWith(prefix) => {
410 if encode_key.starts_with(&prefix) {
411 return true;
412 }
413 }
414 KeyFilter::EndsWith(postfix) => {
415 if encode_key.ends_with(&postfix) {
416 return true;
417 }
418 }
419 }
420 }
421
422 false
423 }
424
425 fn can_consume(limiter_id: Self::RateLimiterId, key: impl Encode, value: u128) -> Result<(), RateLimiterError> {
426 let encoded_key: Vec<u8> = key.encode();
427
428 let allowed = match RateLimitRules::<T>::get(limiter_id, &encoded_key) {
429 Some(rate_limit_rule @ RateLimitRule::PerPeriod { .. })
430 | Some(rate_limit_rule @ RateLimitRule::TokenBucket { .. }) => {
431 let remainer_quota =
432 Self::access_remainer_quota_after_update(rate_limit_rule, &limiter_id, &encoded_key);
433
434 value <= remainer_quota
435 }
436 Some(RateLimitRule::Unlimited) => true,
437 Some(RateLimitRule::NotAllowed) => {
438 false
440 }
441 None => {
442 true
444 }
445 };
446
447 ensure!(allowed, RateLimiterError::ExceedLimit);
448
449 Ok(())
450 }
451
452 fn consume(limiter_id: Self::RateLimiterId, key: impl Encode, value: u128) {
453 let encoded_key: Vec<u8> = key.encode();
454
455 match RateLimitRules::<T>::get(limiter_id, &encoded_key) {
456 Some(RateLimitRule::PerPeriod { .. }) | Some(RateLimitRule::TokenBucket { .. }) => {
457 RateLimitQuota::<T>::mutate(limiter_id, &encoded_key, |(_, remainer_quota)| {
459 *remainer_quota = (*remainer_quota).saturating_sub(value);
460 });
461 }
462 _ => {}
463 };
464 }
465 }
466}