use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum CloudProvider {
AWS,
GCP,
Azure,
None,
}
impl CloudProvider {
pub fn as_str(&self) -> &str {
match self {
CloudProvider::AWS => "aws",
CloudProvider::GCP => "gcp",
CloudProvider::Azure => "azure",
CloudProvider::None => "none",
}
}
pub fn is_enabled(&self) -> bool {
!matches!(self, CloudProvider::None)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CloudConfig {
#[serde(default = "default_provider")]
pub provider: CloudProvider,
#[serde(default)]
pub aws: AwsConfig,
#[serde(default)]
pub gcp: GcpConfig,
#[serde(default)]
pub azure: AzureConfig,
}
fn default_provider() -> CloudProvider {
CloudProvider::None
}
impl Default for CloudConfig {
fn default() -> Self {
Self {
provider: CloudProvider::None,
aws: AwsConfig::default(),
gcp: GcpConfig::default(),
azure: AzureConfig::default(),
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AwsConfig {
#[serde(default = "default_aws_region")]
pub region: String,
#[serde(default)]
pub secrets_manager: AwsSecretsManagerConfig,
#[serde(default)]
pub s3: AwsS3Config,
#[serde(default)]
pub cloudwatch: AwsCloudWatchConfig,
#[serde(default)]
pub xray: AwsXRayConfig,
}
fn default_aws_region() -> String {
"us-east-1".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AwsSecretsManagerConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl Default for AwsSecretsManagerConfig {
fn default() -> Self {
Self {
enabled: false,
cache_ttl_seconds: 300,
}
}
}
fn default_cache_ttl() -> u64 {
300
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AwsS3Config {
#[serde(default)]
pub bucket: String,
#[serde(default = "default_models_prefix")]
pub models_prefix: String,
#[serde(default = "default_results_prefix")]
pub results_prefix: String,
}
fn default_models_prefix() -> String {
"models/".to_string()
}
fn default_results_prefix() -> String {
"scan-results/".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AwsCloudWatchConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cw_namespace")]
pub namespace: String,
#[serde(default = "default_cw_log_group")]
pub log_group: String,
}
impl Default for AwsCloudWatchConfig {
fn default() -> Self {
Self {
enabled: false,
namespace: "LLMShield".to_string(),
log_group: "/llm-shield/api".to_string(),
}
}
}
fn default_cw_namespace() -> String {
"LLMShield".to_string()
}
fn default_cw_log_group() -> String {
"/llm-shield/api".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AwsXRayConfig {
#[serde(default)]
pub enabled: bool,
}
impl Default for AwsXRayConfig {
fn default() -> Self {
Self { enabled: false }
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct GcpConfig {
#[serde(default)]
pub project_id: String,
#[serde(default)]
pub secret_manager: GcpSecretManagerConfig,
#[serde(default)]
pub cloud_storage: GcpCloudStorageConfig,
#[serde(default)]
pub cloud_logging: GcpCloudLoggingConfig,
#[serde(default)]
pub cloud_trace: GcpCloudTraceConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GcpSecretManagerConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl Default for GcpSecretManagerConfig {
fn default() -> Self {
Self {
enabled: false,
cache_ttl_seconds: 300,
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct GcpCloudStorageConfig {
#[serde(default)]
pub bucket: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GcpCloudLoggingConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_gcp_log_name")]
pub log_name: String,
}
impl Default for GcpCloudLoggingConfig {
fn default() -> Self {
Self {
enabled: false,
log_name: "llm-shield-api".to_string(),
}
}
}
fn default_gcp_log_name() -> String {
"llm-shield-api".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GcpCloudTraceConfig {
#[serde(default)]
pub enabled: bool,
}
impl Default for GcpCloudTraceConfig {
fn default() -> Self {
Self { enabled: false }
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AzureConfig {
#[serde(default)]
pub subscription_id: String,
#[serde(default)]
pub resource_group: String,
#[serde(default)]
pub key_vault: AzureKeyVaultConfig,
#[serde(default)]
pub blob_storage: AzureBlobStorageConfig,
#[serde(default)]
pub monitor: AzureMonitorConfig,
#[serde(default)]
pub application_insights: AzureApplicationInsightsConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AzureKeyVaultConfig {
#[serde(default)]
pub vault_url: String,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl Default for AzureKeyVaultConfig {
fn default() -> Self {
Self {
vault_url: String::new(),
cache_ttl_seconds: 300,
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AzureBlobStorageConfig {
#[serde(default)]
pub account: String,
#[serde(default)]
pub container: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AzureMonitorConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub workspace_id: String,
}
impl Default for AzureMonitorConfig {
fn default() -> Self {
Self {
enabled: false,
workspace_id: String::new(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AzureApplicationInsightsConfig {
#[serde(default)]
pub instrumentation_key: String,
#[serde(default)]
pub enabled: bool,
}
impl Default for AzureApplicationInsightsConfig {
fn default() -> Self {
Self {
instrumentation_key: String::new(),
enabled: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cloud_provider_as_str() {
assert_eq!(CloudProvider::AWS.as_str(), "aws");
assert_eq!(CloudProvider::GCP.as_str(), "gcp");
assert_eq!(CloudProvider::Azure.as_str(), "azure");
assert_eq!(CloudProvider::None.as_str(), "none");
}
#[test]
fn test_cloud_provider_is_enabled() {
assert!(CloudProvider::AWS.is_enabled());
assert!(CloudProvider::GCP.is_enabled());
assert!(CloudProvider::Azure.is_enabled());
assert!(!CloudProvider::None.is_enabled());
}
#[test]
fn test_cloud_config_default() {
let config = CloudConfig::default();
assert_eq!(config.provider, CloudProvider::None);
assert!(!config.provider.is_enabled());
}
#[test]
fn test_aws_config_defaults() {
let config = AwsConfig::default();
assert_eq!(config.region, "us-east-1");
assert!(!config.secrets_manager.enabled);
assert_eq!(config.secrets_manager.cache_ttl_seconds, 300);
}
#[test]
fn test_config_deserialization() {
let yaml = r#"
provider: aws
aws:
region: us-west-2
secrets_manager:
enabled: true
cache_ttl_seconds: 600
s3:
bucket: my-bucket
"#;
let config: CloudConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.provider, CloudProvider::AWS);
assert_eq!(config.aws.region, "us-west-2");
assert!(config.aws.secrets_manager.enabled);
assert_eq!(config.aws.secrets_manager.cache_ttl_seconds, 600);
assert_eq!(config.aws.s3.bucket, "my-bucket");
}
#[test]
fn test_config_serialization() {
let config = CloudConfig {
provider: CloudProvider::GCP,
gcp: GcpConfig {
project_id: "my-project".to_string(),
secret_manager: GcpSecretManagerConfig {
enabled: true,
cache_ttl_seconds: 300,
},
..Default::default()
},
..Default::default()
};
let yaml = serde_yaml::to_string(&config).unwrap();
assert!(yaml.contains("provider: gcp"));
assert!(yaml.contains("project_id: my-project"));
}
}