1#![cfg_attr(not(feature = "std"), no_std)]
16#![allow(clippy::unused_unit)]
17
18use frame_support::{pallet_prelude::*, traits::UnixTime, transactional, BoundedVec};
19use frame_system::pallet_prelude::*;
20use orml_traits::{RateLimiter, RateLimiterError};
21use orml_utilities::OrderedSet;
22use parity_scale_codec::MaxEncodedLen;
23use scale_info::TypeInfo;
24use sp_runtime::traits::{BlockNumberProvider, SaturatedConversion, Zero};
25use sp_std::{prelude::*, vec::Vec};
26
27pub use module::*;
28pub use weights::WeightInfo;
29
30mod mock;
31mod tests;
32pub mod weights;
33
34#[frame_support::pallet]
35pub mod module {
36 use super::*;
37
38 #[derive(PartialEq, Eq, Clone, Encode, Decode, RuntimeDebug, TypeInfo, DecodeWithMemTracking)]
40 pub enum Period {
41 Blocks(u64),
42 Seconds(u64),
43 }
44
45 #[derive(PartialEq, Eq, Clone, Encode, Decode, RuntimeDebug, TypeInfo, DecodeWithMemTracking)]
47 pub enum RateLimitRule {
48 PerPeriod { period: Period, quota: u128 },
52 TokenBucket {
57 period: Period,
58 quota_increment: u128,
59 max_quota: u128,
60 },
61 Unlimited,
63 NotAllowed,
65 }
66
67 pub const MAX_FILTER_KEY_LENGTH: u32 = 256;
69
70 #[derive(
72 PartialOrd,
73 Ord,
74 PartialEq,
75 Eq,
76 Clone,
77 Encode,
78 Decode,
79 RuntimeDebug,
80 TypeInfo,
81 MaxEncodedLen,
82 DecodeWithMemTracking,
83 )]
84 pub enum KeyFilter {
85 Match(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
87 StartsWith(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
89 EndsWith(BoundedVec<u8, ConstU32<MAX_FILTER_KEY_LENGTH>>),
91 }
92
93 #[pallet::config]
94 pub trait Config: frame_system::Config {
95 type GovernanceOrigin: EnsureOrigin<Self::RuntimeOrigin>;
97
98 type RateLimiterId: Parameter + Member + Copy + TypeInfo;
99
100 #[pallet::constant]
102 type MaxWhitelistFilterCount: Get<u32>;
103
104 type UnixTime: UnixTime;
106
107 type BlockNumberProvider: BlockNumberProvider<BlockNumber = BlockNumberFor<Self>>;
109
110 type WeightInfo: WeightInfo;
112 }
113
114 #[pallet::error]
115 pub enum Error<T> {
116 InvalidRateLimitRule,
118 FilterExisted,
120 FilterNotExisted,
122 MaxFilterExceeded,
125 }
126
127 #[pallet::event]
128 #[pallet::generate_deposit(pub(crate) fn deposit_event)]
129 pub enum Event<T: Config> {
130 RateLimitRuleUpdated {
132 rate_limiter_id: T::RateLimiterId,
133 encoded_key: Vec<u8>,
134 update: Option<RateLimitRule>,
135 },
136 WhitelistFilterAdded { rate_limiter_id: T::RateLimiterId },
138 WhitelistFilterRemoved { rate_limiter_id: T::RateLimiterId },
140 WhitelistFilterReset { rate_limiter_id: T::RateLimiterId },
142 }
143
144 #[pallet::storage]
148 #[pallet::getter(fn rate_limit_rules)]
149 pub type RateLimitRules<T: Config> =
150 StorageDoubleMap<_, Twox64Concat, T::RateLimiterId, Blake2_128Concat, Vec<u8>, RateLimitRule, OptionQuery>;
151
152 #[pallet::storage]
157 #[pallet::getter(fn rate_limit_quota)]
158 pub type RateLimitQuota<T: Config> =
159 StorageDoubleMap<_, Twox64Concat, T::RateLimiterId, Blake2_128Concat, Vec<u8>, (u64, u128), ValueQuery>;
160
161 #[pallet::storage]
165 #[pallet::getter(fn limit_whitelist)]
166 pub type LimitWhitelist<T: Config> =
167 StorageMap<_, Twox64Concat, T::RateLimiterId, OrderedSet<KeyFilter, T::MaxWhitelistFilterCount>, ValueQuery>;
168
169 #[pallet::pallet]
170 #[pallet::without_storage_info]
171 pub struct Pallet<T>(_);
172
173 #[pallet::hooks]
174 impl<T: Config> Hooks<BlockNumberFor<T>> for Pallet<T> {}
175
176 #[pallet::call]
177 impl<T: Config> Pallet<T> {
178 #[pallet::call_index(0)]
188 #[pallet::weight(T::WeightInfo::update_rate_limit_rule())]
189 #[transactional]
190 pub fn update_rate_limit_rule(
191 origin: OriginFor<T>,
192 rate_limiter_id: T::RateLimiterId,
193 encoded_key: Vec<u8>,
194 update: Option<RateLimitRule>,
195 ) -> DispatchResult {
196 T::GovernanceOrigin::ensure_origin(origin)?;
197
198 RateLimitRules::<T>::try_mutate_exists(rate_limiter_id, &encoded_key, |maybe_limit| -> DispatchResult {
199 *maybe_limit = update.clone();
200
201 if let Some(rule) = maybe_limit {
202 match rule {
203 RateLimitRule::PerPeriod { period, quota } => {
204 match period {
205 Period::Blocks(blocks_count) => {
206 ensure!(!blocks_count.is_zero(), Error::<T>::InvalidRateLimitRule);
207 }
208 Period::Seconds(secs_count) => {
209 ensure!(!secs_count.is_zero(), Error::<T>::InvalidRateLimitRule);
210 }
211 }
212
213 ensure!(!quota.is_zero(), Error::<T>::InvalidRateLimitRule);
214 }
215 RateLimitRule::TokenBucket {
216 period,
217 quota_increment,
218 max_quota,
219 } => {
220 match period {
221 Period::Blocks(blocks_count) => {
222 ensure!(!blocks_count.is_zero(), Error::<T>::InvalidRateLimitRule);
223 }
224 Period::Seconds(secs_count) => {
225 ensure!(!secs_count.is_zero(), Error::<T>::InvalidRateLimitRule);
226 }
227 }
228
229 ensure!(
230 !quota_increment.is_zero() && !max_quota.is_zero(),
231 Error::<T>::InvalidRateLimitRule
232 );
233 }
234 RateLimitRule::Unlimited => {}
235 RateLimitRule::NotAllowed => {}
236 }
237 }
238
239 RateLimitQuota::<T>::remove(rate_limiter_id, &encoded_key);
241
242 Self::deposit_event(Event::RateLimitRuleUpdated {
243 rate_limiter_id,
244 encoded_key: encoded_key.clone(),
245 update,
246 });
247
248 Ok(())
249 })
250 }
251
252 #[pallet::call_index(1)]
260 #[pallet::weight(T::WeightInfo::add_whitelist())]
261 #[transactional]
262 pub fn add_whitelist(
263 origin: OriginFor<T>,
264 rate_limiter_id: T::RateLimiterId,
265 key_filter: KeyFilter,
266 ) -> DispatchResult {
267 T::GovernanceOrigin::ensure_origin(origin)?;
268
269 LimitWhitelist::<T>::try_mutate(rate_limiter_id, |whitelist| -> DispatchResult {
270 ensure!(!whitelist.contains(&key_filter), Error::<T>::FilterExisted);
271 let inserted = whitelist.insert(key_filter);
272 ensure!(inserted, Error::<T>::MaxFilterExceeded);
273
274 Self::deposit_event(Event::WhitelistFilterAdded { rate_limiter_id });
275 Ok(())
276 })
277 }
278
279 #[pallet::call_index(2)]
287 #[pallet::weight(T::WeightInfo::remove_whitelist())]
288 #[transactional]
289 pub fn remove_whitelist(
290 origin: OriginFor<T>,
291 rate_limiter_id: T::RateLimiterId,
292 key_filter: KeyFilter,
293 ) -> DispatchResult {
294 T::GovernanceOrigin::ensure_origin(origin)?;
295
296 LimitWhitelist::<T>::try_mutate(rate_limiter_id, |whitelist| -> DispatchResult {
297 ensure!(whitelist.contains(&key_filter), Error::<T>::FilterNotExisted);
298 whitelist.remove(&key_filter);
299
300 Self::deposit_event(Event::WhitelistFilterRemoved { rate_limiter_id });
301 Ok(())
302 })
303 }
304
305 #[pallet::call_index(3)]
313 #[pallet::weight(T::WeightInfo::reset_whitelist())]
314 #[transactional]
315 pub fn reset_whitelist(
316 origin: OriginFor<T>,
317 rate_limiter_id: T::RateLimiterId,
318 new_list: Vec<KeyFilter>,
319 ) -> DispatchResult {
320 T::GovernanceOrigin::ensure_origin(origin)?;
321
322 let whitelist: BoundedVec<KeyFilter, T::MaxWhitelistFilterCount> =
323 BoundedVec::try_from(new_list).map_err(|_| Error::<T>::MaxFilterExceeded)?;
324 let ordered_set: OrderedSet<KeyFilter, T::MaxWhitelistFilterCount> = whitelist.into();
325 LimitWhitelist::<T>::insert(rate_limiter_id, ordered_set);
326
327 Self::deposit_event(Event::WhitelistFilterReset { rate_limiter_id });
328 Ok(())
329 }
330 }
331
332 impl<T: Config> Pallet<T> {
333 pub fn access_remainer_quota_after_update(
336 rate_limit_rule: RateLimitRule,
337 limiter_id: &T::RateLimiterId,
338 encoded_key: &Vec<u8>,
339 ) -> u128 {
340 RateLimitQuota::<T>::mutate(limiter_id, encoded_key, |(last_updated, remainer_quota)| -> u128 {
341 match rate_limit_rule {
342 RateLimitRule::PerPeriod { period, quota } => {
343 let (now, count): (u64, u64) = match period {
344 Period::Blocks(blocks_count) => (
345 T::BlockNumberProvider::current_block_number().saturated_into(),
346 blocks_count,
347 ),
348 Period::Seconds(secs_count) => (T::UnixTime::now().as_secs(), secs_count),
349 };
350
351 let interval: u64 = now.saturating_sub(*last_updated);
352 if interval >= count {
353 *last_updated = now;
354 *remainer_quota = quota;
355 }
356 }
357
358 RateLimitRule::TokenBucket {
359 period,
360 quota_increment,
361 max_quota,
362 } => {
363 let (now, count): (u64, u64) = match period {
364 Period::Blocks(blocks_count) => (
365 T::BlockNumberProvider::current_block_number().saturated_into(),
366 blocks_count,
367 ),
368 Period::Seconds(secs_count) => (T::UnixTime::now().as_secs(), secs_count),
369 };
370
371 let interval: u64 = now.saturating_sub(*last_updated);
372 if !count.is_zero() && interval >= count {
373 let inc_times: u128 = interval
374 .checked_div(count)
375 .expect("already ensure count is not zero; qed")
376 .saturated_into();
377
378 *last_updated = now;
379 *remainer_quota = quota_increment
380 .saturating_mul(inc_times)
381 .saturating_add(*remainer_quota)
382 .min(max_quota);
383 }
384 }
385
386 RateLimitRule::Unlimited | RateLimitRule::NotAllowed => {}
387 }
388
389 *remainer_quota
390 })
391 }
392 }
393
394 impl<T: Config> RateLimiter for Pallet<T> {
395 type RateLimiterId = T::RateLimiterId;
396
397 fn is_whitelist(limiter_id: Self::RateLimiterId, key: impl Encode) -> bool {
398 let encode_key: Vec<u8> = key.encode();
399
400 for key_filter in LimitWhitelist::<T>::get(limiter_id).0 {
401 match key_filter {
402 KeyFilter::Match(bounded_vec) => {
403 if encode_key == bounded_vec.into_inner() {
404 return true;
405 }
406 }
407 KeyFilter::StartsWith(prefix) => {
408 if encode_key.starts_with(&prefix) {
409 return true;
410 }
411 }
412 KeyFilter::EndsWith(postfix) => {
413 if encode_key.ends_with(&postfix) {
414 return true;
415 }
416 }
417 }
418 }
419
420 false
421 }
422
423 fn can_consume(limiter_id: Self::RateLimiterId, key: impl Encode, value: u128) -> Result<(), RateLimiterError> {
424 let encoded_key: Vec<u8> = key.encode();
425
426 let allowed = match RateLimitRules::<T>::get(limiter_id, &encoded_key) {
427 Some(rate_limit_rule @ RateLimitRule::PerPeriod { .. })
428 | Some(rate_limit_rule @ RateLimitRule::TokenBucket { .. }) => {
429 let remainer_quota =
430 Self::access_remainer_quota_after_update(rate_limit_rule, &limiter_id, &encoded_key);
431
432 value <= remainer_quota
433 }
434 Some(RateLimitRule::Unlimited) => true,
435 Some(RateLimitRule::NotAllowed) => {
436 false
438 }
439 None => {
440 true
442 }
443 };
444
445 ensure!(allowed, RateLimiterError::ExceedLimit);
446
447 Ok(())
448 }
449
450 fn consume(limiter_id: Self::RateLimiterId, key: impl Encode, value: u128) {
451 let encoded_key: Vec<u8> = key.encode();
452
453 match RateLimitRules::<T>::get(limiter_id, &encoded_key) {
454 Some(RateLimitRule::PerPeriod { .. }) | Some(RateLimitRule::TokenBucket { .. }) => {
455 RateLimitQuota::<T>::mutate(limiter_id, &encoded_key, |(_, remainer_quota)| {
457 *remainer_quota = (*remainer_quota).saturating_sub(value);
458 });
459 }
460 _ => {}
461 };
462 }
463 }
464}