use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("Missing required configuration: {key}")]
MissingRequired { key: String },
#[error("Invalid configuration value for {key}: {reason}")]
InvalidValue { key: String, reason: String },
#[error("Environment variable error: {message}")]
EnvError { message: String },
#[error("IO error reading config file: {message}")]
IoError { message: String },
#[error("Configuration parsing error: {message}")]
ParseError { message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
pub api: ApiConfig,
pub database: DatabaseConfig,
pub logging: LoggingConfig,
pub security: SecurityConfig,
pub storage: StorageConfig,
pub slm: Option<Slm>,
pub routing: Option<crate::routing::RoutingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiConfig {
pub port: u16,
pub host: String,
#[serde(skip_serializing)]
pub auth_token: Option<String>,
pub timeout_seconds: u64,
pub max_body_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
#[serde(skip_serializing)]
pub url: Option<String>,
#[serde(skip_serializing)]
pub redis_url: Option<String>,
pub qdrant_url: String,
pub qdrant_collection: String,
pub vector_dimension: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub level: String,
pub format: LogFormat,
pub structured: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogFormat {
Json,
Pretty,
Compact,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub key_provider: KeyProvider,
pub enable_compression: bool,
pub enable_backups: bool,
pub enable_safety_checks: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum KeyProvider {
Environment { var_name: String },
File { path: PathBuf },
Keychain { service: String, account: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
pub context_path: PathBuf,
pub git_clone_path: PathBuf,
pub backup_path: PathBuf,
pub max_context_size_mb: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Slm {
pub enabled: bool,
pub model_allow_lists: ModelAllowListConfig,
pub sandbox_profiles: HashMap<String, SandboxProfile>,
pub default_sandbox_profile: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelAllowListConfig {
pub global_models: Vec<Model>,
pub agent_model_maps: HashMap<String, Vec<String>>,
pub allow_runtime_overrides: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
pub id: String,
pub name: String,
pub provider: ModelProvider,
pub capabilities: Vec<ModelCapability>,
pub resource_requirements: ModelResourceRequirements,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelProvider {
HuggingFace { model_path: String },
LocalFile { file_path: PathBuf },
OpenAI { model_name: String },
Anthropic { model_name: String },
Custom { endpoint_url: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ModelCapability {
TextGeneration,
CodeGeneration,
Reasoning,
ToolUse,
FunctionCalling,
Embeddings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelResourceRequirements {
pub min_memory_mb: u64,
pub preferred_cpu_cores: f32,
pub gpu_requirements: Option<GpuRequirements>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuRequirements {
pub min_vram_mb: u64,
pub compute_capability: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SandboxProfile {
pub resources: ResourceConstraints,
pub filesystem: FilesystemControls,
pub process_limits: ProcessLimits,
pub network: NetworkPolicy,
pub security: SecuritySettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceConstraints {
pub max_memory_mb: u64,
pub max_cpu_cores: f32,
pub max_disk_mb: u64,
pub gpu_access: GpuAccess,
pub max_io_bandwidth_mbps: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilesystemControls {
pub read_paths: Vec<String>,
pub write_paths: Vec<String>,
pub denied_paths: Vec<String>,
pub allow_temp_files: bool,
pub max_file_size_mb: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessLimits {
pub max_child_processes: u32,
pub max_execution_time_seconds: u64,
pub allowed_syscalls: Vec<String>,
pub process_priority: i8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkPolicy {
pub access_mode: NetworkAccessMode,
pub allowed_destinations: Vec<NetworkDestination>,
pub max_bandwidth_mbps: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NetworkAccessMode {
None,
Restricted,
Full,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkDestination {
pub host: String,
pub port: Option<u16>,
pub protocol: Option<NetworkProtocol>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NetworkProtocol {
TCP,
UDP,
HTTP,
HTTPS,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GpuAccess {
None,
Shared { max_memory_mb: u64 },
Exclusive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecuritySettings {
pub strict_syscall_filtering: bool,
pub disable_debugging: bool,
pub enable_audit_logging: bool,
pub require_encryption: bool,
}
impl Default for ApiConfig {
fn default() -> Self {
Self {
port: 8080,
host: "127.0.0.1".to_string(),
auth_token: None,
timeout_seconds: 60,
max_body_size: 16 * 1024 * 1024, }
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: None,
redis_url: None,
qdrant_url: "http://localhost:6333".to_string(),
qdrant_collection: "agent_knowledge".to_string(),
vector_dimension: 1536,
}
}
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
format: LogFormat::Pretty,
structured: false,
}
}
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
key_provider: KeyProvider::Environment {
var_name: "SYMBIONT_SECRET_KEY".to_string(),
},
enable_compression: true,
enable_backups: true,
enable_safety_checks: true,
}
}
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
context_path: PathBuf::from("./agent_storage"),
git_clone_path: PathBuf::from("./temp_repos"),
backup_path: PathBuf::from("./backups"),
max_context_size_mb: 100,
}
}
}
impl Default for Slm {
fn default() -> Self {
let mut profiles = HashMap::new();
profiles.insert("secure".to_string(), SandboxProfile::secure_default());
profiles.insert("standard".to_string(), SandboxProfile::standard_default());
Self {
enabled: false,
model_allow_lists: ModelAllowListConfig::default(),
sandbox_profiles: profiles,
default_sandbox_profile: "secure".to_string(),
}
}
}
impl SandboxProfile {
pub fn secure_default() -> Self {
Self {
resources: ResourceConstraints {
max_memory_mb: 512,
max_cpu_cores: 1.0,
max_disk_mb: 100,
gpu_access: GpuAccess::None,
max_io_bandwidth_mbps: Some(10),
},
filesystem: FilesystemControls {
read_paths: vec!["/tmp/sandbox/*".to_string()],
write_paths: vec!["/tmp/sandbox/output/*".to_string()],
denied_paths: vec!["/etc/*".to_string(), "/proc/*".to_string()],
allow_temp_files: true,
max_file_size_mb: 10,
},
process_limits: ProcessLimits {
max_child_processes: 0,
max_execution_time_seconds: 300,
allowed_syscalls: vec!["read".to_string(), "write".to_string(), "open".to_string()],
process_priority: 19,
},
network: NetworkPolicy {
access_mode: NetworkAccessMode::None,
allowed_destinations: vec![],
max_bandwidth_mbps: None,
},
security: SecuritySettings {
strict_syscall_filtering: true,
disable_debugging: true,
enable_audit_logging: true,
require_encryption: true,
},
}
}
pub fn standard_default() -> Self {
Self {
resources: ResourceConstraints {
max_memory_mb: 1024,
max_cpu_cores: 2.0,
max_disk_mb: 500,
gpu_access: GpuAccess::Shared { max_memory_mb: 1024 },
max_io_bandwidth_mbps: Some(50),
},
filesystem: FilesystemControls {
read_paths: vec!["/tmp/*".to_string(), "/home/sandbox/*".to_string()],
write_paths: vec!["/tmp/*".to_string(), "/home/sandbox/*".to_string()],
denied_paths: vec!["/etc/passwd".to_string(), "/etc/shadow".to_string()],
allow_temp_files: true,
max_file_size_mb: 100,
},
process_limits: ProcessLimits {
max_child_processes: 5,
max_execution_time_seconds: 600,
allowed_syscalls: vec![], process_priority: 0,
},
network: NetworkPolicy {
access_mode: NetworkAccessMode::Restricted,
allowed_destinations: vec![
NetworkDestination {
host: "api.openai.com".to_string(),
port: Some(443),
protocol: Some(NetworkProtocol::HTTPS),
},
],
max_bandwidth_mbps: Some(100),
},
security: SecuritySettings {
strict_syscall_filtering: false,
disable_debugging: false,
enable_audit_logging: true,
require_encryption: false,
},
}
}
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
if self.resources.max_memory_mb == 0 {
return Err("max_memory_mb must be > 0".into());
}
if self.resources.max_cpu_cores <= 0.0 {
return Err("max_cpu_cores must be > 0".into());
}
for path in &self.filesystem.read_paths {
if path.is_empty() {
return Err("read_paths cannot contain empty strings".into());
}
}
if self.process_limits.max_execution_time_seconds == 0 {
return Err("max_execution_time_seconds must be > 0".into());
}
Ok(())
}
}
impl Slm {
pub fn validate(&self) -> Result<(), ConfigError> {
if !self.sandbox_profiles.contains_key(&self.default_sandbox_profile) {
return Err(ConfigError::InvalidValue {
key: "slm.default_sandbox_profile".to_string(),
reason: format!("Profile '{}' not found in sandbox_profiles",
self.default_sandbox_profile),
});
}
let mut model_ids = std::collections::HashSet::new();
for model in &self.model_allow_lists.global_models {
if !model_ids.insert(&model.id) {
return Err(ConfigError::InvalidValue {
key: "slm.model_allow_lists.global_models".to_string(),
reason: format!("Duplicate model ID: {}", model.id),
});
}
}
for (agent_id, model_ids) in &self.model_allow_lists.agent_model_maps {
for model_id in model_ids {
if !self.model_allow_lists.global_models
.iter().any(|m| &m.id == model_id) {
return Err(ConfigError::InvalidValue {
key: format!("slm.model_allow_lists.agent_model_maps.{}", agent_id),
reason: format!("Model ID '{}' not found in global_models", model_id),
});
}
}
}
for (profile_name, profile) in &self.sandbox_profiles {
profile.validate()
.map_err(|e| ConfigError::InvalidValue {
key: format!("slm.sandbox_profiles.{}", profile_name),
reason: e.to_string(),
})?;
}
Ok(())
}
pub fn get_allowed_models(&self, agent_id: &str) -> Vec<&Model> {
if let Some(model_ids) = self.model_allow_lists.agent_model_maps.get(agent_id) {
self.model_allow_lists.global_models
.iter()
.filter(|model| model_ids.contains(&model.id))
.collect()
} else {
self.model_allow_lists.global_models.iter().collect()
}
}
}
impl Config {
pub fn from_env() -> Result<Self, ConfigError> {
let mut config = Self::default();
if let Ok(port) = env::var("API_PORT") {
config.api.port = port.parse().map_err(|_| ConfigError::InvalidValue {
key: "API_PORT".to_string(),
reason: "Invalid port number".to_string(),
})?;
}
if let Ok(host) = env::var("API_HOST") {
config.api.host = host;
}
if let Ok(token) = env::var("API_AUTH_TOKEN") {
if !token.is_empty() {
config.api.auth_token = Some(token);
}
}
if let Ok(db_url) = env::var("DATABASE_URL") {
config.database.url = Some(db_url);
}
if let Ok(redis_url) = env::var("REDIS_URL") {
config.database.redis_url = Some(redis_url);
}
if let Ok(qdrant_url) = env::var("QDRANT_URL") {
config.database.qdrant_url = qdrant_url;
}
if let Ok(log_level) = env::var("LOG_LEVEL") {
config.logging.level = log_level;
}
if let Ok(key_var) = env::var("SYMBIONT_SECRET_KEY_VAR") {
config.security.key_provider = KeyProvider::Environment { var_name: key_var };
}
if let Ok(context_path) = env::var("CONTEXT_STORAGE_PATH") {
config.storage.context_path = PathBuf::from(context_path);
}
if let Ok(git_path) = env::var("GIT_CLONE_BASE_PATH") {
config.storage.git_clone_path = PathBuf::from(git_path);
}
if let Ok(backup_path) = env::var("BACKUP_DIRECTORY") {
config.storage.backup_path = PathBuf::from(backup_path);
}
Ok(config)
}
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path)
.map_err(|e| ConfigError::IoError { message: e.to_string() })?;
let config: Self = toml::from_str(&content)
.map_err(|e| ConfigError::ParseError { message: e.to_string() })?;
Ok(config)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.api.port == 0 {
return Err(ConfigError::InvalidValue {
key: "api.port".to_string(),
reason: "Port cannot be 0".to_string(),
});
}
let valid_levels = ["error", "warn", "info", "debug", "trace"];
if !valid_levels.contains(&self.logging.level.as_str()) {
return Err(ConfigError::InvalidValue {
key: "logging.level".to_string(),
reason: format!("Must be one of: {}", valid_levels.join(", ")),
});
}
if self.database.vector_dimension == 0 {
return Err(ConfigError::InvalidValue {
key: "database.vector_dimension".to_string(),
reason: "Vector dimension must be > 0".to_string(),
});
}
if let Some(slm) = &self.slm {
if slm.enabled {
slm.validate()?;
}
}
Ok(())
}
pub fn get_api_auth_token(&self) -> Result<String, ConfigError> {
match &self.api.auth_token {
Some(token) => Ok(token.clone()),
None => Err(ConfigError::MissingRequired {
key: "API_AUTH_TOKEN".to_string(),
}),
}
}
pub fn get_database_url(&self) -> Result<String, ConfigError> {
match &self.database.url {
Some(url) => Ok(url.clone()),
None => Err(ConfigError::MissingRequired {
key: "DATABASE_URL".to_string(),
}),
}
}
pub fn get_secret_key(&self) -> Result<String, ConfigError> {
match &self.security.key_provider {
KeyProvider::Environment { var_name } => {
env::var(var_name).map_err(|_| ConfigError::MissingRequired {
key: var_name.clone(),
})
}
KeyProvider::File { path } => {
std::fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| ConfigError::IoError { message: e.to_string() })
}
KeyProvider::Keychain { service, account } => {
#[cfg(feature = "keychain")]
{
use keyring::Entry;
let entry = Entry::new(service, account)
.map_err(|e| ConfigError::EnvError { message: e.to_string() })?;
entry.get_password()
.map_err(|e| ConfigError::EnvError { message: e.to_string() })
}
#[cfg(not(feature = "keychain"))]
{
Err(ConfigError::EnvError {
message: "Keychain support not enabled".to_string(),
})
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::collections::HashMap;
use std::path::PathBuf;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.api.port, 8080);
assert_eq!(config.api.host, "127.0.0.1");
assert!(config.validate().is_ok());
}
#[test]
fn test_config_from_env() {
env::set_var("API_PORT", "9090");
env::set_var("API_HOST", "0.0.0.0");
env::set_var("LOG_LEVEL", "debug");
let config = Config::from_env().unwrap();
assert_eq!(config.api.port, 9090);
assert_eq!(config.api.host, "0.0.0.0");
assert_eq!(config.logging.level, "debug");
env::remove_var("API_PORT");
env::remove_var("API_HOST");
env::remove_var("LOG_LEVEL");
}
#[test]
fn test_invalid_port() {
let mut config = Config::default();
config.api.port = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_log_level() {
let mut config = Config::default();
config.logging.level = "invalid".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_slm_default_config() {
let slm = Slm::default();
assert!(!slm.enabled);
assert_eq!(slm.default_sandbox_profile, "secure");
assert!(slm.sandbox_profiles.contains_key("secure"));
assert!(slm.sandbox_profiles.contains_key("standard"));
assert!(slm.validate().is_ok());
}
#[test]
fn test_slm_validation_invalid_default_profile() {
let mut slm = Slm::default();
slm.default_sandbox_profile = "nonexistent".to_string();
let result = slm.validate();
assert!(result.is_err());
if let Err(ConfigError::InvalidValue { key, reason }) = result {
assert_eq!(key, "slm.default_sandbox_profile");
assert!(reason.contains("nonexistent"));
}
}
#[test]
fn test_slm_validation_duplicate_model_ids() {
let model1 = Model {
id: "duplicate".to_string(),
name: "Model 1".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/model1.gguf") },
capabilities: vec![ModelCapability::TextGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 512,
preferred_cpu_cores: 1.0,
gpu_requirements: None,
},
};
let model2 = Model {
id: "duplicate".to_string(), name: "Model 2".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/model2.gguf") },
capabilities: vec![ModelCapability::CodeGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 1024,
preferred_cpu_cores: 2.0,
gpu_requirements: None,
},
};
let mut slm = Slm::default();
slm.model_allow_lists.global_models = vec![model1, model2];
let result = slm.validate();
assert!(result.is_err());
if let Err(ConfigError::InvalidValue { key, reason }) = result {
assert_eq!(key, "slm.model_allow_lists.global_models");
assert!(reason.contains("Duplicate model ID: duplicate"));
}
}
#[test]
fn test_slm_validation_invalid_agent_model_mapping() {
let model = Model {
id: "test_model".to_string(),
name: "Test Model".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/test.gguf") },
capabilities: vec![ModelCapability::TextGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 512,
preferred_cpu_cores: 1.0,
gpu_requirements: None,
},
};
let mut slm = Slm::default();
slm.model_allow_lists.global_models = vec![model];
let mut agent_model_maps = HashMap::new();
agent_model_maps.insert("test_agent".to_string(), vec!["nonexistent_model".to_string()]);
slm.model_allow_lists.agent_model_maps = agent_model_maps;
let result = slm.validate();
assert!(result.is_err());
if let Err(ConfigError::InvalidValue { key, reason }) = result {
assert_eq!(key, "slm.model_allow_lists.agent_model_maps.test_agent");
assert!(reason.contains("Model ID 'nonexistent_model' not found"));
}
}
#[test]
fn test_slm_get_allowed_models_with_agent_mapping() {
let model1 = Model {
id: "model1".to_string(),
name: "Model 1".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/model1.gguf") },
capabilities: vec![ModelCapability::TextGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 512,
preferred_cpu_cores: 1.0,
gpu_requirements: None,
},
};
let model2 = Model {
id: "model2".to_string(),
name: "Model 2".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/model2.gguf") },
capabilities: vec![ModelCapability::CodeGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 1024,
preferred_cpu_cores: 2.0,
gpu_requirements: None,
},
};
let mut slm = Slm::default();
slm.model_allow_lists.global_models = vec![model1, model2];
let mut agent_model_maps = HashMap::new();
agent_model_maps.insert("agent1".to_string(), vec!["model1".to_string()]);
slm.model_allow_lists.agent_model_maps = agent_model_maps;
let allowed_models = slm.get_allowed_models("agent1");
assert_eq!(allowed_models.len(), 1);
assert_eq!(allowed_models[0].id, "model1");
let allowed_models = slm.get_allowed_models("agent2");
assert_eq!(allowed_models.len(), 2);
}
#[test]
fn test_sandbox_profile_secure_default() {
let profile = SandboxProfile::secure_default();
assert_eq!(profile.resources.max_memory_mb, 512);
assert_eq!(profile.resources.max_cpu_cores, 1.0);
assert!(matches!(profile.resources.gpu_access, GpuAccess::None));
assert!(matches!(profile.network.access_mode, NetworkAccessMode::None));
assert!(profile.security.strict_syscall_filtering);
assert!(profile.validate().is_ok());
}
#[test]
fn test_sandbox_profile_standard_default() {
let profile = SandboxProfile::standard_default();
assert_eq!(profile.resources.max_memory_mb, 1024);
assert_eq!(profile.resources.max_cpu_cores, 2.0);
assert!(matches!(profile.resources.gpu_access, GpuAccess::Shared { .. }));
assert!(matches!(profile.network.access_mode, NetworkAccessMode::Restricted));
assert!(!profile.security.strict_syscall_filtering);
assert!(profile.validate().is_ok());
}
#[test]
fn test_sandbox_profile_validation_zero_memory() {
let mut profile = SandboxProfile::secure_default();
profile.resources.max_memory_mb = 0;
let result = profile.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_memory_mb must be > 0"));
}
#[test]
fn test_sandbox_profile_validation_zero_cpu() {
let mut profile = SandboxProfile::secure_default();
profile.resources.max_cpu_cores = 0.0;
let result = profile.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_cpu_cores must be > 0"));
}
#[test]
fn test_sandbox_profile_validation_empty_read_path() {
let mut profile = SandboxProfile::secure_default();
profile.filesystem.read_paths.push("".to_string());
let result = profile.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("read_paths cannot contain empty strings"));
}
#[test]
fn test_sandbox_profile_validation_zero_execution_time() {
let mut profile = SandboxProfile::secure_default();
profile.process_limits.max_execution_time_seconds = 0;
let result = profile.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_execution_time_seconds must be > 0"));
}
#[test]
fn test_model_provider_variants() {
let huggingface_model = Model {
id: "hf_model".to_string(),
name: "HuggingFace Model".to_string(),
provider: ModelProvider::HuggingFace { model_path: "microsoft/DialoGPT-medium".to_string() },
capabilities: vec![ModelCapability::TextGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 512,
preferred_cpu_cores: 1.0,
gpu_requirements: None,
},
};
let openai_model = Model {
id: "openai_model".to_string(),
name: "OpenAI Model".to_string(),
provider: ModelProvider::OpenAI { model_name: "gpt-3.5-turbo".to_string() },
capabilities: vec![ModelCapability::TextGeneration, ModelCapability::Reasoning],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 0, preferred_cpu_cores: 0.0,
gpu_requirements: None,
},
};
assert_eq!(huggingface_model.id, "hf_model");
assert_eq!(openai_model.id, "openai_model");
}
#[test]
fn test_model_capabilities() {
let all_capabilities = vec![
ModelCapability::TextGeneration,
ModelCapability::CodeGeneration,
ModelCapability::Reasoning,
ModelCapability::ToolUse,
ModelCapability::FunctionCalling,
ModelCapability::Embeddings,
];
let model = Model {
id: "full_model".to_string(),
name: "Full Capability Model".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/full.gguf") },
capabilities: all_capabilities.clone(),
resource_requirements: ModelResourceRequirements {
min_memory_mb: 2048,
preferred_cpu_cores: 4.0,
gpu_requirements: Some(GpuRequirements {
min_vram_mb: 8192,
compute_capability: "7.5".to_string(),
}),
},
};
assert_eq!(model.capabilities.len(), 6);
for capability in &all_capabilities {
assert!(model.capabilities.contains(capability));
}
}
#[test]
fn test_config_validation_vector_dimension() {
let mut config = Config::default();
config.database.vector_dimension = 0;
let result = config.validate();
assert!(result.is_err());
if let Err(ConfigError::InvalidValue { key, reason }) = result {
assert_eq!(key, "database.vector_dimension");
assert!(reason.contains("Vector dimension must be > 0"));
}
}
#[test]
fn test_config_validation_with_slm() {
let mut config = Config::default();
let mut slm = Slm::default();
slm.enabled = true;
slm.default_sandbox_profile = "invalid".to_string(); config.slm = Some(slm);
let result = config.validate();
assert!(result.is_err());
}
#[test]
fn test_config_secret_key_retrieval() {
env::set_var("TEST_SECRET_KEY", "test_secret_123");
let mut config = Config::default();
config.security.key_provider = KeyProvider::Environment {
var_name: "TEST_SECRET_KEY".to_string()
};
let key = config.get_secret_key();
assert!(key.is_ok());
assert_eq!(key.unwrap(), "test_secret_123");
env::remove_var("TEST_SECRET_KEY");
}
#[test]
fn test_config_secret_key_missing() {
let mut config = Config::default();
config.security.key_provider = KeyProvider::Environment {
var_name: "NONEXISTENT_KEY".to_string()
};
let result = config.get_secret_key();
assert!(result.is_err());
if let Err(ConfigError::MissingRequired { key }) = result {
assert_eq!(key, "NONEXISTENT_KEY");
}
}
#[test]
fn test_network_policy_configurations() {
let destination = NetworkDestination {
host: "api.openai.com".to_string(),
port: Some(443),
protocol: Some(NetworkProtocol::HTTPS),
};
let network_policy = NetworkPolicy {
access_mode: NetworkAccessMode::Restricted,
allowed_destinations: vec![destination],
max_bandwidth_mbps: Some(100),
};
let profile = SandboxProfile {
resources: ResourceConstraints {
max_memory_mb: 1024,
max_cpu_cores: 2.0,
max_disk_mb: 500,
gpu_access: GpuAccess::None,
max_io_bandwidth_mbps: Some(50),
},
filesystem: FilesystemControls {
read_paths: vec!["/tmp/*".to_string()],
write_paths: vec!["/tmp/output/*".to_string()],
denied_paths: vec!["/etc/*".to_string()],
allow_temp_files: true,
max_file_size_mb: 10,
},
process_limits: ProcessLimits {
max_child_processes: 2,
max_execution_time_seconds: 300,
allowed_syscalls: vec!["read".to_string(), "write".to_string()],
process_priority: 0,
},
network: network_policy,
security: SecuritySettings {
strict_syscall_filtering: true,
disable_debugging: true,
enable_audit_logging: true,
require_encryption: false,
},
};
assert!(profile.validate().is_ok());
assert!(matches!(profile.network.access_mode, NetworkAccessMode::Restricted));
assert_eq!(profile.network.allowed_destinations.len(), 1);
assert_eq!(profile.network.allowed_destinations[0].host, "api.openai.com");
}
#[test]
fn test_gpu_requirements_configurations() {
let gpu_requirements = GpuRequirements {
min_vram_mb: 4096,
compute_capability: "8.0".to_string(),
};
let model = Model {
id: "gpu_model".to_string(),
name: "GPU Model".to_string(),
provider: ModelProvider::LocalFile { file_path: PathBuf::from("/tmp/gpu.gguf") },
capabilities: vec![ModelCapability::TextGeneration],
resource_requirements: ModelResourceRequirements {
min_memory_mb: 1024,
preferred_cpu_cores: 2.0,
gpu_requirements: Some(gpu_requirements),
},
};
assert!(model.resource_requirements.gpu_requirements.is_some());
let gpu_req = model.resource_requirements.gpu_requirements.unwrap();
assert_eq!(gpu_req.min_vram_mb, 4096);
assert_eq!(gpu_req.compute_capability, "8.0");
}
#[test]
fn test_config_from_env_invalid_port() {
env::set_var("API_PORT", "invalid");
let result = Config::from_env();
assert!(result.is_err());
if let Err(ConfigError::InvalidValue { key, reason }) = result {
assert_eq!(key, "API_PORT");
assert!(reason.contains("Invalid port number"));
}
env::remove_var("API_PORT");
}
#[test]
fn test_api_auth_token_missing() {
let config = Config::default();
let result = config.get_api_auth_token();
assert!(result.is_err());
if let Err(ConfigError::MissingRequired { key }) = result {
assert_eq!(key, "API_AUTH_TOKEN");
}
}
#[test]
fn test_database_url_missing() {
let config = Config::default();
let result = config.get_database_url();
assert!(result.is_err());
if let Err(ConfigError::MissingRequired { key }) = result {
assert_eq!(key, "DATABASE_URL");
}
}
}