use serde::{Deserialize, Serialize};
use std::sync::Arc;
use super::blocklist::BlocklistCache;
use super::client::{HorizonClient, MetricsProvider};
use super::config::HorizonConfig;
use super::error::HorizonError;
use super::types::{ConnectionState, ThreatSignal};
use crate::config_manager::ConfigManager;
use crate::utils::circuit_breaker::CircuitBreaker;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HorizonStats {
pub connection_state: String,
pub signals_sent: u64,
pub signals_acked: u64,
pub batches_sent: u64,
pub blocklist_size: usize,
pub blocked_ips: usize,
pub blocked_fingerprints: usize,
pub last_heartbeat: i64,
pub heartbeats_sent: u64,
pub heartbeat_failures: u64,
pub reconnect_attempts: u32,
pub tenant_id: Option<String>,
pub capabilities: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HorizonStatsSnapshot {
pub enabled: bool,
pub connected: bool,
pub connection_state: String,
pub signals_sent: u64,
pub signals_acked: u64,
pub batches_sent: u64,
pub blocklist_size: usize,
pub blocked_ips: usize,
pub blocked_fingerprints: usize,
pub heartbeats_sent: u64,
pub heartbeat_failures: u64,
pub reconnect_attempts: u32,
}
pub struct HorizonManager {
client: Arc<HorizonClient>,
config: HorizonConfig,
}
impl HorizonManager {
pub async fn new(config: HorizonConfig) -> Result<Self, HorizonError> {
config.validate()?;
let client = HorizonClient::new(config.clone());
Ok(Self {
client: Arc::new(client),
config,
})
}
pub async fn with_metrics_provider(
config: HorizonConfig,
metrics_provider: Arc<dyn MetricsProvider>,
) -> Result<Self, HorizonError> {
config.validate()?;
let client = HorizonClient::new(config.clone()).with_metrics_provider(metrics_provider);
Ok(Self {
client: Arc::new(client),
config,
})
}
pub fn set_config_manager(&self, config_manager: Arc<ConfigManager>) {
self.client.set_config_manager(config_manager);
}
pub async fn start(&self) -> Result<(), HorizonError> {
self.client.start().await
}
pub async fn stop(&self) {
self.client.stop().await;
}
pub fn report_signal(&self, signal: ThreatSignal) {
self.client.report_signal(signal);
}
pub async fn flush_signals(&self) {
self.client.flush_signals().await;
}
#[inline]
pub fn is_ip_blocked(&self, ip: &str) -> bool {
self.client.is_ip_blocked(ip)
}
#[inline]
pub fn is_fingerprint_blocked(&self, fingerprint: &str) -> bool {
self.client.is_fingerprint_blocked(fingerprint)
}
pub fn is_blocked(&self, ip: Option<&str>, fingerprint: Option<&str>) -> bool {
self.client.is_blocked(ip, fingerprint)
}
pub async fn connection_state(&self) -> ConnectionState {
self.client.connection_state().await
}
pub async fn is_connected(&self) -> bool {
self.client.is_connected().await
}
pub fn blocklist_size(&self) -> usize {
self.client.blocklist_size()
}
pub fn blocklist(&self) -> Arc<BlocklistCache> {
Arc::clone(self.client.blocklist())
}
pub fn circuit_breaker(&self) -> Arc<CircuitBreaker> {
self.client.circuit_breaker()
}
pub async fn stats(&self) -> HorizonStats {
let client_stats = self.client.stats();
let state = self.client.connection_state().await;
let blocklist = self.client.blocklist();
let tenant_id = self.client.tenant_id().await;
let capabilities = self.client.capabilities().await;
HorizonStats {
connection_state: state.as_str().to_string(),
signals_sent: client_stats.signals_sent,
signals_acked: client_stats.signals_acked,
batches_sent: client_stats.batches_sent,
blocklist_size: blocklist.size(),
blocked_ips: blocklist.ip_count(),
blocked_fingerprints: blocklist.fingerprint_count(),
last_heartbeat: chrono::Utc::now().timestamp_millis(),
heartbeats_sent: client_stats.heartbeats_sent,
heartbeat_failures: client_stats.heartbeat_failures,
reconnect_attempts: client_stats.reconnect_attempts,
tenant_id,
capabilities,
}
}
pub async fn stats_snapshot(&self) -> HorizonStatsSnapshot {
let stats = self.stats().await;
HorizonStatsSnapshot {
enabled: self.config.enabled,
connected: stats.connection_state == "connected",
connection_state: stats.connection_state,
signals_sent: stats.signals_sent,
signals_acked: stats.signals_acked,
batches_sent: stats.batches_sent,
blocklist_size: stats.blocklist_size,
blocked_ips: stats.blocked_ips,
blocked_fingerprints: stats.blocked_fingerprints,
heartbeats_sent: stats.heartbeats_sent,
heartbeat_failures: stats.heartbeat_failures,
reconnect_attempts: stats.reconnect_attempts,
}
}
pub fn config(&self) -> &HorizonConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[allow(dead_code)]
pub struct HorizonManagerBuilder {
config: HorizonConfig,
metrics_provider: Option<Arc<dyn MetricsProvider>>,
config_manager: Option<Arc<ConfigManager>>,
}
#[allow(dead_code)]
impl HorizonManagerBuilder {
pub fn new(config: HorizonConfig) -> Self {
Self {
config,
metrics_provider: None,
config_manager: None,
}
}
pub fn with_metrics_provider(mut self, provider: Arc<dyn MetricsProvider>) -> Self {
self.metrics_provider = Some(provider);
self
}
pub fn with_config_manager(mut self, manager: Arc<ConfigManager>) -> Self {
self.config_manager = Some(manager);
self
}
pub async fn build(self) -> Result<HorizonManager, HorizonError> {
let mut client = HorizonClient::new(self.config.clone());
if let Some(provider) = self.metrics_provider {
client = client.with_metrics_provider(provider);
}
if let Some(config_manager) = self.config_manager {
client = client.with_config_manager(config_manager);
}
Ok(HorizonManager {
client: Arc::new(client),
config: self.config,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::horizon::{Severity, SignalType};
#[tokio::test]
async fn test_manager_disabled() {
let config = HorizonConfig::default();
let manager = HorizonManager::new(config).await.unwrap();
assert!(!manager.is_enabled());
assert!(!manager.is_ip_blocked("192.168.1.1"));
assert!(!manager.is_fingerprint_blocked("abc123"));
}
#[tokio::test]
async fn test_manager_stats() {
let config = HorizonConfig::default();
let manager = HorizonManager::new(config).await.unwrap();
let stats = manager.stats().await;
assert_eq!(stats.signals_sent, 0);
assert_eq!(stats.blocklist_size, 0);
assert_eq!(stats.connection_state, "disconnected");
}
#[tokio::test]
async fn test_manager_stats_snapshot() {
let config = HorizonConfig::default();
let manager = HorizonManager::new(config).await.unwrap();
let snapshot = manager.stats_snapshot().await;
assert!(!snapshot.enabled);
assert!(!snapshot.connected);
assert_eq!(snapshot.signals_sent, 0);
}
#[tokio::test]
async fn test_manager_blocklist() {
let config = HorizonConfig::default();
let manager = HorizonManager::new(config).await.unwrap();
let blocklist = manager.blocklist();
blocklist.add(crate::horizon::BlocklistEntry {
block_type: crate::horizon::BlockType::Ip,
indicator: "192.168.1.100".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
assert!(manager.is_ip_blocked("192.168.1.100"));
assert!(!manager.is_ip_blocked("192.168.1.101"));
}
#[tokio::test]
async fn test_manager_is_blocked() {
let config = HorizonConfig::default();
let manager = HorizonManager::new(config).await.unwrap();
let blocklist = manager.blocklist();
blocklist.add(crate::horizon::BlocklistEntry {
block_type: crate::horizon::BlockType::Ip,
indicator: "192.168.1.100".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
blocklist.add(crate::horizon::BlocklistEntry {
block_type: crate::horizon::BlockType::Fingerprint,
indicator: "fp123".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
assert!(manager.is_blocked(Some("192.168.1.100"), None));
assert!(manager.is_blocked(None, Some("fp123")));
assert!(manager.is_blocked(Some("192.168.1.100"), Some("fp123")));
assert!(!manager.is_blocked(Some("192.168.1.101"), Some("fp456")));
}
#[tokio::test]
async fn test_builder() {
let config = HorizonConfig::default()
.with_hub_url("wss://example.com/ws")
.with_api_key("test")
.with_sensor_id("sensor");
let manager = HorizonManagerBuilder::new(config).build().await.unwrap();
assert!(manager.is_enabled());
}
#[test]
fn test_report_signal_non_blocking() {
let config = HorizonConfig::default();
let rt = tokio::runtime::Runtime::new().unwrap();
let manager = rt.block_on(HorizonManager::new(config)).unwrap();
for _ in 0..1000 {
manager.report_signal(ThreatSignal::new(SignalType::IpThreat, Severity::High));
}
}
}