Skip to main content

loader/
security.rs

1//! Security utilities: key sanitization, rate limiting, etc.
2
3use anyhow::{bail, Result};
4use std::time::{Duration, Instant};
5use tracing::warn;
6
7/// Rate limiter for network requests
8pub struct RateLimiter {
9    min_interval: Duration,
10    last_request: Option<Instant>,
11}
12
13impl RateLimiter {
14    /// Create a new rate limiter with minimum interval between requests
15    pub fn new(min_interval: Duration) -> Self {
16        Self {
17            min_interval,
18            last_request: None,
19        }
20    }
21
22    /// Check if request is allowed; returns true if allowed, false if rate-limited
23    pub fn check_and_update(&mut self) -> bool {
24        match self.last_request {
25            None => {
26                self.last_request = Some(Instant::now());
27                true
28            }
29            Some(last) => {
30                if last.elapsed() >= self.min_interval {
31                    self.last_request = Some(Instant::now());
32                    true
33                } else {
34                    false
35                }
36            }
37        }
38    }
39
40    /// Wait until next request is allowed
41    pub async fn wait_if_needed(&mut self) {
42        if !self.check_and_update() {
43            if let Some(last) = self.last_request {
44                let elapsed = last.elapsed();
45                let wait_time = self.min_interval.saturating_sub(elapsed);
46                if wait_time.as_millis() > 0 {
47                    tokio::time::sleep(wait_time).await;
48                }
49            }
50        }
51    }
52}
53
54/// Validate sensitive operations and sanitize for logging
55pub struct SensitiveData;
56
57impl SensitiveData {
58    /// Mask a key for logging (shows only first 4 and last 4 chars)
59    pub fn mask_key(key_b64: &str) -> String {
60        if key_b64.len() <= 8 {
61            "***".to_string()
62        } else {
63            format!(
64                "{}...{}",
65                &key_b64[..4],
66                &key_b64[key_b64.len() - 4..]
67            )
68        }
69    }
70
71    /// Check if URL looks like HTTPS (strict security)
72    pub fn validate_url_scheme(url: &str) -> Result<()> {
73        if !url.starts_with("https://") {
74            warn!("URL does not use HTTPS: {}", url);
75            // In production, this could be a hard failure:
76            // bail!("only HTTPS URLs are allowed");
77        }
78        Ok(())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_mask_key() {
88        let key = "abcdefghijklmnop";
89        let masked = SensitiveData::mask_key(key);
90        assert_eq!(masked, "abcd...mnop");
91
92        let short_key = "abc";
93        let masked_short = SensitiveData::mask_key(short_key);
94        assert_eq!(masked_short, "***");
95    }
96
97    #[test]
98    fn test_rate_limiter() {
99        let mut limiter = RateLimiter::new(Duration::from_millis(100));
100
101        assert!(limiter.check_and_update());
102        assert!(!limiter.check_and_update()); // Too soon
103    }
104
105    #[tokio::test]
106    async fn test_rate_limiter_async() {
107        let mut limiter = RateLimiter::new(Duration::from_millis(50));
108
109        assert!(limiter.check_and_update());
110        assert!(!limiter.check_and_update());
111
112        limiter.wait_if_needed().await; // Should sleep ~50ms
113        assert!(limiter.check_and_update());
114    }
115
116    #[test]
117    fn test_validate_url() {
118        assert!(SensitiveData::validate_url_scheme("https://example.com").is_ok());
119        // HTTP should warn but not fail by default
120        assert!(SensitiveData::validate_url_scheme("http://example.com").is_ok());
121    }
122}