use super::{ConfigError, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CloudConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub provider: CloudProvider,
#[serde(default)]
pub aws: AwsConfig,
#[serde(default)]
pub gcp: GcpConfig,
#[serde(default)]
pub azure: AzureConfig,
}
impl CloudConfig {
pub fn validate(&self) -> Result<()> {
if !self.enabled {
return Ok(());
}
match self.provider {
CloudProvider::Aws => self.aws.validate()?,
CloudProvider::Gcp => self.gcp.validate()?,
CloudProvider::Azure => self.azure.validate()?,
CloudProvider::None => {
return Err(ConfigError::ValidationError(
"Cloud enabled but no provider specified".to_string(),
))
}
}
Ok(())
}
}
impl Default for CloudConfig {
fn default() -> Self {
Self {
enabled: false,
provider: CloudProvider::None,
aws: AwsConfig::default(),
gcp: GcpConfig::default(),
azure: AzureConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum CloudProvider {
#[default]
None,
Aws,
Gcp,
Azure,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AwsConfig {
pub region: Option<String>,
#[serde(default)]
pub secrets: AwsSecretsConfig,
#[serde(default)]
pub storage: AwsStorageConfig,
#[serde(default)]
pub observability: AwsObservabilityConfig,
}
impl AwsConfig {
fn validate(&self) -> Result<()> {
self.secrets.validate()?;
self.storage.validate()?;
self.observability.validate()?;
Ok(())
}
}
impl Default for AwsConfig {
fn default() -> Self {
Self {
region: None,
secrets: AwsSecretsConfig::default(),
storage: AwsStorageConfig::default(),
observability: AwsObservabilityConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AwsSecretsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl AwsSecretsConfig {
fn validate(&self) -> Result<()> {
if self.enabled && self.cache_ttl_seconds == 0 {
return Err(ConfigError::ValidationError(
"AWS Secrets cache TTL must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AwsStorageConfig {
#[serde(default)]
pub enabled: bool,
pub bucket: Option<String>,
pub prefix: Option<String>,
}
impl AwsStorageConfig {
fn validate(&self) -> Result<()> {
if self.enabled && self.bucket.is_none() {
return Err(ConfigError::ValidationError(
"AWS S3 bucket must be specified when enabled".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AwsObservabilityConfig {
#[serde(default)]
pub metrics_enabled: bool,
#[serde(default)]
pub logs_enabled: bool,
pub namespace: Option<String>,
pub log_group: Option<String>,
pub log_stream: Option<String>,
}
impl AwsObservabilityConfig {
fn validate(&self) -> Result<()> {
if self.metrics_enabled && self.namespace.is_none() {
return Err(ConfigError::ValidationError(
"AWS CloudWatch namespace must be specified when metrics enabled".to_string(),
));
}
if self.logs_enabled {
if self.log_group.is_none() {
return Err(ConfigError::ValidationError(
"AWS CloudWatch log group must be specified when logs enabled".to_string(),
));
}
if self.log_stream.is_none() {
return Err(ConfigError::ValidationError(
"AWS CloudWatch log stream must be specified when logs enabled".to_string(),
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcpConfig {
pub project_id: Option<String>,
#[serde(default)]
pub secrets: GcpSecretsConfig,
#[serde(default)]
pub storage: GcpStorageConfig,
#[serde(default)]
pub observability: GcpObservabilityConfig,
}
impl GcpConfig {
fn validate(&self) -> Result<()> {
self.secrets.validate()?;
self.storage.validate()?;
self.observability.validate()?;
Ok(())
}
}
impl Default for GcpConfig {
fn default() -> Self {
Self {
project_id: None,
secrets: GcpSecretsConfig::default(),
storage: GcpStorageConfig::default(),
observability: GcpObservabilityConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GcpSecretsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl GcpSecretsConfig {
fn validate(&self) -> Result<()> {
if self.enabled && self.cache_ttl_seconds == 0 {
return Err(ConfigError::ValidationError(
"GCP Secrets cache TTL must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GcpStorageConfig {
#[serde(default)]
pub enabled: bool,
pub bucket: Option<String>,
pub prefix: Option<String>,
}
impl GcpStorageConfig {
fn validate(&self) -> Result<()> {
if self.enabled && self.bucket.is_none() {
return Err(ConfigError::ValidationError(
"GCP Cloud Storage bucket must be specified when enabled".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GcpObservabilityConfig {
#[serde(default)]
pub metrics_enabled: bool,
#[serde(default)]
pub logs_enabled: bool,
pub log_name: Option<String>,
}
impl GcpObservabilityConfig {
fn validate(&self) -> Result<()> {
if self.logs_enabled && self.log_name.is_none() {
return Err(ConfigError::ValidationError(
"GCP Cloud Logging log name must be specified when logs enabled".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureConfig {
#[serde(default)]
pub secrets: AzureSecretsConfig,
#[serde(default)]
pub storage: AzureStorageConfig,
#[serde(default)]
pub observability: AzureObservabilityConfig,
}
impl AzureConfig {
fn validate(&self) -> Result<()> {
self.secrets.validate()?;
self.storage.validate()?;
self.observability.validate()?;
Ok(())
}
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
secrets: AzureSecretsConfig::default(),
storage: AzureStorageConfig::default(),
observability: AzureObservabilityConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AzureSecretsConfig {
#[serde(default)]
pub enabled: bool,
pub vault_url: Option<String>,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
}
impl AzureSecretsConfig {
fn validate(&self) -> Result<()> {
if self.enabled {
if self.vault_url.is_none() {
return Err(ConfigError::ValidationError(
"Azure Key Vault URL must be specified when enabled".to_string(),
));
}
if self.cache_ttl_seconds == 0 {
return Err(ConfigError::ValidationError(
"Azure Key Vault cache TTL must be greater than 0".to_string(),
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AzureStorageConfig {
#[serde(default)]
pub enabled: bool,
pub account_name: Option<String>,
pub container_name: Option<String>,
}
impl AzureStorageConfig {
fn validate(&self) -> Result<()> {
if self.enabled {
if self.account_name.is_none() {
return Err(ConfigError::ValidationError(
"Azure storage account name must be specified when enabled".to_string(),
));
}
if self.container_name.is_none() {
return Err(ConfigError::ValidationError(
"Azure container name must be specified when enabled".to_string(),
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AzureObservabilityConfig {
#[serde(default)]
pub metrics_enabled: bool,
#[serde(default)]
pub logs_enabled: bool,
pub resource_id: Option<String>,
pub region: Option<String>,
pub workspace_id: Option<String>,
pub shared_key: Option<String>,
pub log_type: Option<String>,
}
impl AzureObservabilityConfig {
fn validate(&self) -> Result<()> {
if self.metrics_enabled {
if self.resource_id.is_none() {
return Err(ConfigError::ValidationError(
"Azure Monitor resource ID must be specified when metrics enabled"
.to_string(),
));
}
if self.region.is_none() {
return Err(ConfigError::ValidationError(
"Azure region must be specified when metrics enabled".to_string(),
));
}
}
if self.logs_enabled {
if self.workspace_id.is_none() {
return Err(ConfigError::ValidationError(
"Azure Log Analytics workspace ID must be specified when logs enabled"
.to_string(),
));
}
if self.shared_key.is_none() {
return Err(ConfigError::ValidationError(
"Azure Log Analytics shared key must be specified when logs enabled"
.to_string(),
));
}
if self.log_type.is_none() {
return Err(ConfigError::ValidationError(
"Azure log type must be specified when logs enabled".to_string(),
));
}
}
Ok(())
}
}
fn default_cache_ttl() -> u64 {
300 }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cloud_config_default() {
let config = CloudConfig::default();
assert!(!config.enabled);
assert_eq!(config.provider, CloudProvider::None);
assert!(config.validate().is_ok());
}
#[test]
fn test_cloud_config_validation_no_provider() {
let mut config = CloudConfig::default();
config.enabled = true;
assert!(config.validate().is_err());
}
#[test]
fn test_aws_config_validation() {
let mut config = AwsConfig::default();
config.storage.enabled = true;
assert!(config.validate().is_err());
config.storage.bucket = Some("my-bucket".to_string());
assert!(config.validate().is_ok());
config.observability.metrics_enabled = true;
assert!(config.validate().is_err());
config.observability.namespace = Some("LLMShield".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_gcp_config_validation() {
let mut config = GcpConfig::default();
config.storage.enabled = true;
assert!(config.validate().is_err());
config.storage.bucket = Some("my-bucket".to_string());
assert!(config.validate().is_ok());
config.observability.logs_enabled = true;
assert!(config.validate().is_err());
config.observability.log_name = Some("llm-shield-logs".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_azure_config_validation() {
let mut config = AzureConfig::default();
config.secrets.enabled = true;
assert!(config.validate().is_err());
config.secrets.vault_url = Some("https://my-vault.vault.azure.net".to_string());
assert!(config.validate().is_ok());
config.storage.enabled = true;
assert!(config.validate().is_err());
config.storage.account_name = Some("myaccount".to_string());
assert!(config.validate().is_err());
config.storage.container_name = Some("models".to_string());
assert!(config.validate().is_ok());
}
}