Skip to main content

rmcp_memex/security/
mod.rs

1//! Security module for namespace access control.
2//!
3//! This module provides token-based access control for namespaces.
4//! Each namespace can have an associated access token that must be provided
5//! when reading or writing data to that namespace.
6
7use anyhow::{Result, anyhow};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14use uuid::Uuid;
15
16/// Token prefix for namespace access tokens
17const TOKEN_PREFIX: &str = "ns_";
18
19/// Configuration for namespace security
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct NamespaceSecurityConfig {
22    /// Whether token-based access control is enabled
23    #[serde(default)]
24    pub enabled: bool,
25    /// Path to the token store file
26    #[serde(default)]
27    pub token_store_path: Option<String>,
28}
29
30/// Stored token information for a namespace
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct NamespaceToken {
33    /// The namespace this token grants access to
34    pub namespace: String,
35    /// The actual token value
36    pub token: String,
37    /// When the token was created (Unix timestamp)
38    pub created_at: u64,
39    /// Optional description/label for the token
40    pub description: Option<String>,
41}
42
43/// Token store for managing namespace access tokens
44#[derive(Debug)]
45pub struct TokenStore {
46    /// Map of namespace -> token
47    tokens: Arc<RwLock<HashMap<String, NamespaceToken>>>,
48    /// Path to persist tokens (if any)
49    store_path: Option<String>,
50}
51
52impl TokenStore {
53    /// Create a new token store
54    pub fn new(store_path: Option<String>) -> Self {
55        Self {
56            tokens: Arc::new(RwLock::new(HashMap::new())),
57            store_path,
58        }
59    }
60
61    /// Load tokens from persistent storage
62    pub async fn load(&self) -> Result<()> {
63        if let Some(path) = &self.store_path {
64            let expanded = shellexpand::tilde(path).to_string();
65            let path = Path::new(&expanded);
66
67            if path.exists() {
68                let contents = tokio::fs::read_to_string(path).await?;
69                let loaded: HashMap<String, NamespaceToken> = serde_json::from_str(&contents)?;
70                let mut tokens = self.tokens.write().await;
71                *tokens = loaded;
72                info!("Loaded {} namespace tokens from {}", tokens.len(), expanded);
73            }
74        }
75        Ok(())
76    }
77
78    /// Save tokens to persistent storage
79    pub async fn save(&self) -> Result<()> {
80        if let Some(path) = &self.store_path {
81            let expanded = shellexpand::tilde(path).to_string();
82            let path = Path::new(&expanded);
83
84            // Ensure parent directory exists
85            if let Some(parent) = path.parent() {
86                tokio::fs::create_dir_all(parent).await?;
87            }
88
89            let tokens = self.tokens.read().await;
90            let contents = serde_json::to_string_pretty(&*tokens)?;
91            tokio::fs::write(path, contents).await?;
92            debug!("Saved {} namespace tokens to {}", tokens.len(), expanded);
93        }
94        Ok(())
95    }
96
97    /// Generate a new token for a namespace
98    pub fn generate_token() -> String {
99        format!(
100            "{}{}",
101            TOKEN_PREFIX,
102            Uuid::new_v4().to_string().replace("-", "")
103        )
104    }
105
106    /// Create or update a token for a namespace
107    pub async fn create_token(
108        &self,
109        namespace: &str,
110        description: Option<String>,
111    ) -> Result<String> {
112        let token = Self::generate_token();
113        let namespace_token = NamespaceToken {
114            namespace: namespace.to_string(),
115            token: token.clone(),
116            created_at: std::time::SystemTime::now()
117                .duration_since(std::time::UNIX_EPOCH)
118                .unwrap_or_default()
119                .as_secs(),
120            description,
121        };
122
123        {
124            let mut tokens = self.tokens.write().await;
125            tokens.insert(namespace.to_string(), namespace_token);
126        }
127
128        self.save().await?;
129        info!("Created token for namespace '{}'", namespace);
130        Ok(token)
131    }
132
133    /// Verify a token for a namespace
134    pub async fn verify_token(&self, namespace: &str, token: &str) -> bool {
135        let tokens = self.tokens.read().await;
136        if let Some(stored) = tokens.get(namespace) {
137            stored.token == token
138        } else {
139            // If no token is set for this namespace, access is allowed
140            // (backward compatibility - namespaces without tokens are open)
141            true
142        }
143    }
144
145    /// Check if a namespace has a token set
146    pub async fn has_token(&self, namespace: &str) -> bool {
147        let tokens = self.tokens.read().await;
148        tokens.contains_key(namespace)
149    }
150
151    /// Get token info for a namespace (without revealing the actual token)
152    pub async fn get_token_info(&self, namespace: &str) -> Option<(u64, Option<String>)> {
153        let tokens = self.tokens.read().await;
154        tokens
155            .get(namespace)
156            .map(|t| (t.created_at, t.description.clone()))
157    }
158
159    /// Revoke (delete) a token for a namespace
160    pub async fn revoke_token(&self, namespace: &str) -> Result<bool> {
161        let removed = {
162            let mut tokens = self.tokens.write().await;
163            tokens.remove(namespace).is_some()
164        };
165
166        if removed {
167            self.save().await?;
168            info!("Revoked token for namespace '{}'", namespace);
169        }
170
171        Ok(removed)
172    }
173
174    /// List all namespaces that have tokens (without revealing tokens)
175    pub async fn list_protected_namespaces(&self) -> Vec<(String, u64, Option<String>)> {
176        let tokens = self.tokens.read().await;
177        tokens
178            .values()
179            .map(|t| (t.namespace.clone(), t.created_at, t.description.clone()))
180            .collect()
181    }
182}
183
184/// Namespace access manager that combines token verification with access control
185#[derive(Debug)]
186pub struct NamespaceAccessManager {
187    /// Token store for managing tokens
188    token_store: TokenStore,
189    /// Whether token-based access control is enabled
190    enabled: bool,
191}
192
193impl NamespaceAccessManager {
194    /// Create a new namespace access manager
195    pub fn new(config: NamespaceSecurityConfig) -> Self {
196        let store_path = config.token_store_path.or_else(|| {
197            if config.enabled {
198                Some("~/.rmcp-servers/rmcp-memex/tokens.json".to_string())
199            } else {
200                None
201            }
202        });
203
204        Self {
205            token_store: TokenStore::new(store_path),
206            enabled: config.enabled,
207        }
208    }
209
210    /// Initialize the access manager (load tokens from storage)
211    pub async fn init(&self) -> Result<()> {
212        if self.enabled {
213            self.token_store.load().await?;
214        }
215        Ok(())
216    }
217
218    /// Check if access control is enabled
219    pub fn is_enabled(&self) -> bool {
220        self.enabled
221    }
222
223    /// Verify access to a namespace
224    /// Returns Ok(()) if access is granted, Err if denied
225    pub async fn verify_access(&self, namespace: &str, token: Option<&str>) -> Result<()> {
226        if !self.enabled {
227            return Ok(());
228        }
229
230        // Check if namespace has a token
231        if !self.token_store.has_token(namespace).await {
232            // No token set for this namespace - allow access
233            return Ok(());
234        }
235
236        // Namespace has a token - verify it
237        match token {
238            Some(t) => {
239                if self.token_store.verify_token(namespace, t).await {
240                    Ok(())
241                } else {
242                    warn!("Invalid token provided for namespace '{}'", namespace);
243                    Err(anyhow!(
244                        "Access denied: invalid token for namespace '{}'",
245                        namespace
246                    ))
247                }
248            }
249            None => {
250                warn!("No token provided for protected namespace '{}'", namespace);
251                Err(anyhow!(
252                    "Access denied: namespace '{}' requires a token. Use namespace_create_token to generate one.",
253                    namespace
254                ))
255            }
256        }
257    }
258
259    /// Create a token for a namespace
260    pub async fn create_token(
261        &self,
262        namespace: &str,
263        description: Option<String>,
264    ) -> Result<String> {
265        self.token_store.create_token(namespace, description).await
266    }
267
268    /// Revoke a token for a namespace
269    pub async fn revoke_token(&self, namespace: &str) -> Result<bool> {
270        self.token_store.revoke_token(namespace).await
271    }
272
273    /// List protected namespaces
274    pub async fn list_protected_namespaces(&self) -> Vec<(String, u64, Option<String>)> {
275        self.token_store.list_protected_namespaces().await
276    }
277
278    /// Get token info for a namespace
279    pub async fn get_token_info(&self, namespace: &str) -> Option<(u64, Option<String>)> {
280        self.token_store.get_token_info(namespace).await
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[tokio::test]
289    async fn test_token_generation() {
290        let token = TokenStore::generate_token();
291        assert!(token.starts_with(TOKEN_PREFIX));
292        assert!(token.len() > TOKEN_PREFIX.len());
293    }
294
295    #[tokio::test]
296    async fn test_token_store_create_and_verify() {
297        let store = TokenStore::new(None);
298
299        let token = store
300            .create_token("test_namespace", Some("Test token".to_string()))
301            .await
302            .unwrap();
303
304        assert!(store.verify_token("test_namespace", &token).await);
305        assert!(!store.verify_token("test_namespace", "wrong_token").await);
306        assert!(store.verify_token("other_namespace", "any_token").await); // No token set
307    }
308
309    #[tokio::test]
310    async fn test_access_manager_disabled() {
311        let config = NamespaceSecurityConfig::default();
312        let manager = NamespaceAccessManager::new(config);
313
314        // When disabled, all access should be allowed
315        assert!(manager.verify_access("any_namespace", None).await.is_ok());
316    }
317
318    #[tokio::test]
319    async fn test_access_manager_enabled() {
320        let config = NamespaceSecurityConfig {
321            enabled: true,
322            token_store_path: None,
323        };
324        let manager = NamespaceAccessManager::new(config);
325
326        // Create a token for a namespace
327        let token = manager
328            .create_token("protected", Some("Test".to_string()))
329            .await
330            .unwrap();
331
332        // Access without token should fail
333        assert!(manager.verify_access("protected", None).await.is_err());
334
335        // Access with wrong token should fail
336        assert!(
337            manager
338                .verify_access("protected", Some("wrong"))
339                .await
340                .is_err()
341        );
342
343        // Access with correct token should succeed
344        assert!(
345            manager
346                .verify_access("protected", Some(&token))
347                .await
348                .is_ok()
349        );
350
351        // Unprotected namespace should allow access without token
352        assert!(manager.verify_access("unprotected", None).await.is_ok());
353    }
354}