use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::interval;
use tonic::Request;
use tonic::metadata::MetadataValue;
use tonic::transport::{Channel, Endpoint};
use crate::error::PolicyError;
use crate::policy::Policy;
use crate::proto::tero::policy::v1::policy_service_client::PolicyServiceClient;
use crate::proto::tero::policy::v1::{ClientMetadata, SyncRequest};
use super::sync::collect_policy_statuses;
use super::{PolicyCallback, PolicyProvider, StatsCollector};
#[derive(Debug, Clone)]
pub struct GrpcProviderConfig {
pub url: String,
pub headers: HashMap<String, String>,
pub poll_interval_ns: u64,
pub client_metadata: Option<ClientMetadata>,
}
impl GrpcProviderConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
headers: HashMap::new(),
poll_interval_ns: Duration::from_secs(60).as_nanos() as u64,
client_metadata: None,
}
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers.extend(headers);
self
}
pub fn poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval_ns = interval.as_nanos() as u64;
self
}
pub fn poll_interval_ns(mut self, ns: u64) -> Self {
self.poll_interval_ns = ns;
self
}
pub fn client_metadata(mut self, metadata: ClientMetadata) -> Self {
self.client_metadata = Some(metadata);
self
}
}
pub struct GrpcProvider {
config: GrpcProviderConfig,
last_hash: RwLock<Option<String>>,
last_sync_timestamp: RwLock<u64>,
running: AtomicBool,
stats_collector: RwLock<Option<StatsCollector>>,
initial_policies: RwLock<Option<Vec<Policy>>>,
}
impl GrpcProvider {
pub fn new(config: GrpcProviderConfig) -> Self {
Self {
config,
last_hash: RwLock::new(None),
last_sync_timestamp: RwLock::new(0),
running: AtomicBool::new(false),
stats_collector: RwLock::new(None),
initial_policies: RwLock::new(None),
}
}
pub async fn new_with_initial_fetch(config: GrpcProviderConfig) -> Result<Self, PolicyError> {
let provider = Self::new(config);
let policies = provider.sync(true).await?;
*provider.initial_policies.write().unwrap() = Some(policies);
Ok(provider)
}
pub async fn load(&self) -> Result<Vec<Policy>, PolicyError> {
self.sync(true).await
}
fn build_sync_request(&self, full_sync: bool) -> SyncRequest {
let last_hash = self.last_hash.read().unwrap().clone().unwrap_or_default();
let last_timestamp = *self.last_sync_timestamp.read().unwrap();
let policy_statuses = collect_policy_statuses(&self.stats_collector.read().unwrap());
SyncRequest {
client_metadata: self.config.client_metadata.clone(),
full_sync,
last_sync_timestamp_unix_nano: last_timestamp,
last_successful_hash: last_hash,
policy_statuses,
}
}
async fn create_channel(&self) -> Result<Channel, PolicyError> {
let endpoint = Endpoint::from_shared(self.config.url.clone())
.map_err(|e| PolicyError::GrpcError(format!("Invalid URL: {}", e)))?;
endpoint
.connect()
.await
.map_err(|e| PolicyError::GrpcError(format!("Connection failed: {}", e)))
}
fn create_request<T>(&self, message: T) -> Request<T> {
let mut request = Request::new(message);
for (key, value) in &self.config.headers {
if let (Ok(key), Ok(value)) = (
key.parse::<tonic::metadata::MetadataKey<_>>(),
value.parse::<MetadataValue<_>>(),
) {
request.metadata_mut().insert(key, value);
}
}
request
}
async fn sync(&self, full_sync: bool) -> Result<Vec<Policy>, PolicyError> {
let channel = self.create_channel().await?;
let mut client = PolicyServiceClient::new(channel);
let sync_request = self.build_sync_request(full_sync);
let request = self.create_request(sync_request);
let response = client
.sync(request)
.await
.map_err(|e| PolicyError::GrpcError(format!("Sync RPC failed: {}", e)))?;
let sync_response = response.into_inner();
if !sync_response.error_message.is_empty() {
return Err(PolicyError::GrpcError(format!(
"Sync error: {}",
sync_response.error_message
)));
}
if !sync_response.hash.is_empty() {
*self.last_hash.write().unwrap() = Some(sync_response.hash);
}
if sync_response.sync_timestamp_unix_nano > 0 {
*self.last_sync_timestamp.write().unwrap() = sync_response.sync_timestamp_unix_nano;
}
let policies = sync_response
.policies
.into_iter()
.map(Policy::new)
.collect();
Ok(policies)
}
pub fn start_polling(
&self,
) -> mpsc::Receiver<Result<(Option<String>, Vec<Policy>), PolicyError>> {
let (tx, rx) = mpsc::channel(16);
self.running.store(true, Ordering::SeqCst);
let config = self.config.clone();
let last_hash = Arc::new(RwLock::new(None::<String>));
let last_sync_timestamp = Arc::new(RwLock::new(0u64));
let stats_collector = self.stats_collector.read().unwrap().clone();
let running = Arc::new(AtomicBool::new(true));
let running_clone = running.clone();
let last_hash_clone = last_hash.clone();
let last_sync_timestamp_clone = last_sync_timestamp.clone();
tokio::spawn(async move {
let poll_duration = Duration::from_nanos(config.poll_interval_ns);
let mut interval_timer = interval(poll_duration);
let mut first = true;
while running_clone.load(Ordering::SeqCst) {
interval_timer.tick().await;
let result = async {
let endpoint = Endpoint::from_shared(config.url.clone())
.map_err(|e| PolicyError::GrpcError(format!("Invalid URL: {}", e)))?;
let channel = endpoint
.connect()
.await
.map_err(|e| PolicyError::GrpcError(format!("Connection failed: {}", e)))?;
let mut client = PolicyServiceClient::new(channel);
let last_hash_val = last_hash_clone.read().unwrap().clone().unwrap_or_default();
let last_timestamp = *last_sync_timestamp_clone.read().unwrap();
let policy_statuses = collect_policy_statuses(&stats_collector);
let sync_request = SyncRequest {
client_metadata: config.client_metadata.clone(),
full_sync: first,
last_sync_timestamp_unix_nano: last_timestamp,
last_successful_hash: last_hash_val,
policy_statuses,
};
let mut request = Request::new(sync_request);
for (key, value) in &config.headers {
if let (Ok(key), Ok(value)) = (
key.parse::<tonic::metadata::MetadataKey<_>>(),
value.parse::<MetadataValue<_>>(),
) {
request.metadata_mut().insert(key, value);
}
}
let response = client
.sync(request)
.await
.map_err(|e| PolicyError::GrpcError(format!("Sync RPC failed: {}", e)))?;
let sync_response = response.into_inner();
if !sync_response.error_message.is_empty() {
return Err(PolicyError::GrpcError(format!(
"Sync error: {}",
sync_response.error_message
)));
}
let new_hash = if !sync_response.hash.is_empty() {
let hash = Some(sync_response.hash);
*last_hash_clone.write().unwrap() = hash.clone();
hash
} else {
None
};
if sync_response.sync_timestamp_unix_nano > 0 {
*last_sync_timestamp_clone.write().unwrap() =
sync_response.sync_timestamp_unix_nano;
}
let policies: Vec<Policy> = sync_response
.policies
.into_iter()
.map(Policy::new)
.collect();
Ok((new_hash, policies))
}
.await;
first = false;
if tx.send(result).await.is_err() {
break; }
}
});
rx
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
}
impl PolicyProvider for GrpcProvider {
fn set_stats_collector(&self, collector: StatsCollector) {
*self.stats_collector.write().unwrap() = Some(collector);
}
fn subscribe(&self, callback: PolicyCallback) -> Result<(), PolicyError> {
let policies = self
.initial_policies
.write()
.unwrap()
.take()
.expect("GrpcProvider::subscribe() requires new_with_initial_fetch()");
callback(policies);
let initial_hash = self.last_hash.read().unwrap().clone();
let mut rx = self.start_polling();
let callback = callback.clone();
tokio::spawn(async move {
let mut last_known_hash = initial_hash;
while let Some(result) = rx.recv().await {
match result {
Ok((new_hash, policies)) => {
if new_hash != last_known_hash {
last_known_hash = new_hash;
callback(policies);
}
}
Err(e) => {
eprintln!("gRPC provider sync error: {}", e);
}
}
}
});
Ok(())
}
}