kora_lib/usage_limit/
usage_tracker.rs

1use std::{collections::HashSet, sync::Arc};
2
3use deadpool_redis::Runtime;
4use redis::AsyncCommands;
5use solana_sdk::{pubkey::Pubkey, transaction::VersionedTransaction};
6use tokio::sync::OnceCell;
7
8use super::usage_store::{RedisUsageStore, UsageStore};
9use crate::{error::KoraError, sanitize_error, state::get_signer_pool};
10
11#[cfg(not(test))]
12use crate::state::get_config;
13
14#[cfg(test)]
15use crate::tests::config_mock::mock_state::get_config;
16
17const USAGE_CACHE_KEY: &str = "kora:usage_limit";
18
19/// Global usage limiter instance
20static USAGE_LIMITER: OnceCell<Option<UsageTracker>> = OnceCell::const_new();
21
22pub struct UsageTracker {
23    store: Arc<dyn UsageStore>,
24    max_transactions: u64,
25    kora_signers: HashSet<Pubkey>,
26    fallback_if_unavailable: bool,
27}
28
29impl UsageTracker {
30    pub fn new(
31        store: Arc<dyn UsageStore>,
32        max_transactions: u64,
33        kora_signers: HashSet<Pubkey>,
34        fallback_if_unavailable: bool,
35    ) -> Self {
36        Self { store, max_transactions, kora_signers, fallback_if_unavailable }
37    }
38
39    fn get_usage_key(&self, wallet: &Pubkey) -> String {
40        format!("{USAGE_CACHE_KEY}:{wallet}")
41    }
42
43    /// Handle store errors according to fallback configuration
44    fn handle_store_error(
45        &self,
46        error: KoraError,
47        operation: &str,
48        wallet: &Pubkey,
49    ) -> Result<(), KoraError> {
50        log::error!("Failed to {operation} for {wallet}: {error}");
51        if self.fallback_if_unavailable {
52            log::error!("Fallback enabled - allowing transaction due to store error");
53            Ok(()) // Allow transaction when fallback is enabled
54        } else {
55            Err(KoraError::InternalServerError(format!(
56                "Usage limit store unavailable and fallback disabled: {error}"
57            )))
58        }
59    }
60
61    async fn check_usage_limit(&self, wallet: &Pubkey) -> Result<(), KoraError> {
62        // Skip check if unlimited (0)
63        if self.max_transactions == 0 {
64            return Ok(());
65        }
66
67        // Check current count first, then increment only if allowed
68        let key = self.get_usage_key(wallet);
69
70        // Handle store.get() errors using helper
71        let current_count = match self.store.get(&key).await {
72            Ok(count) => count,
73            Err(e) => {
74                return self.handle_store_error(e, "get usage count", wallet);
75            }
76        };
77
78        if current_count >= self.max_transactions as u32 {
79            return Err(KoraError::UsageLimitExceeded(format!(
80                "Wallet {wallet} exceeded limit: {}/{}",
81                current_count + 1,
82                self.max_transactions
83            )));
84        }
85
86        // Handle store.increment() errors using helper
87        let new_count = match self.store.increment(&key).await {
88            Ok(count) => count,
89            Err(e) => {
90                return self.handle_store_error(e, "increment usage count", wallet);
91            }
92        };
93
94        log::debug!("Usage check passed for {wallet}: {new_count}/{}", self.max_transactions);
95
96        Ok(())
97    }
98
99    fn get_usage_limiter() -> Result<Option<&'static UsageTracker>, KoraError> {
100        match USAGE_LIMITER.get() {
101            Some(limiter) => Ok(limiter.as_ref()),
102            None => {
103                Err(KoraError::InternalServerError("Usage limiter not initialized".to_string()))
104            }
105        }
106    }
107
108    /// Extract sender from transaction
109    fn extract_transaction_sender(
110        &self,
111        transaction: &VersionedTransaction,
112    ) -> Result<Option<Pubkey>, KoraError> {
113        let account_keys = transaction.message.static_account_keys();
114
115        if account_keys.is_empty() {
116            return Err(KoraError::InvalidTransaction(
117                "Transaction has no account keys".to_string(),
118            ));
119        }
120
121        let signers = account_keys
122            .iter()
123            .take(transaction.message.header().num_required_signatures as usize)
124            .collect::<Vec<_>>();
125
126        for signer in &signers {
127            if !self.kora_signers.contains(signer) {
128                return Ok(Some(**signer));
129            }
130        }
131
132        log::debug!(
133            "No user signers found when extracting transaction sender for usage limit: {signers:?}",
134        );
135
136        Ok(None)
137    }
138
139    /// Initialize the global usage limiter
140    pub async fn init_usage_limiter() -> Result<(), KoraError> {
141        let config = get_config()?;
142
143        if !config.kora.usage_limit.enabled {
144            log::info!("Usage limiting disabled");
145            USAGE_LIMITER.set(None).map_err(|_| {
146                KoraError::InternalServerError("Usage limiter already initialized".to_string())
147            })?;
148            return Ok(());
149        }
150
151        let usage_limiter = if let Some(cache_url) = &config.kora.usage_limit.cache_url {
152            let cfg = deadpool_redis::Config::from_url(cache_url);
153            let pool = cfg.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
154                KoraError::InternalServerError(format!(
155                    "Failed to create Redis pool: {}",
156                    sanitize_error!(e)
157                ))
158            })?;
159
160            // Test Redis connection
161            let mut conn = pool.get().await.map_err(|e| {
162                KoraError::InternalServerError(format!(
163                    "Failed to connect to Redis: {}",
164                    sanitize_error!(e)
165                ))
166            })?;
167
168            // Simple connection test
169            let _: Option<String> = conn.get("__usage_limiter_test__").await.map_err(|e| {
170                KoraError::InternalServerError(format!(
171                    "Redis connection test failed: {}",
172                    sanitize_error!(e)
173                ))
174            })?;
175
176            log::info!(
177                "Usage limiter initialized with max {} transactions",
178                config.kora.usage_limit.max_transactions
179            );
180
181            let kora_signers = get_signer_pool()?
182                .get_signers_info()
183                .iter()
184                .filter_map(|info| info.public_key.parse().ok())
185                .collect();
186
187            let store = Arc::new(RedisUsageStore::new(pool));
188            Some(UsageTracker::new(
189                store,
190                config.kora.usage_limit.max_transactions,
191                kora_signers,
192                config.kora.usage_limit.fallback_if_unavailable,
193            ))
194        } else {
195            log::info!("Usage limiting enabled but no cache_url configured - disabled");
196            None
197        };
198
199        USAGE_LIMITER.set(usage_limiter).map_err(|_| {
200            KoraError::InternalServerError("Usage limiter already initialized".to_string())
201        })?;
202
203        Ok(())
204    }
205
206    /// Check usage limit for transaction sender
207    pub async fn check_transaction_usage_limit(
208        transaction: &VersionedTransaction,
209    ) -> Result<(), KoraError> {
210        let config = get_config()?;
211
212        if let Some(limiter) = Self::get_usage_limiter()? {
213            let sender = limiter.extract_transaction_sender(transaction)?;
214            if let Some(sender) = sender {
215                limiter.check_usage_limit(&sender).await?;
216            }
217            Ok(())
218        } else if config.kora.usage_limit.enabled
219            && !config.kora.usage_limit.fallback_if_unavailable
220        {
221            // Usage limiting enabled but limiter unavailable and fallback disabled
222            Err(KoraError::InternalServerError(
223                "Usage limiter unavailable and fallback disabled".to_string(),
224            ))
225        } else {
226            // Usage limiting disabled or fallback allowed
227            Ok(())
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::{
236        tests::{config_mock::ConfigMockBuilder, transaction_mock::create_mock_transaction},
237        usage_limit::{usage_store::ErrorUsageStore, InMemoryUsageStore},
238    };
239
240    #[tokio::test]
241    async fn test_get_usage_key_format() {
242        let wallet = Pubkey::new_unique();
243        let expected_key = format!("kora:usage_limit:{wallet}");
244
245        assert_eq!(expected_key, format!("kora:usage_limit:{wallet}"));
246    }
247
248    #[tokio::test]
249    async fn test_usage_limit_enforcement() {
250        let store = Arc::new(InMemoryUsageStore::new());
251        let kora_signers = HashSet::new();
252        let tracker = UsageTracker::new(store, 2, kora_signers, true);
253
254        let wallet = Pubkey::new_unique();
255
256        // First transaction should succeed
257        assert!(tracker.check_usage_limit(&wallet).await.is_ok());
258
259        // Second transaction should succeed (at limit)
260        assert!(tracker.check_usage_limit(&wallet).await.is_ok());
261
262        // Third transaction should fail (over limit)
263        let result = tracker.check_usage_limit(&wallet).await;
264        assert!(result.is_err());
265        assert!(result.unwrap_err().to_string().contains("exceeded limit"));
266    }
267
268    #[tokio::test]
269    async fn test_independent_wallet_limits() {
270        let store = Arc::new(InMemoryUsageStore::new());
271        let kora_signers = HashSet::new();
272        let tracker = UsageTracker::new(store, 2, kora_signers, true);
273
274        let wallet1 = Pubkey::new_unique();
275        let wallet2 = Pubkey::new_unique();
276
277        // Use up wallet1's limit
278        assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
279        assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
280        assert!(tracker.check_usage_limit(&wallet1).await.is_err());
281
282        // Wallet2 should still be able to make transactions
283        assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
284        assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
285        assert!(tracker.check_usage_limit(&wallet2).await.is_err());
286    }
287
288    #[tokio::test]
289    async fn test_unlimited_usage() {
290        let store = Arc::new(InMemoryUsageStore::new());
291        let kora_signers = HashSet::new();
292        let tracker = UsageTracker::new(store, 0, kora_signers, true); // 0 = unlimited
293
294        let wallet = Pubkey::new_unique();
295
296        // Should allow many transactions when unlimited
297        for _ in 0..10 {
298            assert!(tracker.check_usage_limit(&wallet).await.is_ok());
299        }
300    }
301
302    #[tokio::test]
303    async fn test_usage_limiter_disabled_fallback() {
304        // Test that when usage limiting is disabled, transactions are allowed
305        let _m = ConfigMockBuilder::new().with_usage_limit_enabled(false).build_and_setup();
306
307        // Initialize the usage limiter - it should set to None when disabled
308        let _ = UsageTracker::init_usage_limiter().await;
309
310        let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
311        match &result {
312            Ok(_) => {}
313            Err(e) => println!("Test failed with error: {e}"),
314        }
315        assert!(result.is_ok());
316    }
317
318    #[tokio::test]
319    async fn test_usage_limiter_fallback_allowed() {
320        let _m = ConfigMockBuilder::new()
321            .with_usage_limit_enabled(true)
322            .with_usage_limit_cache_url(None)
323            .with_usage_limit_fallback(true)
324            .build_and_setup();
325
326        // Initialize with no cache_url - should set limiter to None
327        let _ = UsageTracker::init_usage_limiter().await;
328
329        let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
330        assert!(result.is_ok());
331    }
332
333    #[tokio::test]
334    async fn test_usage_limiter_fallback_denied() {
335        let _m = ConfigMockBuilder::new()
336            .with_usage_limit_enabled(true)
337            .with_usage_limit_cache_url(None)
338            .with_usage_limit_fallback(false)
339            .build_and_setup();
340
341        // Initialize with no cache_url - should set limiter to None
342        let _ = UsageTracker::init_usage_limiter().await;
343
344        let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
345        assert!(result.is_err());
346        assert!(result
347            .unwrap_err()
348            .to_string()
349            .contains("Usage limiter unavailable and fallback disabled"));
350    }
351
352    #[tokio::test]
353    async fn test_usage_limit_store_get_error_fallback_enabled() {
354        let store = Arc::new(ErrorUsageStore::new(true, false)); // get() will error
355        let kora_signers = HashSet::new();
356        let tracker = UsageTracker::new(store, 2, kora_signers, true); // fallback enabled
357
358        let wallet = Pubkey::new_unique();
359
360        // Should succeed because fallback is enabled
361        let result = tracker.check_usage_limit(&wallet).await;
362        assert!(result.is_ok());
363    }
364
365    #[tokio::test]
366    async fn test_usage_limit_store_get_error_fallback_disabled() {
367        let store = Arc::new(ErrorUsageStore::new(true, false)); // get() will error
368        let kora_signers = HashSet::new();
369        let tracker = UsageTracker::new(store, 2, kora_signers, false); // fallback disabled
370
371        let wallet = Pubkey::new_unique();
372
373        // Should fail because fallback is disabled
374        let result = tracker.check_usage_limit(&wallet).await;
375        assert!(result.is_err());
376        assert!(result
377            .unwrap_err()
378            .to_string()
379            .contains("Usage limit store unavailable and fallback disabled"));
380    }
381
382    #[tokio::test]
383    async fn test_usage_limit_store_increment_error_fallback_enabled() {
384        let store = Arc::new(ErrorUsageStore::new(false, true)); // increment() will error
385        let kora_signers = HashSet::new();
386        let tracker = UsageTracker::new(store, 2, kora_signers, true); // fallback enabled
387
388        let wallet = Pubkey::new_unique();
389
390        // Should succeed because fallback is enabled (get() succeeds, increment() fails but fallback allows)
391        let result = tracker.check_usage_limit(&wallet).await;
392        assert!(result.is_ok());
393    }
394
395    #[tokio::test]
396    async fn test_usage_limit_store_increment_error_fallback_disabled() {
397        let store = Arc::new(ErrorUsageStore::new(false, true)); // increment() will error
398        let kora_signers = HashSet::new();
399        let tracker = UsageTracker::new(store, 2, kora_signers, false); // fallback disabled
400
401        let wallet = Pubkey::new_unique();
402
403        // Should fail because fallback is disabled
404        let result = tracker.check_usage_limit(&wallet).await;
405        assert!(result.is_err());
406        assert!(result
407            .unwrap_err()
408            .to_string()
409            .contains("Usage limit store unavailable and fallback disabled"));
410    }
411}