Skip to main content

matrixcode_core/matrixrpc/callback/
security.rs

1//! Security Validator
2//!
3//! Validates callback requests from external services.
4//! Ensures request_id + service_id + token combination is valid.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use tokio::sync::RwLock;
11
12use crate::matrixrpc::ServiceId;
13
14/// Security error types
15#[derive(Debug, thiserror::Error)]
16pub enum SecurityError {
17    /// Invalid token
18    #[error("Invalid or expired token")]
19    InvalidToken,
20
21    /// Token expired
22    #[error("Token expired at {0}")]
23    TokenExpired(String),
24
25    /// Service not authorized
26    #[error("Service '{0}' is not authorized for this callback")]
27    ServiceNotAuthorized(String),
28
29    /// Request ID mismatch
30    #[error("Request ID '{0}' does not match the token")]
31    RequestIdMismatch(String),
32
33    /// Missing required field
34    #[error("Missing required field: {0}")]
35    MissingField(String),
36
37    /// Token generation failed
38    #[error("Token generation failed: {0}")]
39    TokenGenerationFailed(String),
40
41    /// Rate limit exceeded
42    #[error("Rate limit exceeded for service '{0}'")]
43    RateLimitExceeded(String),
44
45    /// Internal error
46    #[error("Internal security error: {0}")]
47    Internal(String),
48}
49
50/// Token information
51#[derive(Debug, Clone)]
52pub struct TokenInfo {
53    /// The token value
54    pub token: String,
55
56    /// Service ID that owns this token
57    pub service_id: ServiceId,
58
59    /// Request ID this token is associated with
60    pub request_id: String,
61
62    /// When the token was created
63    pub created_at: Instant,
64
65    /// When the token expires
66    pub expires_at: Instant,
67
68    /// Allowed callback types
69    pub allowed_types: Vec<String>,
70
71    /// Usage count
72    pub usage_count: u32,
73
74    /// Maximum allowed uses
75    pub max_uses: u32,
76}
77
78impl TokenInfo {
79    /// Create a new token info
80    pub fn new(
81        token: String,
82        service_id: ServiceId,
83        request_id: String,
84        lifetime_secs: u64,
85    ) -> Self {
86        let now = Instant::now();
87        Self {
88            token,
89            service_id,
90            request_id,
91            created_at: now,
92            expires_at: now + Duration::from_secs(lifetime_secs),
93            allowed_types: vec![
94                "ai".to_string(), "tool".to_string(), "context".to_string(),
95            ],
96            usage_count: 0,
97            max_uses: 10,
98        }
99    }
100
101    /// Set allowed callback types
102    pub fn with_allowed_types(mut self, types: Vec<String>) -> Self {
103        self.allowed_types = types;
104        self
105    }
106
107    /// Set maximum uses
108    pub fn with_max_uses(mut self, max: u32) -> Self {
109        self.max_uses = max;
110        self
111    }
112
113    /// Check if token is expired
114    pub fn is_expired(&self) -> bool {
115        Instant::now() > self.expires_at
116    }
117
118    /// Check if token has remaining uses
119    pub fn has_remaining_uses(&self) -> bool {
120        self.usage_count < self.max_uses
121    }
122
123    /// Check if callback type is allowed
124    pub fn is_type_allowed(&self, callback_type: &str) -> bool {
125        self.allowed_types.contains(&callback_type.to_string())
126    }
127
128    /// Increment usage count
129    pub fn increment_usage(&mut self) {
130        self.usage_count += 1;
131    }
132}
133
134/// Validation result
135#[derive(Debug, Clone)]
136pub struct ValidationResult {
137    /// Whether validation passed
138    pub is_valid: bool,
139
140    /// The validated token info (if valid)
141    pub token_info: Option<TokenInfo>,
142
143    /// Error message (if invalid)
144    pub error: Option<String>,
145}
146
147impl ValidationResult {
148    /// Create a successful validation result
149    pub fn success(token_info: TokenInfo) -> Self {
150        Self {
151            is_valid: true,
152            token_info: Some(token_info),
153            error: None,
154        }
155    }
156
157    /// Create a failed validation result
158    pub fn failure(error: impl Into<String>) -> Self {
159        Self {
160            is_valid: false,
161            token_info: None,
162            error: Some(error.into()),
163        }
164    }
165}
166
167/// Security configuration
168#[derive(Debug, Clone)]
169pub struct SecurityConfig {
170    /// Token lifetime in seconds
171    pub token_lifetime_secs: u64,
172
173    /// Maximum uses per token
174    pub max_token_uses: u32,
175
176    /// Maximum tokens per service
177    pub max_tokens_per_service: u32,
178
179    /// Rate limit: requests per minute per service
180    pub rate_limit_per_minute: u32,
181
182    /// Enable strict validation
183    pub strict_validation: bool,
184}
185
186impl Default for SecurityConfig {
187    fn default() -> Self {
188        Self {
189            token_lifetime_secs: 300, // 5 minutes
190            max_token_uses: 10,
191            max_tokens_per_service: 100,
192            rate_limit_per_minute: 60,
193            strict_validation: true,
194        }
195    }
196}
197
198/// Rate limit tracker
199#[derive(Debug, Clone)]
200struct RateLimitEntry {
201    /// Request timestamps
202    timestamps: Vec<Instant>,
203    /// Last cleanup time
204    last_cleanup: Instant,
205}
206
207impl RateLimitEntry {
208    fn new() -> Self {
209        Self {
210            timestamps: Vec::new(),
211            last_cleanup: Instant::now(),
212        }
213    }
214
215    fn add_request(&mut self) {
216        self.timestamps.push(Instant::now());
217        // Cleanup old entries every minute
218        if self.last_cleanup.elapsed() > Duration::from_secs(60) {
219            self.cleanup();
220            self.last_cleanup = Instant::now();
221        }
222    }
223
224    fn cleanup(&mut self) {
225        let cutoff = Instant::now() - Duration::from_secs(60);
226        self.timestamps.retain(|t| *t > cutoff);
227    }
228
229    fn count_last_minute(&self) -> u32 {
230        let cutoff = Instant::now() - Duration::from_secs(60);
231        self.timestamps.iter().filter(|t| **t > cutoff).count() as u32
232    }
233}
234
235/// Security Validator
236///
237/// Validates callback requests from external services.
238/// Manages token generation, validation, and rate limiting.
239pub struct SecurityValidator {
240    /// Configuration
241    config: SecurityConfig,
242
243    /// Active tokens
244    tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
245
246    /// Service to token mapping
247    service_tokens: Arc<RwLock<HashMap<ServiceId, Vec<String>>>>,
248
249    /// Rate limit tracking
250    rate_limits: Arc<RwLock<HashMap<ServiceId, RateLimitEntry>>>,
251}
252
253impl SecurityValidator {
254    /// Create a new security validator
255    pub fn new() -> Self {
256        Self::with_config(SecurityConfig::default())
257    }
258
259    /// Create a new security validator with configuration
260    pub fn with_config(config: SecurityConfig) -> Self {
261        Self {
262            config,
263            tokens: Arc::new(RwLock::new(HashMap::new())),
264            service_tokens: Arc::new(RwLock::new(HashMap::new())),
265            rate_limits: Arc::new(RwLock::new(HashMap::new())),
266        }
267    }
268
269    /// Generate a new token for a callback request
270    pub async fn generate_token(
271        &self,
272        service_id: ServiceId,
273        request_id: String,
274        allowed_types: Vec<String>,
275    ) -> Result<String, SecurityError> {
276        // Check if service has too many tokens
277        {
278            let service_tokens = self.service_tokens.read().await;
279            if let Some(tokens) = service_tokens.get(&service_id) {
280                if tokens.len() >= self.config.max_tokens_per_service as usize {
281                    return Err(SecurityError::RateLimitExceeded(service_id.to_string()));
282                }
283            }
284        }
285
286        // Generate unique token
287        let token = format!(
288            "cb_{}_{}",
289            uuid::Uuid::new_v4().to_string(),
290            &request_id[..8.min(request_id.len())]
291        );
292
293        // Create token info
294        let token_info = TokenInfo::new(
295            token.clone(),
296            service_id.clone(),
297            request_id,
298            self.config.token_lifetime_secs,
299        )
300        .with_allowed_types(allowed_types)
301        .with_max_uses(self.config.max_token_uses);
302
303        // Store token
304        {
305            let mut tokens = self.tokens.write().await;
306            tokens.insert(token.clone(), token_info);
307        }
308
309        // Update service token mapping
310        {
311            let mut service_tokens = self.service_tokens.write().await;
312            service_tokens
313                .entry(service_id)
314                .or_insert_with(Vec::new)
315                .push(token.clone());
316        }
317
318        Ok(token)
319    }
320
321    /// Validate a callback request
322    pub async fn validate(
323        &self,
324        token: &str,
325        service_id: &ServiceId,
326        request_id: &str,
327        callback_type: &str,
328    ) -> ValidationResult {
329        // Check rate limit first
330        {
331            let mut rate_limits = self.rate_limits.write().await;
332            let entry = rate_limits
333                .entry(service_id.clone())
334                .or_insert_with(RateLimitEntry::new);
335
336            if entry.count_last_minute() >= self.config.rate_limit_per_minute {
337                return ValidationResult::failure(SecurityError::RateLimitExceeded(
338                    service_id.to_string(),
339                ).to_string());
340            }
341
342            entry.add_request();
343        }
344
345        // Look up token
346        let mut tokens = self.tokens.write().await;
347        let token_info = match tokens.get_mut(token) {
348            Some(info) => info,
349            None => return ValidationResult::failure(SecurityError::InvalidToken.to_string()),
350        };
351
352        // Check expiration
353        if token_info.is_expired() {
354            tokens.remove(token);
355            return ValidationResult::failure(
356                SecurityError::TokenExpired("token has expired".to_string()).to_string(),
357            );
358        }
359
360        // Check remaining uses
361        if !token_info.has_remaining_uses() {
362            return ValidationResult::failure("Token usage limit exceeded".to_string());
363        }
364
365        // Validate service ID
366        if token_info.service_id != *service_id {
367            return ValidationResult::failure(
368                SecurityError::ServiceNotAuthorized(service_id.to_string()).to_string(),
369            );
370        }
371
372        // Validate request ID
373        if self.config.strict_validation && token_info.request_id != request_id {
374            return ValidationResult::failure(
375                SecurityError::RequestIdMismatch(request_id.to_string()).to_string(),
376            );
377        }
378
379        // Validate callback type
380        if !token_info.is_type_allowed(callback_type) {
381            return ValidationResult::failure(format!(
382                "Callback type '{}' is not allowed for this token",
383                callback_type
384            ));
385        }
386
387        // Increment usage
388        token_info.increment_usage();
389
390        ValidationResult::success(token_info.clone())
391    }
392
393    /// Invalidate a token
394    pub async fn invalidate_token(&self, token: &str) -> Result<(), SecurityError> {
395        let token_info = {
396            let mut tokens = self.tokens.write().await;
397            tokens.remove(token)
398        };
399
400        if let Some(info) = token_info {
401            // Remove from service mapping
402            let mut service_tokens = self.service_tokens.write().await;
403            if let Some(tokens) = service_tokens.get_mut(&info.service_id) {
404                tokens.retain(|t| t != token);
405            }
406        }
407
408        Ok(())
409    }
410
411    /// Invalidate all tokens for a service
412    pub async fn invalidate_service_tokens(&self, service_id: &ServiceId) {
413        let tokens_to_remove = {
414            let service_tokens = self.service_tokens.read().await;
415            service_tokens.get(service_id).cloned().unwrap_or_default()
416        };
417
418        {
419            let mut tokens = self.tokens.write().await;
420            for token in &tokens_to_remove {
421                tokens.remove(token);
422            }
423        }
424
425        {
426            let mut service_tokens = self.service_tokens.write().await;
427            service_tokens.remove(service_id);
428        }
429    }
430
431    /// Clean up expired tokens
432    pub async fn cleanup_expired(&self) -> usize {
433        let expired_tokens: Vec<String> = {
434            let tokens = self.tokens.read().await;
435            tokens
436                .iter()
437                .filter(|(_, info)| info.is_expired())
438                .map(|(token, _)| token.clone())
439                .collect()
440        };
441
442        let count = expired_tokens.len();
443
444        for token in &expired_tokens {
445            self.invalidate_token(token).await.ok();
446        }
447
448        count
449    }
450
451    /// Get active token count
452    pub async fn token_count(&self) -> usize {
453        self.tokens.read().await.len()
454    }
455
456    /// Get token info
457    pub async fn get_token_info(&self, token: &str) -> Option<TokenInfo> {
458        self.tokens.read().await.get(token).cloned()
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[tokio::test]
467    async fn test_generate_token() {
468        let validator = SecurityValidator::new();
469        let service_id = ServiceId::new("test-service");
470        let request_id = "req-001".to_string();
471
472        let token = validator
473            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string(), "tool".to_string()])
474            .await
475            .unwrap();
476
477        assert!(token.starts_with("cb_"));
478        assert!(validator.token_count().await == 1);
479    }
480
481    #[tokio::test]
482    async fn test_validate_token() {
483        let validator = SecurityValidator::new();
484        let service_id = ServiceId::new("test-service");
485        let request_id = "req-001".to_string();
486
487        let token = validator
488            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
489            .await
490            .unwrap();
491
492        let result = validator
493            .validate(&token, &service_id, &request_id, "ai")
494            .await;
495
496        assert!(result.is_valid);
497        assert!(result.token_info.is_some());
498    }
499
500    #[tokio::test]
501    async fn test_validate_invalid_token() {
502        let validator = SecurityValidator::new();
503        let service_id = ServiceId::new("test-service");
504
505        let result = validator
506            .validate("invalid_token", &service_id, "req-001", "ai")
507            .await;
508
509        assert!(!result.is_valid);
510        assert!(result.error.is_some());
511    }
512
513    #[tokio::test]
514    async fn test_validate_wrong_service() {
515        let validator = SecurityValidator::new();
516        let service_id1 = ServiceId::new("service1");
517        let service_id2 = ServiceId::new("service2");
518        let request_id = "req-001".to_string();
519
520        let token = validator
521            .generate_token(service_id1.clone(), request_id.clone(), vec!["ai".to_string()])
522            .await
523            .unwrap();
524
525        let result = validator
526            .validate(&token, &service_id2, &request_id, "ai")
527            .await;
528
529        assert!(!result.is_valid);
530    }
531
532    #[tokio::test]
533    async fn test_validate_wrong_callback_type() {
534        let validator = SecurityValidator::new();
535        let service_id = ServiceId::new("test-service");
536        let request_id = "req-001".to_string();
537
538        let token = validator
539            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
540            .await
541            .unwrap();
542
543        let result = validator
544            .validate(&token, &service_id, &request_id, "tool")
545            .await;
546
547        assert!(!result.is_valid);
548    }
549
550    #[tokio::test]
551    async fn test_invalidate_token() {
552        let validator = SecurityValidator::new();
553        let service_id = ServiceId::new("test-service");
554        let request_id = "req-001".to_string();
555
556        let token = validator
557            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
558            .await
559            .unwrap();
560
561        validator.invalidate_token(&token).await.unwrap();
562        assert!(validator.token_count().await == 0);
563    }
564
565    #[tokio::test]
566    async fn test_token_usage_limit() {
567        let config = SecurityConfig {
568            max_token_uses: 2,
569            ..Default::default()
570        };
571        let validator = SecurityValidator::with_config(config);
572        let service_id = ServiceId::new("test-service");
573        let request_id = "req-001".to_string();
574
575        let token = validator
576            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
577            .await
578            .unwrap();
579
580        // First use
581        let result1 = validator.validate(&token, &service_id, &request_id, "ai").await;
582        assert!(result1.is_valid);
583
584        // Second use
585        let result2 = validator.validate(&token, &service_id, &request_id, "ai").await;
586        assert!(result2.is_valid);
587
588        // Third use should fail
589        let result3 = validator.validate(&token, &service_id, &request_id, "ai").await;
590        assert!(!result3.is_valid);
591    }
592}