Skip to main content

kora_lib/usage_limit/
usage_tracker.rs

1use std::{collections::HashSet, sync::Arc, time::SystemTime};
2
3use super::{
4    limiter::{LimiterContext, LimiterResult},
5    rules::{InstructionRule, UsageRule},
6    usage_store::{InMemoryUsageStore, RedisUsageStore},
7    UsageStore,
8};
9use crate::{
10    cache::CacheUtil,
11    config::Config,
12    error::KoraError,
13    sanitize_error,
14    state::get_signer_pool,
15    token::token::TokenType,
16    transaction::{
17        ParsedSPLInstructionData, ParsedSPLInstructionType, VersionedTransactionResolved,
18    },
19};
20use deadpool_redis::Runtime;
21use redis::AsyncCommands;
22use solana_client::nonblocking::rpc_client::RpcClient;
23use solana_sdk::pubkey::Pubkey;
24use tokio::sync::OnceCell;
25
26#[cfg(not(test))]
27use crate::state::get_config;
28
29#[cfg(test)]
30use crate::tests::config_mock::mock_state::get_config;
31
32/// Global usage limiter instance
33static USAGE_LIMITER: OnceCell<Option<UsageTracker>> = OnceCell::const_new();
34
35pub struct UsageTracker {
36    enabled: bool,
37    store: Arc<dyn UsageStore>,
38    rules: Vec<UsageRule>,
39    instruction_rule_indices: Vec<usize>,
40    kora_signers: HashSet<Pubkey>,
41    fallback_if_unavailable: bool,
42}
43
44impl UsageTracker {
45    pub fn new(
46        enabled: bool,
47        store: Arc<dyn UsageStore>,
48        rules: Vec<UsageRule>,
49        kora_signers: HashSet<Pubkey>,
50        fallback_if_unavailable: bool,
51    ) -> Self {
52        // Pre-compute instruction rule indices at initialization
53        let instruction_rule_indices: Vec<usize> =
54            rules
55                .iter()
56                .enumerate()
57                .filter_map(|(idx, rule)| {
58                    if matches!(rule, UsageRule::Instruction(_)) {
59                        Some(idx)
60                    } else {
61                        None
62                    }
63                })
64                .collect();
65
66        Self {
67            enabled,
68            store,
69            rules,
70            instruction_rule_indices,
71            kora_signers,
72            fallback_if_unavailable,
73        }
74    }
75
76    fn get_usage_limiter() -> Result<Option<&'static UsageTracker>, KoraError> {
77        match USAGE_LIMITER.get() {
78            Some(limiter) => Ok(limiter.as_ref()),
79            None => {
80                Err(KoraError::InternalServerError("Usage limiter not initialized".to_string()))
81            }
82        }
83    }
84
85    fn is_enabled(&self) -> bool {
86        self.enabled && !self.rules.is_empty()
87    }
88
89    fn has_instruction_rules(&self) -> bool {
90        !self.instruction_rule_indices.is_empty()
91    }
92
93    async fn extract_user_from_payment_instruction(
94        &self,
95        transaction: &mut VersionedTransactionResolved,
96        config: &Config,
97        fee_payer: &Pubkey,
98        rpc_client: &RpcClient,
99    ) -> Result<Option<Pubkey>, KoraError> {
100        let payment_destination = config.kora.get_payment_address(fee_payer)?;
101        let parsed_spl_instructions = transaction.get_or_parse_spl_instructions()?;
102
103        for instruction in parsed_spl_instructions
104            .get(&ParsedSPLInstructionType::SplTokenTransfer)
105            .unwrap_or(&vec![])
106        {
107            if let ParsedSPLInstructionData::SplTokenTransfer {
108                destination_address, owner, ..
109            } = instruction
110            {
111                // Check if this is a payment to Kora by verifying the destination token account owner
112                // matches the payment destination
113                let destination_account =
114                    match CacheUtil::get_account(config, rpc_client, destination_address, true)
115                        .await
116                    {
117                        Ok(account) => account,
118                        Err(KoraError::AccountNotFound(_)) => continue,
119                        Err(e) => return Err(e),
120                    };
121
122                let token_program =
123                    TokenType::get_token_program_from_owner(&destination_account.owner)?;
124                let token_account =
125                    token_program.unpack_token_account(&destination_account.data)?;
126
127                // Check if this is a payment to Kora
128                if token_account.owner() == payment_destination {
129                    return Ok(Some(*owner));
130                }
131            }
132        }
133
134        Ok(None)
135    }
136
137    /// Extract kora signer from transaction signers
138    fn extract_kora_signer(&self, transaction: &VersionedTransactionResolved) -> Option<Pubkey> {
139        let account_keys = transaction.message.static_account_keys();
140        let num_signers = transaction.message.header().num_required_signatures as usize;
141
142        account_keys
143            .iter()
144            .take(num_signers)
145            .find(|signer| self.kora_signers.contains(signer))
146            .copied()
147    }
148
149    fn current_timestamp() -> u64 {
150        SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0)
151    }
152
153    /// Check and record usage for a transaction
154    /// Uses two-phase commit: check all rules first, then increment only if all pass
155    async fn check_and_record(
156        &self,
157        ctx: &mut LimiterContext<'_>,
158    ) -> Result<LimiterResult, KoraError> {
159        if !self.is_enabled() {
160            return Ok(LimiterResult::Allowed);
161        }
162
163        // Extract instruction rules using pre-computed indices (no per-request separation)
164        let instruction_rules: Vec<&InstructionRule> = self
165            .instruction_rule_indices
166            .iter()
167            .filter_map(|&idx| self.rules[idx].as_instruction())
168            .collect();
169
170        // Batch count instruction rules in single pass
171        let instruction_counts = if !instruction_rules.is_empty() {
172            InstructionRule::count_all_rules(&instruction_rules, ctx)
173        } else {
174            Vec::new()
175        };
176
177        // Build HashSet for O(1) lookup instead of Vec::contains O(n)
178        let ix_idx_set: HashSet<usize> = self.instruction_rule_indices.iter().copied().collect();
179
180        // Phase 1: Check all rules first (no incrementing yet)
181        // Collect rule checks: (key, increment_count, window_seconds)
182        let mut pending_increments: Vec<(String, u64, Option<u64>)> = Vec::new();
183        let mut instruction_count_idx = 0;
184
185        for (idx, rule) in self.rules.iter().enumerate() {
186            let increment_count = if ix_idx_set.contains(&idx) {
187                // Use pre-computed count for instruction rule
188                let count = instruction_counts[instruction_count_idx];
189                instruction_count_idx += 1;
190                count
191            } else {
192                // Transaction rules always increment by 1
193                1
194            };
195
196            if increment_count == 0 {
197                continue;
198            }
199
200            let key = rule.storage_key(&ctx.user_id, ctx.timestamp);
201
202            let current = self.store.get(&key).await?;
203            let new_count = current as u64 + increment_count;
204
205            if new_count > rule.max() {
206                return Ok(LimiterResult::Denied {
207                    reason: format!(
208                        "User {} exceeded {} limit: {}/{}",
209                        ctx.user_id,
210                        rule.description(),
211                        new_count,
212                        rule.max()
213                    ),
214                });
215            }
216
217            // Queue for increment (don't increment yet)
218            pending_increments.push((key, increment_count, rule.window_seconds()));
219
220            log::debug!(
221                "[rule] User {} {}: {}/{} ({})",
222                ctx.user_id,
223                rule.description(),
224                new_count,
225                rule.max(),
226                rule.window_seconds().map_or("lifetime".to_string(), |w| format!("{}s window", w))
227            );
228        }
229
230        for (key, increment_count, window_seconds) in pending_increments {
231            if let Some(window) = window_seconds.filter(|&w| w > 0) {
232                // Calculate bucket boundary: key expires at end of current bucket
233                // bucket = timestamp / window, so bucket_end = (bucket + 1) * window
234                let expires_at = (ctx.timestamp / window + 1) * window;
235                // First increment with expiry
236                self.store.increment_with_expiry(&key, expires_at).await?;
237                // Subsequent increments without resetting expiry
238                for _ in 1..increment_count {
239                    self.store.increment(&key).await?;
240                }
241            } else {
242                for _ in 0..increment_count {
243                    self.store.increment(&key).await?;
244                }
245            }
246        }
247
248        Ok(LimiterResult::Allowed)
249    }
250
251    pub async fn init_usage_limiter() -> Result<(), KoraError> {
252        let config = get_config()?;
253        let usage_config = &config.kora.usage_limit;
254
255        let set_limiter = |limiter| {
256            USAGE_LIMITER.set(limiter).map_err(|_| {
257                KoraError::InternalServerError("Usage limiter already initialized".to_string())
258            })
259        };
260
261        if !usage_config.enabled {
262            log::info!("Usage limiting disabled");
263            return set_limiter(None);
264        }
265
266        let rules = usage_config.build_rules()?;
267        if rules.is_empty() {
268            log::info!("Usage limiting enabled but no rules configured - disabled");
269            return set_limiter(None);
270        }
271
272        let kora_signers = get_signer_pool()?
273            .get_signers_info()
274            .iter()
275            .filter_map(|info| info.public_key.parse().ok())
276            .collect();
277
278        let (store, backend): (Arc<dyn UsageStore>, &str) =
279            if let Some(cache_url) = &usage_config.cache_url {
280                let cfg = deadpool_redis::Config::from_url(cache_url);
281                let pool = cfg.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
282                    KoraError::InternalServerError(format!(
283                        "Failed to create Redis pool: {}",
284                        sanitize_error!(e)
285                    ))
286                })?;
287
288                let mut conn = pool.get().await.map_err(|e| {
289                    KoraError::InternalServerError(format!(
290                        "Failed to connect to Redis: {}",
291                        sanitize_error!(e)
292                    ))
293                })?;
294
295                let _: Option<String> = conn.get("__usage_limiter_test__").await.map_err(|e| {
296                    KoraError::InternalServerError(format!(
297                        "Redis connection test failed: {}",
298                        sanitize_error!(e)
299                    ))
300                })?;
301
302                (Arc::new(RedisUsageStore::new(pool)), "Redis")
303            } else {
304                log::warn!(
305                    "Usage limiting configured with in-memory store. \
306                     Limits will NOT be shared across instances and will reset on restart. \
307                     Configure 'cache_url' in [kora.usage_limit] for production deployments."
308                );
309                (Arc::new(InMemoryUsageStore::new()), "in-memory")
310            };
311
312        log::info!("Usage limiting initialized with {} rules ({backend})", rules.len());
313
314        set_limiter(Some(UsageTracker::new(
315            usage_config.enabled,
316            store,
317            rules,
318            kora_signers,
319            usage_config.fallback_if_unavailable,
320        )))
321    }
322
323    pub async fn check_transaction_usage_limit(
324        config: &Config,
325        transaction: &mut VersionedTransactionResolved,
326        user_id: Option<&str>,
327        fee_payer: &Pubkey,
328        rpc_client: &RpcClient,
329    ) -> Result<(), KoraError> {
330        // Validate user_id is provided when required
331        if config.kora.usage_limit.enabled
332            && matches!(&config.validation.price.model, crate::fee::price::PriceModel::Free)
333            && user_id.is_none()
334        {
335            return Err(KoraError::ValidationError(
336                "user_id is required when usage tracking is enabled and pricing is free"
337                    .to_string(),
338            ));
339        }
340
341        let Some(tracker) = Self::get_usage_limiter()? else {
342            if config.kora.usage_limit.enabled && !config.kora.usage_limit.fallback_if_unavailable {
343                return Err(KoraError::InternalServerError(
344                    "Usage limiter unavailable and fallback disabled".to_string(),
345                ));
346            }
347            return Ok(());
348        };
349
350        if tracker.has_instruction_rules() {
351            transaction.get_or_parse_system_instructions()?;
352            transaction.get_or_parse_spl_instructions()?;
353        }
354
355        // Resolve user_id for usage tracking:
356        // - If user_id is provided, use it directly (works for both free and paid modes)
357        // - Otherwise (paid mode), extract payer from payment instruction
358        let resolved_user_id = if let Some(user_id_str) = user_id {
359            user_id_str.to_string()
360        } else {
361            // Paid mode: extract payer from payment instruction
362            tracker
363                .extract_user_from_payment_instruction(transaction, config, fee_payer, rpc_client)
364                .await?
365                .ok_or_else(|| {
366                    KoraError::ValidationError(
367                        "Could not resolve user_id: no payment instruction found".to_string(),
368                    )
369                })?
370                .to_string()
371        };
372
373        let kora_signer = tracker.extract_kora_signer(transaction);
374
375        let mut ctx = LimiterContext {
376            transaction,
377            user_id: resolved_user_id,
378            kora_signer,
379            timestamp: Self::current_timestamp(),
380        };
381
382        match tracker.check_and_record(&mut ctx).await {
383            Ok(LimiterResult::Allowed) => Ok(()),
384            Ok(LimiterResult::Denied { reason }) => Err(KoraError::UsageLimitExceeded(reason)),
385            Err(e)
386                if tracker.fallback_if_unavailable
387                    && matches!(e, KoraError::InternalServerError(_)) =>
388            {
389                log::warn!("Usage limiter error (fallback enabled): {e}");
390                Ok(())
391            }
392            Err(e) => Err(e),
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::{
401        tests::{
402            config_mock::ConfigMockBuilder, rpc_mock::RpcMockBuilder,
403            transaction_mock::create_mock_resolved_transaction,
404        },
405        usage_limit::{InMemoryUsageStore, UsageLimitConfig, UsageLimitRuleConfig},
406    };
407    use std::sync::Arc;
408
409    fn create_test_tracker(max_transactions: u64) -> UsageTracker {
410        let store = Arc::new(InMemoryUsageStore::new());
411        let config = UsageLimitConfig {
412            enabled: true,
413            cache_url: None,
414            fallback_if_unavailable: false,
415            rules: vec![UsageLimitRuleConfig::Transaction {
416                max: max_transactions,
417                window_seconds: None,
418            }],
419        };
420        let rules = config.build_rules().unwrap();
421        UsageTracker::new(true, store, rules, HashSet::new(), false)
422    }
423
424    #[tokio::test]
425    async fn test_usage_limit_enforcement() {
426        let tracker = create_test_tracker(2);
427        let user_id = "test-user-enforcement".to_string();
428
429        let mut tx1 = create_mock_resolved_transaction();
430        let mut ctx1 = LimiterContext {
431            transaction: &mut tx1,
432            user_id: user_id.clone(),
433            kora_signer: None,
434            timestamp: 1000000,
435        };
436
437        let mut tx2 = create_mock_resolved_transaction();
438        let mut ctx2 = LimiterContext {
439            transaction: &mut tx2,
440            user_id: user_id.clone(),
441            kora_signer: None,
442            timestamp: 1000000,
443        };
444
445        let mut tx3 = create_mock_resolved_transaction();
446        let mut ctx3 = LimiterContext {
447            transaction: &mut tx3,
448            user_id: user_id.clone(),
449            kora_signer: None,
450            timestamp: 1000000,
451        };
452
453        // First transaction should succeed
454        assert!(matches!(
455            tracker.check_and_record(&mut ctx1).await.unwrap(),
456            LimiterResult::Allowed
457        ));
458
459        // Second transaction should succeed (at limit)
460        assert!(matches!(
461            tracker.check_and_record(&mut ctx2).await.unwrap(),
462            LimiterResult::Allowed
463        ));
464
465        // Third transaction should fail (over limit)
466        assert!(matches!(
467            tracker.check_and_record(&mut ctx3).await.unwrap(),
468            LimiterResult::Denied { .. }
469        ));
470    }
471
472    #[tokio::test]
473    async fn test_independent_user_limits() {
474        let tracker = create_test_tracker(2);
475
476        let user_id1 = "test-user-1".to_string();
477        let user_id2 = "test-user-2".to_string();
478
479        // Use up user1's limit
480        let mut tx1a = create_mock_resolved_transaction();
481        let mut ctx1a = LimiterContext {
482            transaction: &mut tx1a,
483            user_id: user_id1.clone(),
484            kora_signer: None,
485            timestamp: 1000000,
486        };
487        assert!(matches!(
488            tracker.check_and_record(&mut ctx1a).await.unwrap(),
489            LimiterResult::Allowed
490        ));
491        let mut tx1b = create_mock_resolved_transaction();
492        let mut ctx1b = LimiterContext {
493            transaction: &mut tx1b,
494            user_id: user_id1.clone(),
495            kora_signer: None,
496            timestamp: 1000000,
497        };
498        assert!(matches!(
499            tracker.check_and_record(&mut ctx1b).await.unwrap(),
500            LimiterResult::Allowed
501        ));
502        let mut tx1c = create_mock_resolved_transaction();
503        let mut ctx1c = LimiterContext {
504            transaction: &mut tx1c,
505            user_id: user_id1.clone(),
506            kora_signer: None,
507            timestamp: 1000000,
508        };
509        assert!(matches!(
510            tracker.check_and_record(&mut ctx1c).await.unwrap(),
511            LimiterResult::Denied { .. }
512        ));
513
514        // User2 should still be able to make transactions
515        let mut tx2a = create_mock_resolved_transaction();
516        let mut ctx2a = LimiterContext {
517            transaction: &mut tx2a,
518            user_id: user_id2.clone(),
519            kora_signer: None,
520            timestamp: 1000000,
521        };
522        assert!(matches!(
523            tracker.check_and_record(&mut ctx2a).await.unwrap(),
524            LimiterResult::Allowed
525        ));
526        let mut tx2b = create_mock_resolved_transaction();
527        let mut ctx2b = LimiterContext {
528            transaction: &mut tx2b,
529            user_id: user_id2.clone(),
530            kora_signer: None,
531            timestamp: 1000000,
532        };
533        assert!(matches!(
534            tracker.check_and_record(&mut ctx2b).await.unwrap(),
535            LimiterResult::Allowed
536        ));
537        let mut tx2c = create_mock_resolved_transaction();
538        let mut ctx2c = LimiterContext {
539            transaction: &mut tx2c,
540            user_id: user_id2.clone(),
541            kora_signer: None,
542            timestamp: 1000000,
543        };
544        assert!(matches!(
545            tracker.check_and_record(&mut ctx2c).await.unwrap(),
546            LimiterResult::Denied { .. }
547        ));
548    }
549
550    #[tokio::test]
551    async fn test_unlimited_usage() {
552        let store = Arc::new(InMemoryUsageStore::new());
553        let config = UsageLimitConfig {
554            enabled: true,
555            cache_url: None,
556            fallback_if_unavailable: false,
557            rules: vec![], // No rules = unlimited
558        };
559        let rules = config.build_rules().unwrap();
560        let tracker = UsageTracker::new(true, store, rules, HashSet::new(), false);
561
562        let user_id = "test-user-unlimited".to_string();
563
564        // Should allow many transactions when no rules (limiter is not enabled)
565        for _ in 0..10 {
566            let mut tx = create_mock_resolved_transaction();
567            let mut ctx = LimiterContext {
568                transaction: &mut tx,
569                user_id: user_id.clone(),
570                kora_signer: None,
571                timestamp: 1000000,
572            };
573            assert!(matches!(
574                tracker.check_and_record(&mut ctx).await.unwrap(),
575                LimiterResult::Allowed
576            ));
577        }
578    }
579
580    #[tokio::test]
581    async fn test_multiple_rules() {
582        let store: Arc<dyn UsageStore> = Arc::new(InMemoryUsageStore::new());
583
584        let config = UsageLimitConfig {
585            enabled: true,
586            cache_url: None,
587            fallback_if_unavailable: false,
588            rules: vec![
589                // Lifetime limit: 10 transactions
590                UsageLimitRuleConfig::Transaction { max: 10, window_seconds: None },
591                // Time bucket limit: 2 per 100 seconds
592                UsageLimitRuleConfig::Transaction { max: 2, window_seconds: Some(100) },
593            ],
594        };
595
596        let rules = config.build_rules().unwrap();
597        let tracker = UsageTracker::new(true, store, rules, HashSet::new(), false);
598
599        let user_id = "test-user-multiple-rules".to_string();
600        // Use realistic timestamp (current time) so expiry calculations work correctly
601        let now = UsageTracker::current_timestamp();
602
603        // First two should pass (time bucket limit is 2)
604        let mut tx1 = create_mock_resolved_transaction();
605        let mut ctx1 = LimiterContext {
606            transaction: &mut tx1,
607            user_id: user_id.clone(),
608            kora_signer: None,
609            timestamp: now,
610        };
611        assert!(matches!(
612            tracker.check_and_record(&mut ctx1).await.unwrap(),
613            LimiterResult::Allowed
614        ));
615        let mut tx2 = create_mock_resolved_transaction();
616        let mut ctx2 = LimiterContext {
617            transaction: &mut tx2,
618            user_id: user_id.clone(),
619            kora_signer: None,
620            timestamp: now,
621        };
622        assert!(matches!(
623            tracker.check_and_record(&mut ctx2).await.unwrap(),
624            LimiterResult::Allowed
625        ));
626
627        // Third should fail (time bucket limit exceeded)
628        let mut tx3 = create_mock_resolved_transaction();
629        let mut ctx3 = LimiterContext {
630            transaction: &mut tx3,
631            user_id: user_id.clone(),
632            kora_signer: None,
633            timestamp: now,
634        };
635        assert!(matches!(
636            tracker.check_and_record(&mut ctx3).await.unwrap(),
637            LimiterResult::Denied { .. }
638        ));
639    }
640
641    #[tokio::test]
642    async fn test_usage_limiter_disabled_fallback() {
643        // Test that when usage limiting is disabled, transactions are allowed
644        let _m = ConfigMockBuilder::new().with_usage_limit_enabled(false).build_and_setup();
645
646        // Initialize the usage limiter - it should set to None when disabled
647        let _ = UsageTracker::init_usage_limiter().await;
648
649        let config = get_config().unwrap();
650        let mut tx = create_mock_resolved_transaction();
651        let rpc_client = Arc::new(RpcMockBuilder::new().build());
652        let fee_payer = Pubkey::new_unique();
653        let result = UsageTracker::check_transaction_usage_limit(
654            &config,
655            &mut tx,
656            None,
657            &fee_payer,
658            &rpc_client,
659        )
660        .await;
661        match &result {
662            Ok(_) => {}
663            Err(e) => println!("Test failed with error: {e}"),
664        }
665        assert!(result.is_ok());
666    }
667
668    #[tokio::test]
669    async fn test_usage_limiter_fallback_allowed() {
670        let _m = ConfigMockBuilder::new()
671            .with_usage_limit_enabled(true)
672            .with_usage_limit_cache_url(None)
673            .with_usage_limit_fallback(true)
674            .build_and_setup();
675
676        // Initialize with no cache_url - should use in-memory store but no rules = limiter disabled
677        let _ = UsageTracker::init_usage_limiter().await;
678
679        let config = get_config().unwrap();
680        let mut tx = create_mock_resolved_transaction();
681        let rpc_client = Arc::new(RpcMockBuilder::new().build());
682        let fee_payer = Pubkey::new_unique();
683        let result = UsageTracker::check_transaction_usage_limit(
684            &config,
685            &mut tx,
686            None,
687            &fee_payer,
688            &rpc_client,
689        )
690        .await;
691        assert!(result.is_ok());
692    }
693
694    #[tokio::test]
695    async fn test_usage_limiter_fallback_denied() {
696        let _m = ConfigMockBuilder::new()
697            .with_usage_limit_enabled(true)
698            .with_usage_limit_cache_url(None)
699            .with_usage_limit_fallback(false)
700            .build_and_setup();
701
702        // Initialize with no cache_url and no rules - should set limiter to None
703        let _ = UsageTracker::init_usage_limiter().await;
704
705        let config = get_config().unwrap();
706        let mut tx = create_mock_resolved_transaction();
707        let rpc_client = Arc::new(RpcMockBuilder::new().build());
708        let fee_payer = Pubkey::new_unique();
709        let result = UsageTracker::check_transaction_usage_limit(
710            &config,
711            &mut tx,
712            None,
713            &fee_payer,
714            &rpc_client,
715        )
716        .await;
717        assert!(result.is_err());
718        assert!(result
719            .unwrap_err()
720            .to_string()
721            .contains("Usage limiter unavailable and fallback disabled"));
722    }
723}