Skip to main content

ai_session/security/
mod.rs

1//! Security and isolation features for AI sessions
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8/// Secure session with isolation and access control
9pub struct SecureSession {
10    /// Namespace isolation
11    pub namespace: Namespace,
12    /// Resource limits
13    pub cgroups: CGroupLimits,
14    /// Security policy
15    pub security_policy: SecurityPolicy,
16    /// Audit log
17    pub audit_log: AuditLog,
18}
19
20impl SecureSession {
21    /// Create a new secure session
22    pub fn new(session_id: &str) -> Self {
23        Self {
24            namespace: Namespace::new(session_id),
25            cgroups: CGroupLimits::default(),
26            security_policy: SecurityPolicy::default(),
27            audit_log: AuditLog::new(),
28        }
29    }
30
31    /// Apply security policy
32    pub fn apply_policy(&mut self, policy: SecurityPolicy) -> Result<()> {
33        self.security_policy = policy;
34        self.audit_log.log(AuditEvent::PolicyApplied {
35            policy_name: "custom".to_string(),
36        })?;
37        Ok(())
38    }
39
40    /// Check if an action is allowed
41    pub fn is_allowed(&self, action: &Action) -> bool {
42        match action {
43            Action::FileAccess { path, mode } => {
44                self.security_policy.fs_permissions.is_allowed(path, mode)
45            }
46            Action::NetworkAccess { host, port } => {
47                self.security_policy.network_policy.is_allowed(host, *port)
48            }
49            Action::SystemCall { syscall } => {
50                self.security_policy.syscall_access.is_allowed(syscall)
51            }
52            Action::APICall { endpoint, method } => {
53                self.security_policy.api_limits.is_allowed(endpoint, method)
54            }
55        }
56    }
57
58    /// Audit an action
59    pub fn audit_action(&mut self, action: Action, allowed: bool) -> Result<()> {
60        self.audit_log.log(AuditEvent::ActionAttempted {
61            action,
62            allowed,
63            timestamp: chrono::Utc::now(),
64        })
65    }
66}
67
68/// Namespace isolation
69#[derive(Debug, Clone)]
70pub struct Namespace {
71    /// Namespace ID
72    pub id: String,
73    /// PID namespace
74    pub pid_ns: bool,
75    /// Network namespace
76    pub net_ns: bool,
77    /// Mount namespace
78    pub mnt_ns: bool,
79    /// User namespace
80    pub user_ns: bool,
81}
82
83impl Namespace {
84    /// Create a new namespace
85    pub fn new(id: &str) -> Self {
86        Self {
87            id: format!("ai-session-{}", id),
88            pid_ns: true,
89            net_ns: true,
90            mnt_ns: true,
91            user_ns: true,
92        }
93    }
94}
95
96/// Resource limits using cgroups
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct CGroupLimits {
99    /// CPU limit (percentage)
100    pub cpu_limit: f64,
101    /// Memory limit (bytes)
102    pub memory_limit: usize,
103    /// Disk I/O limit (bytes/sec)
104    pub io_limit: usize,
105    /// Maximum processes
106    pub pids_limit: usize,
107}
108
109impl Default for CGroupLimits {
110    fn default() -> Self {
111        Self {
112            cpu_limit: 50.0,       // 50% CPU
113            memory_limit: 1 << 30, // 1GB
114            io_limit: 100 << 20,   // 100MB/s
115            pids_limit: 100,       // 100 processes
116        }
117    }
118}
119
120/// Security policy
121#[derive(Debug, Clone, Default)]
122pub struct SecurityPolicy {
123    /// File system access control
124    pub fs_permissions: FileSystemPermissions,
125    /// Network access control
126    pub network_policy: NetworkPolicy,
127    /// System call access
128    pub syscall_access: SyscallAccess,
129    /// API call limits
130    pub api_limits: APILimits,
131    /// Secret manager
132    pub secret_manager: SecretManager,
133}
134
135/// File system permissions
136#[derive(Debug, Clone, Default)]
137pub struct FileSystemPermissions {
138    /// Allowed paths
139    allowed_paths: Vec<PathBuf>,
140    /// Denied paths
141    denied_paths: Vec<PathBuf>,
142    /// Read-only paths
143    readonly_paths: Vec<PathBuf>,
144}
145
146impl FileSystemPermissions {
147    /// Check if file access is allowed
148    pub fn is_allowed(&self, path: &Path, mode: &FileAccessMode) -> bool {
149        // Check denied paths first
150        for denied in &self.denied_paths {
151            if path.starts_with(denied) {
152                return false;
153            }
154        }
155
156        // Check if in allowed paths
157        let in_allowed = self
158            .allowed_paths
159            .iter()
160            .any(|allowed| path.starts_with(allowed));
161        if !in_allowed && !self.allowed_paths.is_empty() {
162            return false;
163        }
164
165        // Check read-only restrictions
166        if matches!(mode, FileAccessMode::Write | FileAccessMode::Execute) {
167            for readonly in &self.readonly_paths {
168                if path.starts_with(readonly) {
169                    return false;
170                }
171            }
172        }
173
174        true
175    }
176
177    /// Add an allowed path
178    pub fn allow_path(&mut self, path: PathBuf) {
179        self.allowed_paths.push(path);
180    }
181
182    /// Add a denied path
183    pub fn deny_path(&mut self, path: PathBuf) {
184        self.denied_paths.push(path);
185    }
186
187    /// Add a read-only path
188    pub fn readonly_path(&mut self, path: PathBuf) {
189        self.readonly_paths.push(path);
190    }
191}
192
193/// File access mode
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub enum FileAccessMode {
196    Read,
197    Write,
198    Execute,
199}
200
201/// Network access policy
202#[derive(Debug, Clone, Default)]
203pub struct NetworkPolicy {
204    /// Allowed hosts
205    allowed_hosts: Vec<String>,
206    /// Blocked hosts
207    blocked_hosts: Vec<String>,
208    /// Allowed ports
209    allowed_ports: Vec<u16>,
210    /// Rate limits per host
211    rate_limits: HashMap<String, RateLimit>,
212}
213
214impl NetworkPolicy {
215    /// Check if network access is allowed
216    pub fn is_allowed(&self, host: &str, port: u16) -> bool {
217        // Check blocked hosts
218        if self.blocked_hosts.iter().any(|h| h == host) {
219            return false;
220        }
221
222        // Check allowed hosts
223        if !self.allowed_hosts.is_empty() && !self.allowed_hosts.iter().any(|h| h == host) {
224            return false;
225        }
226
227        // Check allowed ports
228        if !self.allowed_ports.is_empty() && !self.allowed_ports.contains(&port) {
229            return false;
230        }
231
232        // Check rate limits
233        if let Some(limit) = self.rate_limits.get(host) {
234            return limit.check();
235        }
236
237        true
238    }
239}
240
241/// System call access control
242#[derive(Debug, Clone, Default)]
243pub struct SyscallAccess {
244    /// Allowed syscalls
245    allowed_syscalls: Vec<String>,
246    /// Blocked syscalls
247    blocked_syscalls: Vec<String>,
248}
249
250impl SyscallAccess {
251    /// Check if syscall is allowed
252    pub fn is_allowed(&self, syscall: &str) -> bool {
253        if self.blocked_syscalls.contains(&syscall.to_string()) {
254            return false;
255        }
256
257        if !self.allowed_syscalls.is_empty() {
258            return self.allowed_syscalls.contains(&syscall.to_string());
259        }
260
261        true
262    }
263}
264
265/// API call limits
266#[derive(Debug, Clone, Default)]
267pub struct APILimits {
268    /// Rate limits per endpoint
269    endpoint_limits: HashMap<String, RateLimit>,
270    /// Global rate limit
271    global_limit: Option<RateLimit>,
272    /// Token limits
273    _token_limits: TokenLimits,
274}
275
276impl APILimits {
277    /// Check if API call is allowed
278    pub fn is_allowed(&self, endpoint: &str, _method: &str) -> bool {
279        // Check endpoint-specific limit
280        if let Some(limit) = self.endpoint_limits.get(endpoint)
281            && !limit.check()
282        {
283            return false;
284        }
285
286        // Check global limit
287        if let Some(ref limit) = self.global_limit
288            && !limit.check()
289        {
290            return false;
291        }
292
293        true
294    }
295}
296
297/// Token usage limits
298#[derive(Debug, Clone, Default)]
299pub struct TokenLimits {
300    /// Maximum tokens per request
301    pub max_tokens_per_request: usize,
302    /// Maximum tokens per hour
303    pub max_tokens_per_hour: usize,
304    /// Maximum tokens per day
305    pub max_tokens_per_day: usize,
306}
307
308/// Rate limiting
309#[derive(Debug)]
310pub struct RateLimit {
311    /// Requests per minute
312    pub requests_per_minute: usize,
313    /// Current minute
314    _current_minute: std::time::Instant,
315    /// Request count
316    request_count: std::sync::Mutex<usize>,
317}
318
319impl RateLimit {
320    /// Create a new rate limit
321    pub fn new(requests_per_minute: usize) -> Self {
322        Self {
323            requests_per_minute,
324            _current_minute: std::time::Instant::now(),
325            request_count: std::sync::Mutex::new(0),
326        }
327    }
328
329    /// Check if request is allowed and increment counter
330    pub fn check(&self) -> bool {
331        let mut count = self.request_count.lock().unwrap();
332        if *count < self.requests_per_minute {
333            *count += 1;
334            true
335        } else {
336            false
337        }
338    }
339}
340
341impl Clone for RateLimit {
342    fn clone(&self) -> Self {
343        let count = *self.request_count.lock().unwrap();
344        Self {
345            requests_per_minute: self.requests_per_minute,
346            _current_minute: self._current_minute,
347            request_count: std::sync::Mutex::new(count),
348        }
349    }
350}
351
352/// Secret manager
353#[derive(Debug, Clone, Default)]
354pub struct SecretManager {
355    /// Encrypted secrets
356    secrets: HashMap<String, Vec<u8>>,
357}
358
359impl SecretManager {
360    /// Store a secret
361    pub fn store_secret(&mut self, key: &str, value: &[u8]) -> Result<()> {
362        // In a real implementation, this would encrypt the value
363        self.secrets.insert(key.to_string(), value.to_vec());
364        Ok(())
365    }
366
367    /// Retrieve a secret
368    pub fn get_secret(&self, key: &str) -> Option<Vec<u8>> {
369        self.secrets.get(key).cloned()
370    }
371}
372
373/// Audit log
374pub struct AuditLog {
375    /// Log entries
376    entries: Vec<AuditEvent>,
377}
378
379impl AuditLog {
380    /// Create a new audit log
381    pub fn new() -> Self {
382        Self {
383            entries: Vec::new(),
384        }
385    }
386
387    /// Log an event
388    pub fn log(&mut self, event: AuditEvent) -> Result<()> {
389        self.entries.push(event);
390        Ok(())
391    }
392
393    /// Get all entries
394    pub fn entries(&self) -> &[AuditEvent] {
395        &self.entries
396    }
397}
398
399impl Default for AuditLog {
400    fn default() -> Self {
401        Self::new()
402    }
403}
404
405/// Audit event
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub enum AuditEvent {
408    /// Session created
409    SessionCreated {
410        session_id: String,
411        timestamp: chrono::DateTime<chrono::Utc>,
412    },
413    /// Policy applied
414    PolicyApplied { policy_name: String },
415    /// Action attempted
416    ActionAttempted {
417        action: Action,
418        allowed: bool,
419        timestamp: chrono::DateTime<chrono::Utc>,
420    },
421    /// Security violation
422    SecurityViolation {
423        description: String,
424        severity: Severity,
425        timestamp: chrono::DateTime<chrono::Utc>,
426    },
427}
428
429/// Security action
430#[derive(Debug, Clone, Serialize, Deserialize)]
431pub enum Action {
432    /// File system access
433    FileAccess { path: PathBuf, mode: FileAccessMode },
434    /// Network access
435    NetworkAccess { host: String, port: u16 },
436    /// System call
437    SystemCall { syscall: String },
438    /// API call
439    APICall { endpoint: String, method: String },
440}
441
442/// Security severity
443#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
444pub enum Severity {
445    Low,
446    Medium,
447    High,
448    Critical,
449}
450
451/// Capability-based security
452pub trait Capabilities {
453    /// Request a capability
454    fn request_capability(&self, capability: Capability) -> Result<CapabilityToken>;
455
456    /// Check if capability is granted
457    fn has_capability(&self, capability: &Capability) -> bool;
458
459    /// Revoke a capability
460    fn revoke_capability(&mut self, token: CapabilityToken) -> Result<()>;
461}
462
463/// Security capability
464#[derive(Debug, Clone, Hash, Eq, PartialEq)]
465pub enum Capability {
466    FileRead(PathBuf),
467    FileWrite(PathBuf),
468    NetworkAccess(String, u16),
469    ProcessSpawn(String),
470    SystemCall(String),
471}
472
473/// Capability token
474#[derive(Debug, Clone, Hash, Eq, PartialEq)]
475pub struct CapabilityToken {
476    /// Token ID
477    pub id: uuid::Uuid,
478    /// Capability
479    pub capability: Capability,
480    /// Expiration
481    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_file_permissions() {
490        let mut perms = FileSystemPermissions::default();
491        perms.allow_path(PathBuf::from("/tmp"));
492        perms.deny_path(PathBuf::from("/tmp/secret"));
493        perms.readonly_path(PathBuf::from("/tmp/readonly"));
494
495        assert!(perms.is_allowed(&PathBuf::from("/tmp/test.txt"), &FileAccessMode::Read));
496        assert!(!perms.is_allowed(&PathBuf::from("/tmp/secret/key.txt"), &FileAccessMode::Read));
497        assert!(!perms.is_allowed(
498            &PathBuf::from("/tmp/readonly/file.txt"),
499            &FileAccessMode::Write
500        ));
501    }
502
503    #[test]
504    fn test_audit_log() {
505        let mut log = AuditLog::new();
506
507        log.log(AuditEvent::SessionCreated {
508            session_id: "test-123".to_string(),
509            timestamp: chrono::Utc::now(),
510        })
511        .unwrap();
512
513        assert_eq!(log.entries().len(), 1);
514    }
515}