use crate::InfernoError;
use anyhow::Result;
use argon2::password_hash::{SaltString, rand_core::OsRng};
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::fs;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub auth_enabled: bool,
pub jwt_secret: String,
pub token_expiry_hours: i64,
pub api_key_enabled: bool,
pub rate_limiting_enabled: bool,
pub max_requests_per_minute: u32,
pub max_requests_per_hour: u32,
pub ip_allowlist_enabled: bool,
pub allowed_ips: Vec<String>,
pub ip_blocklist_enabled: bool,
pub blocked_ips: Vec<String>,
pub audit_logging_enabled: bool,
pub max_model_size_gb: f64,
pub input_validation_enabled: bool,
pub max_input_length: usize,
pub output_sanitization_enabled: bool,
pub tls_required: bool,
pub min_tls_version: String,
pub data_dir: PathBuf,
}
impl SecurityConfig {
pub fn validate(&self) -> Result<(), String> {
if self.auth_enabled && self.jwt_secret.is_empty() {
return Err("JWT secret is required when authentication is enabled. \
Set INFERNO_JWT_SECRET environment variable or configure jwt_secret in config."
.to_string());
}
if self.jwt_secret.len() < 32 && !self.jwt_secret.is_empty() {
return Err("JWT secret must be at least 32 characters for security. \
Use a strong random secret."
.to_string());
}
Ok(())
}
}
impl Default for SecurityConfig {
fn default() -> Self {
let jwt_secret = std::env::var("INFERNO_JWT_SECRET").unwrap_or_default();
Self {
auth_enabled: true,
jwt_secret, token_expiry_hours: 24,
api_key_enabled: true,
rate_limiting_enabled: true,
max_requests_per_minute: 60,
max_requests_per_hour: 1000,
ip_allowlist_enabled: false,
allowed_ips: vec![],
ip_blocklist_enabled: false,
blocked_ips: vec![],
audit_logging_enabled: true,
max_model_size_gb: 5.0, input_validation_enabled: true,
max_input_length: 10000,
output_sanitization_enabled: true,
tls_required: false,
min_tls_version: "1.2".to_string(),
data_dir: dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("inferno")
.join("security"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UserRole {
Admin,
User,
Guest,
Service,
}
impl UserRole {
pub fn has_permission(&self, permission: &Permission) -> bool {
match self {
UserRole::Admin => true, UserRole::User => matches!(
permission,
Permission::ReadModels
| Permission::RunInference
| Permission::ReadMetrics
| Permission::UseStreaming
),
UserRole::Guest => matches!(
permission,
Permission::ReadModels | Permission::RunInference
),
UserRole::Service => matches!(
permission,
Permission::ReadModels
| Permission::RunInference
| Permission::ReadMetrics
| Permission::UseStreaming
| Permission::ManageCache
),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Permission {
ReadModels,
WriteModels,
DeleteModels,
RunInference,
ManageCache,
ReadMetrics,
WriteConfig,
ManageUsers,
ViewAuditLogs,
UseStreaming,
UseDistributed,
ManageQueue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
pub email: Option<String>,
pub password_hash: Option<String>, pub role: UserRole,
pub api_keys: Vec<ApiKey>,
pub created_at: DateTime<Utc>,
pub last_login: Option<DateTime<Utc>>,
pub is_active: bool,
pub permissions: HashSet<Permission>,
pub rate_limit_override: Option<RateLimitConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKey {
pub id: String,
pub key_hash: String, pub name: String,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used: Option<DateTime<Utc>>,
pub is_active: bool,
pub permissions: HashSet<Permission>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub sub: String, pub username: String,
pub role: UserRole,
pub exp: i64, pub iat: i64, pub jti: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub requests_per_hour: u32,
pub requests_per_day: Option<u32>,
pub burst_size: u32,
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
minute_window: Arc<Mutex<VecDeque<DateTime<Utc>>>>,
hour_window: Arc<Mutex<VecDeque<DateTime<Utc>>>>,
day_window: Arc<Mutex<VecDeque<DateTime<Utc>>>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
minute_window: Arc::new(Mutex::new(VecDeque::new())),
hour_window: Arc::new(Mutex::new(VecDeque::new())),
day_window: Arc::new(Mutex::new(VecDeque::new())),
}
}
pub async fn check_rate_limit(&self) -> Result<bool> {
let now = Utc::now();
{
let mut minute_window = self.minute_window.lock().await;
let minute_ago = now - Duration::minutes(1);
while let Some(front) = minute_window.front() {
if *front < minute_ago {
minute_window.pop_front();
} else {
break;
}
}
if minute_window.len() >= self.config.requests_per_minute as usize {
return Ok(false);
}
minute_window.push_back(now);
}
{
let mut hour_window = self.hour_window.lock().await;
let hour_ago = now - Duration::hours(1);
while let Some(front) = hour_window.front() {
if *front < hour_ago {
hour_window.pop_front();
} else {
break;
}
}
if hour_window.len() >= self.config.requests_per_hour as usize {
return Ok(false);
}
hour_window.push_back(now);
}
if let Some(daily_limit) = self.config.requests_per_day {
let mut day_window = self.day_window.lock().await;
let day_ago = now - Duration::days(1);
while let Some(front) = day_window.front() {
if *front < day_ago {
day_window.pop_front();
} else {
break;
}
}
if day_window.len() >= daily_limit as usize {
return Ok(false);
}
day_window.push_back(now);
}
Ok(true)
}
pub async fn get_remaining_quota(&self) -> (u32, u32, Option<u32>) {
let now = Utc::now();
let minute_remaining = {
let minute_window = self.minute_window.lock().await;
let minute_ago = now - Duration::minutes(1);
let recent_count = minute_window.iter().filter(|&&t| t >= minute_ago).count() as u32;
self.config.requests_per_minute.saturating_sub(recent_count)
};
let hour_remaining = {
let hour_window = self.hour_window.lock().await;
let hour_ago = now - Duration::hours(1);
let recent_count = hour_window.iter().filter(|&&t| t >= hour_ago).count() as u32;
self.config.requests_per_hour.saturating_sub(recent_count)
};
let day_remaining = if let Some(daily_limit) = self.config.requests_per_day {
let day_window = self.day_window.lock().await;
let day_ago = now - Duration::days(1);
let recent_count = day_window.iter().filter(|&&t| t >= day_ago).count() as u32;
Some(daily_limit.saturating_sub(recent_count))
} else {
None
};
(minute_remaining, hour_remaining, day_remaining)
}
}
#[derive(Debug)]
pub struct SecurityManager {
config: SecurityConfig,
pub users: Arc<RwLock<HashMap<String, User>>>,
api_keys: Arc<RwLock<HashMap<String, String>>>, rate_limiters: Arc<RwLock<HashMap<String, RateLimiter>>>, ip_rate_limiters: Arc<RwLock<HashMap<IpAddr, RateLimiter>>>,
blocked_tokens: Arc<RwLock<HashSet<String>>>, audit_log: Arc<Mutex<Vec<AuditLogEntry>>>,
}
impl SecurityManager {
pub fn new(config: SecurityConfig) -> Self {
Self {
config,
users: Arc::new(RwLock::new(HashMap::new())),
api_keys: Arc::new(RwLock::new(HashMap::new())),
rate_limiters: Arc::new(RwLock::new(HashMap::new())),
ip_rate_limiters: Arc::new(RwLock::new(HashMap::new())),
blocked_tokens: Arc::new(RwLock::new(HashSet::new())),
audit_log: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn initialize_default_users(&self) -> Result<()> {
info!("Initializing security manager");
let admin_password = std::env::var("INFERNO_ADMIN_PASSWORD").map_err(|_| {
InfernoError::Security(
"INFERNO_ADMIN_PASSWORD environment variable is required to create admin user. \
Set a strong password (at least 12 characters) before starting the application."
.to_string(),
)
})?;
if admin_password.len() < 12 {
return Err(InfernoError::Security(
"Admin password must be at least 12 characters for security".to_string(),
)
.into());
}
let admin_password_hash = self.hash_password(&admin_password)?;
let admin_user = User {
id: "admin".to_string(),
username: "admin".to_string(),
email: Some("admin@inferno.ai".to_string()),
password_hash: Some(admin_password_hash),
role: UserRole::Admin,
api_keys: vec![],
created_at: Utc::now(),
last_login: None,
is_active: true,
permissions: HashSet::new(), rate_limit_override: None,
};
self.create_user(admin_user).await?;
let service_user = User {
id: "service".to_string(),
username: "service".to_string(),
email: None,
password_hash: None, role: UserRole::Service,
api_keys: vec![],
created_at: Utc::now(),
last_login: None,
is_active: true,
permissions: HashSet::from([
Permission::ReadModels,
Permission::RunInference,
Permission::ReadMetrics,
Permission::UseStreaming,
Permission::ManageCache,
]),
rate_limit_override: Some(RateLimitConfig {
requests_per_minute: 600,
requests_per_hour: 10000,
requests_per_day: Some(100000),
burst_size: 100,
}),
};
self.create_user(service_user).await?;
Ok(())
}
pub async fn create_user(&self, user: User) -> Result<()> {
let mut users = self.users.write().await;
if users.contains_key(&user.id) {
return Err(InfernoError::Security(format!("User {} already exists", user.id)).into());
}
info!("Creating user: {} with role {:?}", user.username, user.role);
let user_id = user.id.clone();
users.insert(user_id.clone(), user);
self.log_audit_event(AuditLogEntry {
timestamp: Utc::now(),
user_id: Some("system".to_string()),
action: AuditAction::UserCreated,
resource: Some(format!("user:{}", user_id)),
ip_address: None,
success: true,
details: None,
})
.await;
drop(users);
self.save_users().await?;
Ok(())
}
pub async fn delete_user(&self, user_id: &str) -> Result<()> {
let mut users = self.users.write().await;
let user = users
.remove(user_id)
.ok_or_else(|| InfernoError::Security(format!("User {} not found", user_id)))?;
info!("Deleting user: {} ({})", user.username, user_id);
let mut api_keys = self.api_keys.write().await;
let keys_to_remove: Vec<String> = api_keys
.iter()
.filter(|(_, uid)| *uid == user_id)
.map(|(key_hash, _)| key_hash.clone())
.collect();
for key_hash in keys_to_remove {
api_keys.remove(&key_hash);
}
let mut rate_limiters = self.rate_limiters.write().await;
rate_limiters.remove(user_id);
self.log_audit_event(AuditLogEntry {
timestamp: Utc::now(),
user_id: Some("system".to_string()),
action: AuditAction::UserDeleted,
resource: Some(format!("user:{}", user_id)),
ip_address: None,
success: true,
details: Some(format!("Deleted user: {}", user.username)),
})
.await;
Ok(())
}
pub async fn generate_api_key(
&self,
user_id: &str,
key_name: &str,
permissions: HashSet<Permission>,
expires_in_days: Option<i64>,
) -> Result<String> {
let mut users = self.users.write().await;
let user = users
.get_mut(user_id)
.ok_or_else(|| InfernoError::Security(format!("User {} not found", user_id)))?;
let api_key = Self::generate_random_key();
let key_hash = Self::hash_api_key(&api_key);
let api_key_info = ApiKey {
id: uuid::Uuid::new_v4().to_string(),
key_hash: key_hash.clone(),
name: key_name.to_string(),
created_at: Utc::now(),
expires_at: expires_in_days.map(|days| Utc::now() + Duration::days(days)),
last_used: None,
is_active: true,
permissions,
};
user.api_keys.push(api_key_info);
let mut api_keys = self.api_keys.write().await;
api_keys.insert(key_hash, user_id.to_string());
info!("Generated API key '{}' for user {}", key_name, user_id);
self.log_audit_event(AuditLogEntry {
timestamp: Utc::now(),
user_id: Some(user_id.to_string()),
action: AuditAction::ApiKeyCreated,
resource: Some(format!("api_key:{}", key_name)),
ip_address: None,
success: true,
details: None,
})
.await;
Ok(api_key)
}
pub async fn authenticate_api_key(&self, api_key: &str) -> Result<User> {
let key_hash = Self::hash_api_key(api_key);
let api_keys = self.api_keys.read().await;
let user_id = api_keys
.get(&key_hash)
.ok_or_else(|| InfernoError::Security("Invalid API key".to_string()))?;
let mut users = self.users.write().await;
let user = users
.get_mut(user_id)
.ok_or_else(|| InfernoError::Security("User not found".to_string()))?;
if !user.is_active {
return Err(InfernoError::Security("User account is disabled".to_string()).into());
}
for api_key_info in &mut user.api_keys {
if api_key_info.key_hash == key_hash {
if !api_key_info.is_active {
return Err(InfernoError::Security("API key is disabled".to_string()).into());
}
if let Some(expires_at) = api_key_info.expires_at {
if expires_at < Utc::now() {
return Err(
InfernoError::Security("API key has expired".to_string()).into()
);
}
}
api_key_info.last_used = Some(Utc::now());
break;
}
}
Ok(user.clone())
}
pub async fn generate_jwt_token(&self, user: &User) -> Result<String> {
let expiration = Utc::now() + Duration::hours(self.config.token_expiry_hours);
let claims = TokenClaims {
sub: user.id.clone(),
username: user.username.clone(),
role: user.role.clone(),
exp: expiration.timestamp(),
iat: Utc::now().timestamp(),
jti: uuid::Uuid::new_v4().to_string(),
};
let header = Header::new(Algorithm::HS256);
let encoding_key = EncodingKey::from_secret(self.config.jwt_secret.as_ref());
let token = encode(&header, &claims, &encoding_key)
.map_err(|e| InfernoError::Security(format!("JWT encoding failed: {}", e)))?;
Ok(token)
}
pub async fn verify_jwt_token(&self, token: &str) -> Result<TokenClaims> {
let decoding_key = DecodingKey::from_secret(self.config.jwt_secret.as_ref());
let validation = Validation::new(Algorithm::HS256);
let token_data = decode::<TokenClaims>(token, &decoding_key, &validation)
.map_err(|e| InfernoError::Security(format!("JWT verification failed: {}", e)))?;
let claims = token_data.claims;
let blocked_tokens = self.blocked_tokens.read().await;
if blocked_tokens.contains(&claims.jti) {
return Err(InfernoError::Security("Token has been revoked".to_string()).into());
}
if claims.exp < Utc::now().timestamp() {
return Err(InfernoError::Security("Token has expired".to_string()).into());
}
Ok(claims)
}
pub fn hash_password(&self, password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| InfernoError::Security(format!("Password hashing failed: {}", e)))?;
Ok(password_hash.to_string())
}
pub fn verify_password(&self, password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| InfernoError::Security(format!("Invalid password hash: {}", e)))?;
let argon2 = Argon2::default();
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
Ok(()) => Ok(true),
Err(_) => Ok(false),
}
}
pub async fn authenticate_user(&self, username: &str, password: &str) -> Result<Option<User>> {
let users = self.users.read().await;
let user = users
.values()
.find(|u| u.username == username && u.is_active);
if let Some(user) = user {
if let Some(ref stored_hash) = user.password_hash {
if self.verify_password(password, stored_hash)? {
let mut user_copy = user.clone();
user_copy.last_login = Some(Utc::now());
let user_id = user_copy.id.clone();
drop(users);
{
let mut users_write = self.users.write().await;
users_write.insert(user_id, user_copy.clone());
}
if let Err(e) = self.save_users().await {
warn!("Failed to save user update after login: {}", e);
}
return Ok(Some(user_copy));
}
}
}
Ok(None)
}
pub async fn check_rate_limit(&self, identifier: &str, ip: Option<IpAddr>) -> Result<bool> {
if !self.config.rate_limiting_enabled {
return Ok(true);
}
let mut rate_limiters = self.rate_limiters.write().await;
let user_limiter = rate_limiters
.entry(identifier.to_string())
.or_insert_with(|| {
RateLimiter::new(RateLimitConfig {
requests_per_minute: self.config.max_requests_per_minute,
requests_per_hour: self.config.max_requests_per_hour,
requests_per_day: None,
burst_size: 10,
})
});
if !user_limiter.check_rate_limit().await? {
warn!("Rate limit exceeded for user: {}", identifier);
return Ok(false);
}
if let Some(ip_addr) = ip {
let mut ip_limiters = self.ip_rate_limiters.write().await;
let ip_limiter = ip_limiters.entry(ip_addr).or_insert_with(|| {
RateLimiter::new(RateLimitConfig {
requests_per_minute: self.config.max_requests_per_minute * 2,
requests_per_hour: self.config.max_requests_per_hour * 2,
requests_per_day: None,
burst_size: 20,
})
});
if !ip_limiter.check_rate_limit().await? {
warn!("Rate limit exceeded for IP: {}", ip_addr);
return Ok(false);
}
}
Ok(true)
}
pub fn check_ip_access(&self, ip: &IpAddr) -> bool {
if self.config.ip_blocklist_enabled {
let ip_str = ip.to_string();
if self.config.blocked_ips.contains(&ip_str) {
warn!("Blocked IP attempted access: {}", ip);
return false;
}
}
if self.config.ip_allowlist_enabled {
let ip_str = ip.to_string();
if !self.config.allowed_ips.contains(&ip_str) {
warn!("Non-allowlisted IP attempted access: {}", ip);
return false;
}
}
true
}
pub fn validate_input(&self, input: &str) -> Result<()> {
if !self.config.input_validation_enabled {
return Ok(());
}
if input.len() > self.config.max_input_length {
return Err(InfernoError::Security(format!(
"Input exceeds maximum length of {} characters",
self.config.max_input_length
))
.into());
}
let dangerous_patterns = [
"<script",
"javascript:",
"onerror=",
"onclick=",
"../",
"..\\",
"%2e%2e",
"0x",
"\\x",
"DROP TABLE",
"DELETE FROM",
"INSERT INTO",
"cmd.exe",
"/bin/sh",
"powershell",
];
let input_lower = input.to_lowercase();
for pattern in &dangerous_patterns {
if input_lower.contains(pattern) {
warn!("Potentially dangerous input pattern detected: {}", pattern);
return Err(InfernoError::Security(
"Input contains potentially dangerous content".to_string(),
)
.into());
}
}
Ok(())
}
pub fn sanitize_output(&self, output: &str) -> String {
use std::sync::OnceLock;
static API_KEY_PATTERN: OnceLock<Option<regex::Regex>> = OnceLock::new();
static EMAIL_PATTERN: OnceLock<Option<regex::Regex>> = OnceLock::new();
static IP_PATTERN: OnceLock<Option<regex::Regex>> = OnceLock::new();
if !self.config.output_sanitization_enabled {
return output.to_string();
}
let mut sanitized = output.to_string();
let api_key_regex =
API_KEY_PATTERN.get_or_init(|| regex::Regex::new(r"[A-Za-z0-9]{32,}").ok());
if let Some(pattern) = api_key_regex {
sanitized = pattern.replace_all(&sanitized, "[REDACTED]").to_string();
}
let email_regex = EMAIL_PATTERN.get_or_init(|| {
regex::Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").ok()
});
if let Some(pattern) = email_regex {
sanitized = pattern.replace_all(&sanitized, "[EMAIL]").to_string();
}
let ip_regex =
IP_PATTERN.get_or_init(|| regex::Regex::new(r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b").ok());
if let Some(pattern) = ip_regex {
sanitized = pattern.replace_all(&sanitized, "[IP]").to_string();
}
sanitized
}
pub async fn log_audit_event(&self, entry: AuditLogEntry) {
if !self.config.audit_logging_enabled {
return;
}
let mut audit_log = self.audit_log.lock().await;
audit_log.push(entry.clone());
if audit_log.len() > 10000 {
audit_log.drain(0..1000);
}
debug!("Audit log: {:?}", entry);
}
pub async fn get_audit_log(&self, limit: Option<usize>) -> Vec<AuditLogEntry> {
let audit_log = self.audit_log.lock().await;
let limit = limit.unwrap_or(100);
audit_log.iter().rev().take(limit).cloned().collect()
}
fn generate_random_key() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng();
(0..32)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
pub async fn revoke_token(&self, jti: String) -> Result<()> {
let mut blocked_tokens = self.blocked_tokens.write().await;
blocked_tokens.insert(jti.clone());
self.log_audit_event(AuditLogEntry {
timestamp: Utc::now(),
user_id: None,
action: AuditAction::TokenRevoked,
resource: Some(format!("token:{}", jti)),
ip_address: None,
success: true,
details: None,
})
.await;
Ok(())
}
async fn save_users(&self) -> Result<()> {
let users_file = self.config.data_dir.join("users.json");
if let Some(parent) = users_file.parent() {
fs::create_dir_all(parent).await?;
}
let users = self.users.read().await;
let serialized = serde_json::to_string_pretty(&*users)?;
fs::write(&users_file, serialized).await?;
debug!("Saved {} users to {}", users.len(), users_file.display());
Ok(())
}
async fn load_users(&self) -> Result<()> {
let users_file = self.config.data_dir.join("users.json");
if !users_file.exists() {
debug!("Users file does not exist, starting with empty user store");
return Ok(());
}
let content = fs::read_to_string(&users_file).await?;
let loaded_users: HashMap<String, User> = serde_json::from_str(&content)?;
let mut users = self.users.write().await;
*users = loaded_users;
info!("Loaded {} users from {}", users.len(), users_file.display());
Ok(())
}
pub async fn initialize(&self) -> Result<()> {
if let Err(e) = self.config.validate() {
warn!("Security configuration validation warning: {}", e);
}
if let Err(e) = self.load_users().await {
warn!(
"Failed to load users from storage: {}. Will create default admin user if needed.",
e
);
}
let users_count = {
let users = self.users.read().await;
users.len()
};
if users_count == 0 {
info!("No users found, creating default admin user");
let admin_password = std::env::var("INFERNO_ADMIN_PASSWORD").map_err(|_| {
InfernoError::Security(
"INFERNO_ADMIN_PASSWORD environment variable is required to create admin user. \
Set a strong password (at least 12 characters) before starting the application."
.to_string(),
)
})?;
if admin_password.len() < 12 {
return Err(InfernoError::Security(
"Admin password must be at least 12 characters for security".to_string(),
)
.into());
}
let password_hash = self.hash_password(&admin_password)?;
let default_user = User {
id: "admin".to_string(),
username: "admin".to_string(),
email: Some("admin@localhost".to_string()),
password_hash: Some(password_hash),
role: UserRole::Admin,
api_keys: vec![],
created_at: chrono::Utc::now(),
last_login: None,
is_active: true,
permissions: [
Permission::ReadModels,
Permission::WriteModels,
Permission::DeleteModels,
Permission::RunInference,
Permission::ManageCache,
Permission::ReadMetrics,
Permission::WriteConfig,
Permission::ManageUsers,
Permission::ViewAuditLogs,
Permission::UseStreaming,
Permission::UseDistributed,
Permission::ManageQueue,
]
.into_iter()
.collect(),
rate_limit_override: None,
};
self.create_user(default_user).await?;
self.save_users().await?;
}
Ok(())
}
pub async fn get_all_users(&self) -> Vec<User> {
let users = self.users.read().await;
users.values().cloned().collect()
}
pub async fn get_user_by_id(&self, user_id: &str) -> Option<User> {
let users = self.users.read().await;
users.get(user_id).cloned()
}
pub async fn update_user(&self, user_id: &str, updated_user: User) -> Result<()> {
let mut users = self.users.write().await;
if users.contains_key(user_id) {
users.insert(user_id.to_string(), updated_user);
Ok(())
} else {
Err(anyhow::anyhow!("User not found"))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLogEntry {
pub timestamp: DateTime<Utc>,
pub user_id: Option<String>,
pub action: AuditAction,
pub resource: Option<String>,
pub ip_address: Option<IpAddr>,
pub success: bool,
pub details: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AuditAction {
UserCreated,
UserDeleted,
UserModified,
Login,
Logout,
ApiKeyCreated,
ApiKeyRevoked,
TokenRevoked,
InferenceRequested,
ModelLoaded,
ModelDeleted,
ConfigChanged,
RateLimitExceeded,
UnauthorizedAccess,
SecurityViolation,
ModelVerificationStarted,
ModelVerificationCompleted,
ModelVerificationFailed,
SecurityScanStarted,
SecurityScanCompleted,
SecurityScanFailed,
}
#[derive(Debug)]
pub struct SecurityScanner {
config: SecurityScanConfig,
threat_signatures: ThreatSignatureDatabase,
audit_logger: Arc<SecurityManager>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityScanConfig {
pub validate_file_structure: bool,
pub scan_embedded_content: bool,
pub scan_metadata_threats: bool,
pub verify_signatures: bool,
pub verify_checksums: bool,
pub max_scan_size: u64,
pub quarantine_enabled: bool,
pub quarantine_dir: PathBuf,
}
impl Default for SecurityScanConfig {
fn default() -> Self {
Self {
validate_file_structure: true,
scan_embedded_content: true,
scan_metadata_threats: true,
verify_signatures: false, verify_checksums: true,
max_scan_size: 50_000_000_000, quarantine_enabled: true,
quarantine_dir: PathBuf::from("./quarantine"),
}
}
}
#[derive(Debug)]
struct ThreatSignatureDatabase {
executable_patterns: Vec<Vec<u8>>,
script_patterns: Vec<Vec<u8>>,
suspicious_strings: Vec<String>,
metadata_threats: Vec<String>,
}
impl Default for ThreatSignatureDatabase {
fn default() -> Self {
Self {
executable_patterns: vec![
b"\x4d\x5a".to_vec(), b"\x7f\x45\x4c\x46".to_vec(), b"\xfe\xed\xfa\xce".to_vec(), b"\xfe\xed\xfa\xcf".to_vec(), b"\xca\xfe\xba\xbe".to_vec(), b"\x50\x4b\x03\x04".to_vec(), ],
script_patterns: vec![
b"#!/bin/sh".to_vec(),
b"#!/bin/bash".to_vec(),
b"#!/usr/bin/env".to_vec(),
b"<script".to_vec(),
b"javascript:".to_vec(),
b"data:text/html".to_vec(),
b"eval(".to_vec(),
b"exec(".to_vec(),
],
suspicious_strings: vec![
"password".to_string(),
"api_key".to_string(),
"secret".to_string(),
"token".to_string(),
"private_key".to_string(),
"ssh_key".to_string(),
"credential".to_string(),
"backdoor".to_string(),
"exploit".to_string(),
"payload".to_string(),
],
metadata_threats: vec![
"exec".to_string(),
"execute".to_string(),
"script".to_string(),
"command".to_string(),
"shell".to_string(),
"eval".to_string(),
"import".to_string(),
"require".to_string(),
"load".to_string(),
"include".to_string(),
"__import__".to_string(),
"subprocess".to_string(),
"os.system".to_string(),
],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityScanResult {
pub file_path: PathBuf,
pub scan_timestamp: DateTime<Utc>,
pub scan_duration_ms: u64,
pub threats_detected: Vec<ThreatDetection>,
pub overall_risk_level: RiskLevel,
pub file_quarantined: bool,
pub scan_success: bool,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreatDetection {
pub threat_type: ThreatType,
pub severity: ThreatSeverity,
pub description: String,
pub location: Option<String>,
pub mitigation_advice: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThreatType {
EmbeddedExecutable,
SuspiciousScript,
DataExfiltration,
MetadataThreats,
InvalidFileStructure,
SuspiciousSize,
UnknownFormat,
ChecksumMismatch,
SignatureInvalid,
PolicyViolation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThreatSeverity {
Critical,
High,
Medium,
Low,
Info,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RiskLevel {
Critical,
High,
Medium,
Low,
Safe,
}
impl SecurityScanner {
pub fn new(config: SecurityScanConfig, audit_logger: Arc<SecurityManager>) -> Self {
Self {
config,
threat_signatures: ThreatSignatureDatabase::default(),
audit_logger,
}
}
pub async fn scan_file(&self, file_path: &Path) -> Result<SecurityScanResult> {
let start_time = std::time::Instant::now();
let scan_timestamp = Utc::now();
info!("Starting security scan for: {}", file_path.display());
self.audit_logger
.log_audit_event(AuditLogEntry {
timestamp: scan_timestamp,
user_id: None,
action: AuditAction::SecurityScanStarted,
resource: Some(file_path.to_string_lossy().to_string()),
ip_address: None,
success: true,
details: None,
})
.await;
let mut threats = Vec::new();
let mut scan_success = true;
let mut error_message = None;
if !file_path.exists() {
let error = "File does not exist".to_string();
error_message = Some(error.clone());
scan_success = false;
return Ok(SecurityScanResult {
file_path: file_path.to_path_buf(),
scan_timestamp,
scan_duration_ms: start_time.elapsed().as_millis() as u64,
threats_detected: threats,
overall_risk_level: RiskLevel::High,
file_quarantined: false,
scan_success,
error_message,
});
}
let metadata = match tokio::fs::metadata(file_path).await {
Ok(meta) => meta,
Err(e) => {
error_message = Some(format!("Failed to read file metadata: {}", e));
scan_success = false;
return Ok(SecurityScanResult {
file_path: file_path.to_path_buf(),
scan_timestamp,
scan_duration_ms: start_time.elapsed().as_millis() as u64,
threats_detected: threats,
overall_risk_level: RiskLevel::High,
file_quarantined: false,
scan_success,
error_message,
});
}
};
let file_size = metadata.len();
if file_size > self.config.max_scan_size {
threats.push(ThreatDetection {
threat_type: ThreatType::SuspiciousSize,
severity: ThreatSeverity::Medium,
description: format!("File size ({} bytes) exceeds maximum scan size", file_size),
location: None,
mitigation_advice: "Consider scanning with specialized tools for large files"
.to_string(),
});
} else {
if self.config.validate_file_structure {
if let Err(e) = self.scan_file_structure(file_path, &mut threats).await {
warn!("File structure scan failed: {}", e);
}
}
if self.config.scan_embedded_content {
if let Err(e) = self.scan_embedded_content(file_path, &mut threats).await {
warn!("Embedded content scan failed: {}", e);
}
}
if self.config.scan_metadata_threats {
if let Err(e) = self.scan_metadata_threats(file_path, &mut threats).await {
warn!("Metadata threat scan failed: {}", e);
}
}
}
let overall_risk_level = self.assess_risk_level(&threats);
let mut file_quarantined = false;
if self.config.quarantine_enabled
&& matches!(overall_risk_level, RiskLevel::Critical | RiskLevel::High)
{
if let Err(e) = self.quarantine_file(file_path).await {
warn!("Failed to quarantine file: {}", e);
} else {
file_quarantined = true;
info!(
"File quarantined due to security threats: {}",
file_path.display()
);
}
}
let scan_duration_ms = start_time.elapsed().as_millis() as u64;
self.audit_logger
.log_audit_event(AuditLogEntry {
timestamp: Utc::now(),
user_id: None,
action: if scan_success {
AuditAction::SecurityScanCompleted
} else {
AuditAction::SecurityScanFailed
},
resource: Some(file_path.to_string_lossy().to_string()),
ip_address: None,
success: scan_success,
details: Some(format!(
"Threats: {}, Risk: {:?}, Duration: {}ms",
threats.len(),
overall_risk_level,
scan_duration_ms
)),
})
.await;
Ok(SecurityScanResult {
file_path: file_path.to_path_buf(),
scan_timestamp,
scan_duration_ms,
threats_detected: threats,
overall_risk_level,
file_quarantined,
scan_success,
error_message,
})
}
async fn scan_file_structure(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let extension = file_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("");
match extension.to_lowercase().as_str() {
"gguf" => self.validate_gguf_file(file_path, threats).await?,
"onnx" => self.validate_onnx_file(file_path, threats).await?,
"safetensors" => self.validate_safetensors_file(file_path, threats).await?,
"bin" | "pt" | "pth" => self.validate_pytorch_file(file_path, threats).await?,
_ => {
threats.push(ThreatDetection {
threat_type: ThreatType::UnknownFormat,
severity: ThreatSeverity::Low,
description: format!("Unknown file format: {}", extension),
location: None,
mitigation_advice: "Verify file type and ensure it's a valid model format"
.to_string(),
});
}
}
Ok(())
}
async fn scan_embedded_content(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
for pattern in &self.threat_signatures.executable_patterns {
if let Some(position) = self.find_pattern(&file_content, pattern) {
threats.push(ThreatDetection {
threat_type: ThreatType::EmbeddedExecutable,
severity: ThreatSeverity::Critical,
description: "Embedded executable code detected".to_string(),
location: Some(format!("Byte offset: {}", position)),
mitigation_advice: "Remove or verify embedded executable content".to_string(),
});
}
}
for pattern in &self.threat_signatures.script_patterns {
if let Some(position) = self.find_pattern(&file_content, pattern) {
threats.push(ThreatDetection {
threat_type: ThreatType::SuspiciousScript,
severity: ThreatSeverity::High,
description: "Embedded script code detected".to_string(),
location: Some(format!("Byte offset: {}", position)),
mitigation_advice: "Review and validate embedded script content".to_string(),
});
}
}
let printable_count = file_content
.iter()
.filter(|&b| *b >= 32 && *b <= 126)
.count();
let printable_ratio = printable_count as f64 / file_content.len() as f64;
if printable_ratio > 0.7 {
threats.push(ThreatDetection {
threat_type: ThreatType::DataExfiltration,
severity: ThreatSeverity::Medium,
description: format!(
"High ratio of printable characters: {:.1}%",
printable_ratio * 100.0
),
location: None,
mitigation_advice: "Review file content for hidden data or text".to_string(),
});
}
let content_str = String::from_utf8_lossy(&file_content).to_lowercase();
for suspicious_string in &self.threat_signatures.suspicious_strings {
if content_str.contains(suspicious_string) {
threats.push(ThreatDetection {
threat_type: ThreatType::DataExfiltration,
severity: ThreatSeverity::Medium,
description: format!("Suspicious string found: {}", suspicious_string),
location: None,
mitigation_advice: "Review file for embedded credentials or sensitive data"
.to_string(),
});
}
}
Ok(())
}
async fn scan_metadata_threats(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
let content_str = String::from_utf8_lossy(&file_content).to_lowercase();
for threat_pattern in &self.threat_signatures.metadata_threats {
if content_str.contains(threat_pattern) {
threats.push(ThreatDetection {
threat_type: ThreatType::MetadataThreats,
severity: ThreatSeverity::High,
description: format!("Suspicious metadata pattern: {}", threat_pattern),
location: None,
mitigation_advice: "Review model metadata for malicious content".to_string(),
});
}
}
Ok(())
}
async fn validate_gguf_file(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
if file_content.len() < 8 {
threats.push(ThreatDetection {
threat_type: ThreatType::InvalidFileStructure,
severity: ThreatSeverity::High,
description: "GGUF file too small".to_string(),
location: None,
mitigation_advice: "Verify file integrity and re-download if necessary".to_string(),
});
return Ok(());
}
if &file_content[0..4] != b"GGUF" {
threats.push(ThreatDetection {
threat_type: ThreatType::InvalidFileStructure,
severity: ThreatSeverity::High,
description: "Invalid GGUF magic bytes".to_string(),
location: Some("File header".to_string()),
mitigation_advice: "File may be corrupted or not a valid GGUF file".to_string(),
});
}
let version = u32::from_le_bytes([
file_content[4],
file_content[5],
file_content[6],
file_content[7],
]);
if !(1..=3).contains(&version) {
threats.push(ThreatDetection {
threat_type: ThreatType::InvalidFileStructure,
severity: ThreatSeverity::Medium,
description: format!("Unsupported GGUF version: {}", version),
location: Some("Version header".to_string()),
mitigation_advice: "Update to a supported GGUF version".to_string(),
});
}
Ok(())
}
async fn validate_onnx_file(
&self,
file_path: &Path,
_threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
if file_content.len() < 16 {
return Err(anyhow::anyhow!("ONNX file too small"));
}
Ok(())
}
async fn validate_safetensors_file(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
if file_content.len() < 8 {
threats.push(ThreatDetection {
threat_type: ThreatType::InvalidFileStructure,
severity: ThreatSeverity::High,
description: "SafeTensors file too small".to_string(),
location: None,
mitigation_advice: "Verify file integrity".to_string(),
});
return Ok(());
}
let header_length = u64::from_le_bytes([
file_content[0],
file_content[1],
file_content[2],
file_content[3],
file_content[4],
file_content[5],
file_content[6],
file_content[7],
]);
if header_length > file_content.len() as u64 - 8 {
threats.push(ThreatDetection {
threat_type: ThreatType::InvalidFileStructure,
severity: ThreatSeverity::High,
description: "Invalid SafeTensors header length".to_string(),
location: Some("File header".to_string()),
mitigation_advice: "File may be corrupted".to_string(),
});
}
Ok(())
}
async fn validate_pytorch_file(
&self,
file_path: &Path,
threats: &mut Vec<ThreatDetection>,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
if file_content.len() >= 4 {
if &file_content[0..4] == b"PK\x03\x04" {
debug!("Detected ZIP-based PyTorch file");
} else if file_content[0] == 0x80 {
debug!("Detected pickle-based PyTorch file");
threats.push(ThreatDetection {
threat_type: ThreatType::SuspiciousScript,
severity: ThreatSeverity::High,
description: "PyTorch pickle file detected - can execute arbitrary code"
.to_string(),
location: None,
mitigation_advice: "Use SafeTensors format instead of pickle for security"
.to_string(),
});
}
}
Ok(())
}
fn find_pattern(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
fn assess_risk_level(&self, threats: &[ThreatDetection]) -> RiskLevel {
if threats.is_empty() {
return RiskLevel::Safe;
}
let has_critical = threats
.iter()
.any(|t| matches!(t.severity, ThreatSeverity::Critical));
let has_high = threats
.iter()
.any(|t| matches!(t.severity, ThreatSeverity::High));
let medium_count = threats
.iter()
.filter(|t| matches!(t.severity, ThreatSeverity::Medium))
.count();
if has_critical {
RiskLevel::Critical
} else if has_high || medium_count >= 3 {
RiskLevel::High
} else if medium_count >= 1 {
RiskLevel::Medium
} else {
RiskLevel::Low
}
}
async fn quarantine_file(&self, file_path: &Path) -> Result<()> {
tokio::fs::create_dir_all(&self.config.quarantine_dir).await?;
let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
let original_name = file_path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("unknown");
let quarantine_filename = format!("{}_{}", timestamp, original_name);
let quarantine_path = self.config.quarantine_dir.join(quarantine_filename);
tokio::fs::rename(file_path, &quarantine_path).await?;
let metadata_path = quarantine_path.with_extension("quarantine_metadata.json");
let metadata = serde_json::json!({
"original_path": file_path,
"quarantined_at": Utc::now(),
"reason": "Security scan detected threats"
});
tokio::fs::write(&metadata_path, serde_json::to_string_pretty(&metadata)?).await?;
info!(
"File quarantined: {} -> {}",
file_path.display(),
quarantine_path.display()
);
Ok(())
}
}