matrixcode_core/matrixrpc/callback/
security.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use tokio::sync::RwLock;
11
12use crate::matrixrpc::ServiceId;
13
14#[derive(Debug, thiserror::Error)]
16pub enum SecurityError {
17 #[error("Invalid or expired token")]
19 InvalidToken,
20
21 #[error("Token expired at {0}")]
23 TokenExpired(String),
24
25 #[error("Service '{0}' is not authorized for this callback")]
27 ServiceNotAuthorized(String),
28
29 #[error("Request ID '{0}' does not match the token")]
31 RequestIdMismatch(String),
32
33 #[error("Missing required field: {0}")]
35 MissingField(String),
36
37 #[error("Token generation failed: {0}")]
39 TokenGenerationFailed(String),
40
41 #[error("Rate limit exceeded for service '{0}'")]
43 RateLimitExceeded(String),
44
45 #[error("Internal security error: {0}")]
47 Internal(String),
48}
49
50#[derive(Debug, Clone)]
52pub struct TokenInfo {
53 pub token: String,
55
56 pub service_id: ServiceId,
58
59 pub request_id: String,
61
62 pub created_at: Instant,
64
65 pub expires_at: Instant,
67
68 pub allowed_types: Vec<String>,
70
71 pub usage_count: u32,
73
74 pub max_uses: u32,
76}
77
78impl TokenInfo {
79 pub fn new(
81 token: String,
82 service_id: ServiceId,
83 request_id: String,
84 lifetime_secs: u64,
85 ) -> Self {
86 let now = Instant::now();
87 Self {
88 token,
89 service_id,
90 request_id,
91 created_at: now,
92 expires_at: now + Duration::from_secs(lifetime_secs),
93 allowed_types: vec![
94 "ai".to_string(), "tool".to_string(), "context".to_string(),
95 ],
96 usage_count: 0,
97 max_uses: 10,
98 }
99 }
100
101 pub fn with_allowed_types(mut self, types: Vec<String>) -> Self {
103 self.allowed_types = types;
104 self
105 }
106
107 pub fn with_max_uses(mut self, max: u32) -> Self {
109 self.max_uses = max;
110 self
111 }
112
113 pub fn is_expired(&self) -> bool {
115 Instant::now() > self.expires_at
116 }
117
118 pub fn has_remaining_uses(&self) -> bool {
120 self.usage_count < self.max_uses
121 }
122
123 pub fn is_type_allowed(&self, callback_type: &str) -> bool {
125 self.allowed_types.contains(&callback_type.to_string())
126 }
127
128 pub fn increment_usage(&mut self) {
130 self.usage_count += 1;
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct ValidationResult {
137 pub is_valid: bool,
139
140 pub token_info: Option<TokenInfo>,
142
143 pub error: Option<String>,
145}
146
147impl ValidationResult {
148 pub fn success(token_info: TokenInfo) -> Self {
150 Self {
151 is_valid: true,
152 token_info: Some(token_info),
153 error: None,
154 }
155 }
156
157 pub fn failure(error: impl Into<String>) -> Self {
159 Self {
160 is_valid: false,
161 token_info: None,
162 error: Some(error.into()),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct SecurityConfig {
170 pub token_lifetime_secs: u64,
172
173 pub max_token_uses: u32,
175
176 pub max_tokens_per_service: u32,
178
179 pub rate_limit_per_minute: u32,
181
182 pub strict_validation: bool,
184}
185
186impl Default for SecurityConfig {
187 fn default() -> Self {
188 Self {
189 token_lifetime_secs: 300, max_token_uses: 10,
191 max_tokens_per_service: 100,
192 rate_limit_per_minute: 60,
193 strict_validation: true,
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
200struct RateLimitEntry {
201 timestamps: Vec<Instant>,
203 last_cleanup: Instant,
205}
206
207impl RateLimitEntry {
208 fn new() -> Self {
209 Self {
210 timestamps: Vec::new(),
211 last_cleanup: Instant::now(),
212 }
213 }
214
215 fn add_request(&mut self) {
216 self.timestamps.push(Instant::now());
217 if self.last_cleanup.elapsed() > Duration::from_secs(60) {
219 self.cleanup();
220 self.last_cleanup = Instant::now();
221 }
222 }
223
224 fn cleanup(&mut self) {
225 let cutoff = Instant::now() - Duration::from_secs(60);
226 self.timestamps.retain(|t| *t > cutoff);
227 }
228
229 fn count_last_minute(&self) -> u32 {
230 let cutoff = Instant::now() - Duration::from_secs(60);
231 self.timestamps.iter().filter(|t| **t > cutoff).count() as u32
232 }
233}
234
235pub struct SecurityValidator {
240 config: SecurityConfig,
242
243 tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
245
246 service_tokens: Arc<RwLock<HashMap<ServiceId, Vec<String>>>>,
248
249 rate_limits: Arc<RwLock<HashMap<ServiceId, RateLimitEntry>>>,
251}
252
253impl SecurityValidator {
254 pub fn new() -> Self {
256 Self::with_config(SecurityConfig::default())
257 }
258
259 pub fn with_config(config: SecurityConfig) -> Self {
261 Self {
262 config,
263 tokens: Arc::new(RwLock::new(HashMap::new())),
264 service_tokens: Arc::new(RwLock::new(HashMap::new())),
265 rate_limits: Arc::new(RwLock::new(HashMap::new())),
266 }
267 }
268
269 pub async fn generate_token(
271 &self,
272 service_id: ServiceId,
273 request_id: String,
274 allowed_types: Vec<String>,
275 ) -> Result<String, SecurityError> {
276 {
278 let service_tokens = self.service_tokens.read().await;
279 if let Some(tokens) = service_tokens.get(&service_id) {
280 if tokens.len() >= self.config.max_tokens_per_service as usize {
281 return Err(SecurityError::RateLimitExceeded(service_id.to_string()));
282 }
283 }
284 }
285
286 let token = format!(
288 "cb_{}_{}",
289 uuid::Uuid::new_v4().to_string(),
290 &request_id[..8.min(request_id.len())]
291 );
292
293 let token_info = TokenInfo::new(
295 token.clone(),
296 service_id.clone(),
297 request_id,
298 self.config.token_lifetime_secs,
299 )
300 .with_allowed_types(allowed_types)
301 .with_max_uses(self.config.max_token_uses);
302
303 {
305 let mut tokens = self.tokens.write().await;
306 tokens.insert(token.clone(), token_info);
307 }
308
309 {
311 let mut service_tokens = self.service_tokens.write().await;
312 service_tokens
313 .entry(service_id)
314 .or_insert_with(Vec::new)
315 .push(token.clone());
316 }
317
318 Ok(token)
319 }
320
321 pub async fn validate(
323 &self,
324 token: &str,
325 service_id: &ServiceId,
326 request_id: &str,
327 callback_type: &str,
328 ) -> ValidationResult {
329 {
331 let mut rate_limits = self.rate_limits.write().await;
332 let entry = rate_limits
333 .entry(service_id.clone())
334 .or_insert_with(RateLimitEntry::new);
335
336 if entry.count_last_minute() >= self.config.rate_limit_per_minute {
337 return ValidationResult::failure(SecurityError::RateLimitExceeded(
338 service_id.to_string(),
339 ).to_string());
340 }
341
342 entry.add_request();
343 }
344
345 let mut tokens = self.tokens.write().await;
347 let token_info = match tokens.get_mut(token) {
348 Some(info) => info,
349 None => return ValidationResult::failure(SecurityError::InvalidToken.to_string()),
350 };
351
352 if token_info.is_expired() {
354 tokens.remove(token);
355 return ValidationResult::failure(
356 SecurityError::TokenExpired("token has expired".to_string()).to_string(),
357 );
358 }
359
360 if !token_info.has_remaining_uses() {
362 return ValidationResult::failure("Token usage limit exceeded".to_string());
363 }
364
365 if token_info.service_id != *service_id {
367 return ValidationResult::failure(
368 SecurityError::ServiceNotAuthorized(service_id.to_string()).to_string(),
369 );
370 }
371
372 if self.config.strict_validation && token_info.request_id != request_id {
374 return ValidationResult::failure(
375 SecurityError::RequestIdMismatch(request_id.to_string()).to_string(),
376 );
377 }
378
379 if !token_info.is_type_allowed(callback_type) {
381 return ValidationResult::failure(format!(
382 "Callback type '{}' is not allowed for this token",
383 callback_type
384 ));
385 }
386
387 token_info.increment_usage();
389
390 ValidationResult::success(token_info.clone())
391 }
392
393 pub async fn invalidate_token(&self, token: &str) -> Result<(), SecurityError> {
395 let token_info = {
396 let mut tokens = self.tokens.write().await;
397 tokens.remove(token)
398 };
399
400 if let Some(info) = token_info {
401 let mut service_tokens = self.service_tokens.write().await;
403 if let Some(tokens) = service_tokens.get_mut(&info.service_id) {
404 tokens.retain(|t| t != token);
405 }
406 }
407
408 Ok(())
409 }
410
411 pub async fn invalidate_service_tokens(&self, service_id: &ServiceId) {
413 let tokens_to_remove = {
414 let service_tokens = self.service_tokens.read().await;
415 service_tokens.get(service_id).cloned().unwrap_or_default()
416 };
417
418 {
419 let mut tokens = self.tokens.write().await;
420 for token in &tokens_to_remove {
421 tokens.remove(token);
422 }
423 }
424
425 {
426 let mut service_tokens = self.service_tokens.write().await;
427 service_tokens.remove(service_id);
428 }
429 }
430
431 pub async fn cleanup_expired(&self) -> usize {
433 let expired_tokens: Vec<String> = {
434 let tokens = self.tokens.read().await;
435 tokens
436 .iter()
437 .filter(|(_, info)| info.is_expired())
438 .map(|(token, _)| token.clone())
439 .collect()
440 };
441
442 let count = expired_tokens.len();
443
444 for token in &expired_tokens {
445 self.invalidate_token(token).await.ok();
446 }
447
448 count
449 }
450
451 pub async fn token_count(&self) -> usize {
453 self.tokens.read().await.len()
454 }
455
456 pub async fn get_token_info(&self, token: &str) -> Option<TokenInfo> {
458 self.tokens.read().await.get(token).cloned()
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[tokio::test]
467 async fn test_generate_token() {
468 let validator = SecurityValidator::new();
469 let service_id = ServiceId::new("test-service");
470 let request_id = "req-001".to_string();
471
472 let token = validator
473 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string(), "tool".to_string()])
474 .await
475 .unwrap();
476
477 assert!(token.starts_with("cb_"));
478 assert!(validator.token_count().await == 1);
479 }
480
481 #[tokio::test]
482 async fn test_validate_token() {
483 let validator = SecurityValidator::new();
484 let service_id = ServiceId::new("test-service");
485 let request_id = "req-001".to_string();
486
487 let token = validator
488 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
489 .await
490 .unwrap();
491
492 let result = validator
493 .validate(&token, &service_id, &request_id, "ai")
494 .await;
495
496 assert!(result.is_valid);
497 assert!(result.token_info.is_some());
498 }
499
500 #[tokio::test]
501 async fn test_validate_invalid_token() {
502 let validator = SecurityValidator::new();
503 let service_id = ServiceId::new("test-service");
504
505 let result = validator
506 .validate("invalid_token", &service_id, "req-001", "ai")
507 .await;
508
509 assert!(!result.is_valid);
510 assert!(result.error.is_some());
511 }
512
513 #[tokio::test]
514 async fn test_validate_wrong_service() {
515 let validator = SecurityValidator::new();
516 let service_id1 = ServiceId::new("service1");
517 let service_id2 = ServiceId::new("service2");
518 let request_id = "req-001".to_string();
519
520 let token = validator
521 .generate_token(service_id1.clone(), request_id.clone(), vec!["ai".to_string()])
522 .await
523 .unwrap();
524
525 let result = validator
526 .validate(&token, &service_id2, &request_id, "ai")
527 .await;
528
529 assert!(!result.is_valid);
530 }
531
532 #[tokio::test]
533 async fn test_validate_wrong_callback_type() {
534 let validator = SecurityValidator::new();
535 let service_id = ServiceId::new("test-service");
536 let request_id = "req-001".to_string();
537
538 let token = validator
539 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
540 .await
541 .unwrap();
542
543 let result = validator
544 .validate(&token, &service_id, &request_id, "tool")
545 .await;
546
547 assert!(!result.is_valid);
548 }
549
550 #[tokio::test]
551 async fn test_invalidate_token() {
552 let validator = SecurityValidator::new();
553 let service_id = ServiceId::new("test-service");
554 let request_id = "req-001".to_string();
555
556 let token = validator
557 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
558 .await
559 .unwrap();
560
561 validator.invalidate_token(&token).await.unwrap();
562 assert!(validator.token_count().await == 0);
563 }
564
565 #[tokio::test]
566 async fn test_token_usage_limit() {
567 let config = SecurityConfig {
568 max_token_uses: 2,
569 ..Default::default()
570 };
571 let validator = SecurityValidator::with_config(config);
572 let service_id = ServiceId::new("test-service");
573 let request_id = "req-001".to_string();
574
575 let token = validator
576 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
577 .await
578 .unwrap();
579
580 let result1 = validator.validate(&token, &service_id, &request_id, "ai").await;
582 assert!(result1.is_valid);
583
584 let result2 = validator.validate(&token, &service_id, &request_id, "ai").await;
586 assert!(result2.is_valid);
587
588 let result3 = validator.validate(&token, &service_id, &request_id, "ai").await;
590 assert!(!result3.is_valid);
591 }
592}