use serde::{Deserialize, Serialize};
use crate::error::PolicyError;
use crate::provider::FileProvider;
use crate::registry::{PolicyRegistry, ProviderId};
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ProviderConfig {
File(FileProviderConfig),
#[cfg(feature = "http")]
Http(HttpProviderConfig),
#[cfg(feature = "grpc")]
Grpc(GrpcProviderConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FileProviderConfig {
pub id: String,
pub path: String,
}
#[cfg(feature = "http")]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpProviderConfig {
pub id: String,
pub url: String,
#[serde(default)]
pub headers: Vec<Header>,
#[serde(default)]
pub poll_interval_secs: Option<u64>,
#[serde(default)]
pub content_type: Option<String>,
}
#[cfg(feature = "grpc")]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GrpcProviderConfig {
pub id: String,
pub url: String,
#[serde(default)]
pub headers: Vec<Header>,
#[serde(default)]
pub poll_interval_secs: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Header {
pub name: String,
pub value: String,
}
impl ProviderConfig {
pub fn id(&self) -> &str {
match self {
ProviderConfig::File(c) => &c.id,
#[cfg(feature = "http")]
ProviderConfig::Http(c) => &c.id,
#[cfg(feature = "grpc")]
ProviderConfig::Grpc(c) => &c.id,
}
}
pub async fn register(&self, registry: &PolicyRegistry) -> Result<ProviderId, PolicyError> {
match self {
ProviderConfig::File(config) => {
let provider = FileProvider::new(&config.path);
registry.subscribe(&provider)
}
#[cfg(feature = "http")]
ProviderConfig::Http(config) => {
use crate::provider::{
ContentType, HttpProvider, HttpProviderConfig as HttpConfig,
};
use std::time::Duration;
let mut http_config = HttpConfig::new(&config.url);
for header in &config.headers {
http_config = http_config.header(&header.name, &header.value);
}
if let Some(secs) = config.poll_interval_secs {
http_config = http_config.poll_interval(Duration::from_secs(secs));
}
if let Some(ref ct) = config.content_type {
let content_type = match ct.to_lowercase().as_str() {
"json" => ContentType::Json,
_ => ContentType::Protobuf,
};
http_config = http_config.content_type(content_type);
}
let provider = HttpProvider::new_with_initial_fetch(http_config).await?;
registry.subscribe(&provider)
}
#[cfg(feature = "grpc")]
ProviderConfig::Grpc(config) => {
use crate::provider::{GrpcProvider, GrpcProviderConfig as GrpcConfig};
use std::time::Duration;
let mut grpc_config = GrpcConfig::new(&config.url);
for header in &config.headers {
grpc_config = grpc_config.header(&header.name, &header.value);
}
if let Some(secs) = config.poll_interval_secs {
grpc_config = grpc_config.poll_interval(Duration::from_secs(secs));
}
let provider = GrpcProvider::new_with_initial_fetch(grpc_config).await?;
registry.subscribe(&provider)
}
}
}
}
pub async fn register_providers(
configs: &[ProviderConfig],
registry: &PolicyRegistry,
) -> Result<Vec<ProviderId>, PolicyError> {
let mut provider_ids = Vec::with_capacity(configs.len());
for config in configs {
let id = config.register(registry).await?;
provider_ids.push(id);
}
Ok(provider_ids)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_file_provider_config() {
let json = r#"{"id": "local", "type": "file", "path": "policies.json"}"#;
let config: ProviderConfig = serde_json::from_str(json).unwrap();
match config {
ProviderConfig::File(c) => {
assert_eq!(c.id, "local");
assert_eq!(c.path, "policies.json");
}
#[allow(unreachable_patterns)]
_ => panic!("Expected File provider config"),
}
}
#[test]
fn parse_provider_config_list() {
let json = r#"[
{"id": "local", "type": "file", "path": "policies.json"},
{"id": "backup", "type": "file", "path": "backup.json"}
]"#;
let configs: Vec<ProviderConfig> = serde_json::from_str(json).unwrap();
assert_eq!(configs.len(), 2);
assert_eq!(configs[0].id(), "local");
assert_eq!(configs[1].id(), "backup");
}
#[cfg(feature = "http")]
#[test]
fn parse_http_provider_config() {
let json = r#"{
"id": "remote",
"type": "http",
"url": "https://api.example.com/policies",
"headers": [
{"name": "Authorization", "value": "Bearer token123"}
],
"poll_interval_secs": 30,
"content_type": "json"
}"#;
let config: ProviderConfig = serde_json::from_str(json).unwrap();
match config {
ProviderConfig::Http(c) => {
assert_eq!(c.id, "remote");
assert_eq!(c.url, "https://api.example.com/policies");
assert_eq!(c.headers.len(), 1);
assert_eq!(c.headers[0].name, "Authorization");
assert_eq!(c.headers[0].value, "Bearer token123");
assert_eq!(c.poll_interval_secs, Some(30));
assert_eq!(c.content_type, Some("json".to_string()));
}
_ => panic!("Expected Http provider config"),
}
}
#[cfg(feature = "http")]
#[test]
fn parse_http_provider_config_minimal() {
let json = r#"{"id": "remote", "type": "http", "url": "https://api.example.com/policies"}"#;
let config: ProviderConfig = serde_json::from_str(json).unwrap();
match config {
ProviderConfig::Http(c) => {
assert_eq!(c.id, "remote");
assert_eq!(c.url, "https://api.example.com/policies");
assert!(c.headers.is_empty());
assert!(c.poll_interval_secs.is_none());
assert!(c.content_type.is_none());
}
_ => panic!("Expected Http provider config"),
}
}
#[cfg(feature = "grpc")]
#[test]
fn parse_grpc_provider_config() {
let json = r#"{
"id": "grpc-remote",
"type": "grpc",
"url": "https://grpc.example.com:50051",
"headers": [
{"name": "authorization", "value": "Bearer token123"}
],
"poll_interval_secs": 120
}"#;
let config: ProviderConfig = serde_json::from_str(json).unwrap();
match config {
ProviderConfig::Grpc(c) => {
assert_eq!(c.id, "grpc-remote");
assert_eq!(c.url, "https://grpc.example.com:50051");
assert_eq!(c.headers.len(), 1);
assert_eq!(c.poll_interval_secs, Some(120));
}
_ => panic!("Expected Grpc provider config"),
}
}
#[test]
fn provider_config_id() {
let json = r#"{"id": "test-id", "type": "file", "path": "test.json"}"#;
let config: ProviderConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.id(), "test-id");
}
#[tokio::test]
async fn register_file_provider() {
let config = ProviderConfig::File(FileProviderConfig {
id: "test".to_string(),
path: "testdata/policies.json".to_string(),
});
let registry = PolicyRegistry::new();
let provider_id = config.register(®istry).await.unwrap();
assert_eq!(registry.provider_count(), 1);
let snapshot = registry.snapshot();
assert!(!snapshot.is_empty());
let _ = provider_id;
}
#[tokio::test]
async fn register_multiple_providers() {
let configs = vec![
ProviderConfig::File(FileProviderConfig {
id: "provider1".to_string(),
path: "testdata/policies.json".to_string(),
}),
ProviderConfig::File(FileProviderConfig {
id: "provider2".to_string(),
path: "testdata/policies.json".to_string(),
}),
];
let registry = PolicyRegistry::new();
let provider_ids = register_providers(&configs, ®istry).await.unwrap();
assert_eq!(provider_ids.len(), 2);
assert_eq!(registry.provider_count(), 2);
}
#[test]
fn serialize_provider_config() {
let config = ProviderConfig::File(FileProviderConfig {
id: "test".to_string(),
path: "policies.json".to_string(),
});
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("\"type\":\"file\""));
assert!(json.contains("\"id\":\"test\""));
assert!(json.contains("\"path\":\"policies.json\""));
}
}