1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8pub struct SecureSession {
10 pub namespace: Namespace,
12 pub cgroups: CGroupLimits,
14 pub security_policy: SecurityPolicy,
16 pub audit_log: AuditLog,
18}
19
20impl SecureSession {
21 pub fn new(session_id: &str) -> Self {
23 Self {
24 namespace: Namespace::new(session_id),
25 cgroups: CGroupLimits::default(),
26 security_policy: SecurityPolicy::default(),
27 audit_log: AuditLog::new(),
28 }
29 }
30
31 pub fn apply_policy(&mut self, policy: SecurityPolicy) -> Result<()> {
33 self.security_policy = policy;
34 self.audit_log.log(AuditEvent::PolicyApplied {
35 policy_name: "custom".to_string(),
36 })?;
37 Ok(())
38 }
39
40 pub fn is_allowed(&self, action: &Action) -> bool {
42 match action {
43 Action::FileAccess { path, mode } => {
44 self.security_policy.fs_permissions.is_allowed(path, mode)
45 }
46 Action::NetworkAccess { host, port } => {
47 self.security_policy.network_policy.is_allowed(host, *port)
48 }
49 Action::SystemCall { syscall } => {
50 self.security_policy.syscall_access.is_allowed(syscall)
51 }
52 Action::APICall { endpoint, method } => {
53 self.security_policy.api_limits.is_allowed(endpoint, method)
54 }
55 }
56 }
57
58 pub fn audit_action(&mut self, action: Action, allowed: bool) -> Result<()> {
60 self.audit_log.log(AuditEvent::ActionAttempted {
61 action,
62 allowed,
63 timestamp: chrono::Utc::now(),
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct Namespace {
71 pub id: String,
73 pub pid_ns: bool,
75 pub net_ns: bool,
77 pub mnt_ns: bool,
79 pub user_ns: bool,
81}
82
83impl Namespace {
84 pub fn new(id: &str) -> Self {
86 Self {
87 id: format!("ai-session-{}", id),
88 pid_ns: true,
89 net_ns: true,
90 mnt_ns: true,
91 user_ns: true,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct CGroupLimits {
99 pub cpu_limit: f64,
101 pub memory_limit: usize,
103 pub io_limit: usize,
105 pub pids_limit: usize,
107}
108
109impl Default for CGroupLimits {
110 fn default() -> Self {
111 Self {
112 cpu_limit: 50.0, memory_limit: 1 << 30, io_limit: 100 << 20, pids_limit: 100, }
117 }
118}
119
120#[derive(Debug, Clone, Default)]
122pub struct SecurityPolicy {
123 pub fs_permissions: FileSystemPermissions,
125 pub network_policy: NetworkPolicy,
127 pub syscall_access: SyscallAccess,
129 pub api_limits: APILimits,
131 pub secret_manager: SecretManager,
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct FileSystemPermissions {
138 allowed_paths: Vec<PathBuf>,
140 denied_paths: Vec<PathBuf>,
142 readonly_paths: Vec<PathBuf>,
144}
145
146impl FileSystemPermissions {
147 pub fn is_allowed(&self, path: &Path, mode: &FileAccessMode) -> bool {
149 for denied in &self.denied_paths {
151 if path.starts_with(denied) {
152 return false;
153 }
154 }
155
156 let in_allowed = self
158 .allowed_paths
159 .iter()
160 .any(|allowed| path.starts_with(allowed));
161 if !in_allowed && !self.allowed_paths.is_empty() {
162 return false;
163 }
164
165 if matches!(mode, FileAccessMode::Write | FileAccessMode::Execute) {
167 for readonly in &self.readonly_paths {
168 if path.starts_with(readonly) {
169 return false;
170 }
171 }
172 }
173
174 true
175 }
176
177 pub fn allow_path(&mut self, path: PathBuf) {
179 self.allowed_paths.push(path);
180 }
181
182 pub fn deny_path(&mut self, path: PathBuf) {
184 self.denied_paths.push(path);
185 }
186
187 pub fn readonly_path(&mut self, path: PathBuf) {
189 self.readonly_paths.push(path);
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub enum FileAccessMode {
196 Read,
197 Write,
198 Execute,
199}
200
201#[derive(Debug, Clone, Default)]
203pub struct NetworkPolicy {
204 allowed_hosts: Vec<String>,
206 blocked_hosts: Vec<String>,
208 allowed_ports: Vec<u16>,
210 rate_limits: HashMap<String, RateLimit>,
212}
213
214impl NetworkPolicy {
215 pub fn is_allowed(&self, host: &str, port: u16) -> bool {
217 if self.blocked_hosts.iter().any(|h| h == host) {
219 return false;
220 }
221
222 if !self.allowed_hosts.is_empty() && !self.allowed_hosts.iter().any(|h| h == host) {
224 return false;
225 }
226
227 if !self.allowed_ports.is_empty() && !self.allowed_ports.contains(&port) {
229 return false;
230 }
231
232 if let Some(limit) = self.rate_limits.get(host) {
234 return limit.check();
235 }
236
237 true
238 }
239}
240
241#[derive(Debug, Clone, Default)]
243pub struct SyscallAccess {
244 allowed_syscalls: Vec<String>,
246 blocked_syscalls: Vec<String>,
248}
249
250impl SyscallAccess {
251 pub fn is_allowed(&self, syscall: &str) -> bool {
253 if self.blocked_syscalls.contains(&syscall.to_string()) {
254 return false;
255 }
256
257 if !self.allowed_syscalls.is_empty() {
258 return self.allowed_syscalls.contains(&syscall.to_string());
259 }
260
261 true
262 }
263}
264
265#[derive(Debug, Clone, Default)]
267pub struct APILimits {
268 endpoint_limits: HashMap<String, RateLimit>,
270 global_limit: Option<RateLimit>,
272 _token_limits: TokenLimits,
274}
275
276impl APILimits {
277 pub fn is_allowed(&self, endpoint: &str, _method: &str) -> bool {
279 if let Some(limit) = self.endpoint_limits.get(endpoint)
281 && !limit.check()
282 {
283 return false;
284 }
285
286 if let Some(ref limit) = self.global_limit
288 && !limit.check()
289 {
290 return false;
291 }
292
293 true
294 }
295}
296
297#[derive(Debug, Clone, Default)]
299pub struct TokenLimits {
300 pub max_tokens_per_request: usize,
302 pub max_tokens_per_hour: usize,
304 pub max_tokens_per_day: usize,
306}
307
308#[derive(Debug)]
310pub struct RateLimit {
311 pub requests_per_minute: usize,
313 _current_minute: std::time::Instant,
315 request_count: std::sync::Mutex<usize>,
317}
318
319impl RateLimit {
320 pub fn new(requests_per_minute: usize) -> Self {
322 Self {
323 requests_per_minute,
324 _current_minute: std::time::Instant::now(),
325 request_count: std::sync::Mutex::new(0),
326 }
327 }
328
329 pub fn check(&self) -> bool {
331 let mut count = self.request_count.lock().unwrap();
332 if *count < self.requests_per_minute {
333 *count += 1;
334 true
335 } else {
336 false
337 }
338 }
339}
340
341impl Clone for RateLimit {
342 fn clone(&self) -> Self {
343 let count = *self.request_count.lock().unwrap();
344 Self {
345 requests_per_minute: self.requests_per_minute,
346 _current_minute: self._current_minute,
347 request_count: std::sync::Mutex::new(count),
348 }
349 }
350}
351
352#[derive(Debug, Clone, Default)]
354pub struct SecretManager {
355 secrets: HashMap<String, Vec<u8>>,
357}
358
359impl SecretManager {
360 pub fn store_secret(&mut self, key: &str, value: &[u8]) -> Result<()> {
362 self.secrets.insert(key.to_string(), value.to_vec());
364 Ok(())
365 }
366
367 pub fn get_secret(&self, key: &str) -> Option<Vec<u8>> {
369 self.secrets.get(key).cloned()
370 }
371}
372
373pub struct AuditLog {
375 entries: Vec<AuditEvent>,
377}
378
379impl AuditLog {
380 pub fn new() -> Self {
382 Self {
383 entries: Vec::new(),
384 }
385 }
386
387 pub fn log(&mut self, event: AuditEvent) -> Result<()> {
389 self.entries.push(event);
390 Ok(())
391 }
392
393 pub fn entries(&self) -> &[AuditEvent] {
395 &self.entries
396 }
397}
398
399impl Default for AuditLog {
400 fn default() -> Self {
401 Self::new()
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub enum AuditEvent {
408 SessionCreated {
410 session_id: String,
411 timestamp: chrono::DateTime<chrono::Utc>,
412 },
413 PolicyApplied { policy_name: String },
415 ActionAttempted {
417 action: Action,
418 allowed: bool,
419 timestamp: chrono::DateTime<chrono::Utc>,
420 },
421 SecurityViolation {
423 description: String,
424 severity: Severity,
425 timestamp: chrono::DateTime<chrono::Utc>,
426 },
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
431pub enum Action {
432 FileAccess { path: PathBuf, mode: FileAccessMode },
434 NetworkAccess { host: String, port: u16 },
436 SystemCall { syscall: String },
438 APICall { endpoint: String, method: String },
440}
441
442#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
444pub enum Severity {
445 Low,
446 Medium,
447 High,
448 Critical,
449}
450
451pub trait Capabilities {
453 fn request_capability(&self, capability: Capability) -> Result<CapabilityToken>;
455
456 fn has_capability(&self, capability: &Capability) -> bool;
458
459 fn revoke_capability(&mut self, token: CapabilityToken) -> Result<()>;
461}
462
463#[derive(Debug, Clone, Hash, Eq, PartialEq)]
465pub enum Capability {
466 FileRead(PathBuf),
467 FileWrite(PathBuf),
468 NetworkAccess(String, u16),
469 ProcessSpawn(String),
470 SystemCall(String),
471}
472
473#[derive(Debug, Clone, Hash, Eq, PartialEq)]
475pub struct CapabilityToken {
476 pub id: uuid::Uuid,
478 pub capability: Capability,
480 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_file_permissions() {
490 let mut perms = FileSystemPermissions::default();
491 perms.allow_path(PathBuf::from("/tmp"));
492 perms.deny_path(PathBuf::from("/tmp/secret"));
493 perms.readonly_path(PathBuf::from("/tmp/readonly"));
494
495 assert!(perms.is_allowed(&PathBuf::from("/tmp/test.txt"), &FileAccessMode::Read));
496 assert!(!perms.is_allowed(&PathBuf::from("/tmp/secret/key.txt"), &FileAccessMode::Read));
497 assert!(!perms.is_allowed(
498 &PathBuf::from("/tmp/readonly/file.txt"),
499 &FileAccessMode::Write
500 ));
501 }
502
503 #[test]
504 fn test_audit_log() {
505 let mut log = AuditLog::new();
506
507 log.log(AuditEvent::SessionCreated {
508 session_id: "test-123".to_string(),
509 timestamp: chrono::Utc::now(),
510 })
511 .unwrap();
512
513 assert_eq!(log.entries().len(), 1);
514 }
515}