use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[cfg(feature = "nats")]
use async_nats::{Client, ConnectOptions};
#[derive(Debug, Clone)]
pub struct ConnectionWrapper {
#[cfg(feature = "nats")]
pub client: Arc<Client>,
#[cfg(not(feature = "nats"))]
pub client: Arc<()>,
pub url: String,
pub is_healthy: bool,
pub last_health_check: DateTime<Utc>,
pub connection_attempts: u32,
pub last_error: Option<String>,
}
#[derive(Debug)]
pub struct ConnectionPool {
pub connections: Vec<ConnectionWrapper>,
pub active_index: usize,
pub max_connections: usize,
pub round_robin_counter: usize,
pub health_checks_enabled: bool,
}
impl ConnectionPool {
pub fn new(max_connections: usize) -> Self {
Self {
connections: Vec::new(),
active_index: 0,
max_connections,
round_robin_counter: 0,
health_checks_enabled: true,
}
}
#[cfg(feature = "nats")]
pub async fn add_connection(&mut self, url: String) -> Result<()> {
if self.connections.len() >= self.max_connections {
return Err(anyhow!("Connection pool at maximum capacity"));
}
let options = ConnectOptions::new();
let client = options
.connect(&url)
.await
.map_err(|e| anyhow!("Failed to connect to {}: {}", url, e))?;
let wrapper = ConnectionWrapper {
client: Arc::new(client),
url,
is_healthy: true,
last_health_check: Utc::now(),
connection_attempts: 1,
last_error: None,
};
let url = wrapper.url.clone();
self.connections.push(wrapper);
info!("Added NATS connection to pool: {}", url);
Ok(())
}
#[cfg(not(feature = "nats"))]
pub async fn add_connection(&mut self, url: String) -> Result<()> {
let wrapper = ConnectionWrapper {
client: Arc::new(()),
url,
is_healthy: true,
last_health_check: Utc::now(),
connection_attempts: 1,
last_error: None,
};
self.connections.push(wrapper);
info!("Added mock NATS connection to pool: {}", wrapper.url);
Ok(())
}
pub fn get_next_connection(&mut self) -> Option<&ConnectionWrapper> {
if self.connections.is_empty() {
return None;
}
let healthy_connections: Vec<(usize, &ConnectionWrapper)> = self
.connections
.iter()
.enumerate()
.filter(|(_, conn)| conn.is_healthy)
.collect();
if healthy_connections.is_empty() {
warn!("No healthy connections available in pool");
return None;
}
let selected_index = self.round_robin_counter % healthy_connections.len();
self.round_robin_counter = self.round_robin_counter.wrapping_add(1);
let (original_index, connection) = healthy_connections[selected_index];
self.active_index = original_index;
debug!(
"Selected connection: {} (index: {})",
connection.url, original_index
);
Some(connection)
}
pub async fn health_check_all(&mut self) {
if !self.health_checks_enabled {
return;
}
for connection in self.connections.iter_mut() {
ConnectionPool::health_check_connection_static(connection).await;
}
}
async fn health_check_connection_static(connection: &mut ConnectionWrapper) {
#[cfg(feature = "nats")]
{
let _info = connection.client.server_info();
{
connection.is_healthy = true;
connection.last_health_check = Utc::now();
connection.last_error = None;
debug!("Health check passed for: {}", connection.url);
}
}
#[cfg(not(feature = "nats"))]
{
connection.is_healthy = true;
connection.last_health_check = Utc::now();
}
}
async fn health_check_connection(&self, connection: &mut ConnectionWrapper) {
#[cfg(feature = "nats")]
{
let _info = connection.client.server_info();
{
connection.is_healthy = true;
connection.last_health_check = Utc::now();
connection.last_error = None;
debug!("Health check passed for: {}", connection.url);
}
}
#[cfg(not(feature = "nats"))]
{
connection.is_healthy = true;
connection.last_health_check = Utc::now();
}
}
pub fn mark_unhealthy(&mut self, index: usize, error: String) {
if let Some(connection) = self.connections.get_mut(index) {
connection.is_healthy = false;
connection.last_error = Some(error.clone());
warn!(
"Marked connection as unhealthy: {} - {}",
connection.url, error
);
}
}
pub fn get_stats(&self) -> ConnectionPoolStats {
let healthy_count = self.connections.iter().filter(|c| c.is_healthy).count();
ConnectionPoolStats {
total_connections: self.connections.len(),
healthy_connections: healthy_count,
unhealthy_connections: self.connections.len() - healthy_count,
max_connections: self.max_connections,
active_index: self.active_index,
health_checks_enabled: self.health_checks_enabled,
}
}
pub fn cleanup_unhealthy(&mut self) {
let initial_count = self.connections.len();
self.connections.retain(|conn| conn.is_healthy);
let removed_count = initial_count - self.connections.len();
if removed_count > 0 {
info!("Removed {} unhealthy connections from pool", removed_count);
if self.active_index >= self.connections.len() {
self.active_index = 0;
}
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionPoolStats {
pub total_connections: usize,
pub healthy_connections: usize,
pub unhealthy_connections: usize,
pub max_connections: usize,
pub active_index: usize,
pub health_checks_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connection_pool_creation() {
let pool = ConnectionPool::new(5);
assert_eq!(pool.max_connections, 5);
assert_eq!(pool.connections.len(), 0);
}
#[tokio::test]
async fn test_connection_pool_stats() {
let pool = ConnectionPool::new(10);
let stats = pool.get_stats();
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.healthy_connections, 0);
assert_eq!(stats.max_connections, 10);
}
#[tokio::test]
async fn test_get_next_connection_empty_pool() {
let mut pool = ConnectionPool::new(5);
assert!(pool.get_next_connection().is_none());
}
}