junobuild_shared/rate/
utils.rs

1use crate::rate::types::{RateConfig, RateTokenStore, RateTokens};
2use crate::types::state::Timestamp;
3use ic_cdk::api::time;
4use std::cmp::min;
5
6pub fn increment_and_assert_rate_store(
7    key: &String,
8    config: &Option<RateConfig>,
9    rate_tokens: &mut RateTokenStore,
10) -> Result<(), String> {
11    let config = match config {
12        Some(config) => config,
13        None => return Ok(()),
14    };
15
16    if let Some(tokens) = rate_tokens.get_mut(key) {
17        increment_and_assert_rate(config, tokens)?;
18    } else {
19        rate_tokens.insert(key.clone(), RateTokens::default());
20    }
21
22    Ok(())
23}
24
25pub fn increment_and_assert_rate(
26    config: &RateConfig,
27    tokens: &mut RateTokens,
28) -> Result<(), String> {
29    increment_and_assert_rate_at(config, tokens, time())
30}
31
32fn increment_and_assert_rate_at(
33    config: &RateConfig,
34    tokens: &mut RateTokens,
35    now: Timestamp,
36) -> Result<(), String> {
37    if config.time_per_token_ns == 0 {
38        return Err("Invalid rate configuration: time_per_token_ns cannot be zero.".to_string());
39    }
40
41    let elapsed = now.saturating_sub(tokens.updated_at);
42    let new_tokens = elapsed / config.time_per_token_ns;
43
44    if new_tokens > 0 {
45        // The number of tokens is capped otherwise tokens might accumulate
46        tokens.tokens = min(config.max_tokens, tokens.tokens.saturating_add(new_tokens));
47        tokens.updated_at += config.time_per_token_ns * new_tokens;
48    }
49
50    // deduct a token for the current call
51    if tokens.tokens > 0 {
52        tokens.tokens -= 1;
53        Ok(())
54    } else {
55        Err("Rate limit reached, try again later.".to_string())
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    fn cfg(time_per_token_ns: u64, max_tokens: u64) -> RateConfig {
64        RateConfig {
65            time_per_token_ns,
66            max_tokens,
67        }
68    }
69
70    #[test]
71    fn refills_tokens_up_to_max_and_consumes_one() {
72        let config = cfg(10, 5);
73
74        let mut tokens = RateTokens {
75            tokens: 0,
76            updated_at: 0,
77        };
78
79        // now = 100 → new_tokens = (100 - 0) / 10 = 10 → capped to 5
80        let result = increment_and_assert_rate_at(&config, &mut tokens, 100);
81
82        assert!(result.is_ok());
83        // 5 tokens refilled - 1 consumed
84        assert_eq!(tokens.tokens, 4);
85        // updated_at advanced by 10 * 10 = 100 (uses new_tokens, not capped)
86        assert_eq!(tokens.updated_at, 100);
87    }
88
89    #[test]
90    fn returns_error_when_no_tokens_available() {
91        let config = cfg(1000, 1);
92
93        // updated_at == now → no refill
94        let mut tokens = RateTokens {
95            tokens: 0,
96            updated_at: 5000,
97        };
98
99        let result = increment_and_assert_rate_at(&config, &mut tokens, 5000);
100
101        assert!(result.is_err());
102        assert_eq!(result.unwrap_err(), "Rate limit reached, try again later.");
103        assert_eq!(tokens.tokens, 0);
104    }
105
106    #[test]
107    fn consumes_one_token_when_available_without_refill() {
108        let config = cfg(1000, 10);
109
110        let mut tokens = RateTokens {
111            tokens: 3,
112            updated_at: 10_000,
113        };
114
115        // now - updated_at < time_per_token_ns → no refill
116        let result = increment_and_assert_rate_at(&config, &mut tokens, 10_999);
117
118        assert!(result.is_ok());
119        assert_eq!(tokens.tokens, 2);
120        assert_eq!(tokens.updated_at, 10_000);
121    }
122
123    #[test]
124    fn no_refill_when_not_enough_time_passed() {
125        let config = cfg(100, 10);
126
127        let mut tokens = RateTokens {
128            tokens: 5,
129            updated_at: 1000,
130        };
131
132        // Not enough time to generate a new token
133        let result = increment_and_assert_rate_at(&config, &mut tokens, 1099);
134
135        assert!(result.is_ok());
136        assert_eq!(tokens.tokens, 4); // just consumed
137        assert_eq!(tokens.updated_at, 1000);
138    }
139
140    // ---------- Edge cases ----------
141
142    #[test]
143    fn max_tokens_zero_never_allows_call() {
144        // Even with huge elapsed time, max_tokens = 0 means no usable tokens.
145        let config = cfg(1, 0);
146
147        let mut tokens = RateTokens {
148            tokens: 0,
149            updated_at: 0,
150        };
151
152        let result = increment_and_assert_rate_at(&config, &mut tokens, 1_000_000);
153
154        assert!(result.is_err());
155        assert_eq!(result.unwrap_err(), "Rate limit reached, try again later.");
156        assert_eq!(tokens.tokens, 0);
157        // updated_at advanced, but that doesn't give us usable tokens
158        assert!(tokens.updated_at > 0);
159    }
160
161    #[test]
162    fn huge_elapsed_time_saturates_at_max_tokens_and_consumes_one() {
163        let config = cfg(1, 5);
164
165        let mut tokens = RateTokens {
166            tokens: 0,
167            updated_at: 0,
168        };
169
170        // Very large now; new_tokens will be huge but capped by max_tokens.
171        let now = u64::MAX;
172
173        let result = increment_and_assert_rate_at(&config, &mut tokens, now);
174
175        assert!(result.is_ok());
176        // Saturated to 5, then consumed 1
177        assert_eq!(tokens.tokens, 4);
178        // updated_at advanced by new_tokens * time_per_token_ns = u64::MAX
179        assert_eq!(tokens.updated_at, now);
180    }
181
182    #[test]
183    fn returns_error_when_time_per_token_is_zero() {
184        let config = cfg(0, 10); // invalid configuration
185
186        let mut tokens = RateTokens {
187            tokens: 5,
188            updated_at: 1000,
189        };
190
191        let result = increment_and_assert_rate_at(&config, &mut tokens, 2000);
192
193        assert!(result.is_err());
194        assert_eq!(
195            result.unwrap_err(),
196            "Invalid rate configuration: time_per_token_ns cannot be zero."
197        );
198
199        // Ensure no mutation happened
200        assert_eq!(tokens.tokens, 5);
201        assert_eq!(tokens.updated_at, 1000);
202    }
203}