use crate::config::TargetConfig;
use crate::logging::{log_error_with_context, log_resource_cleanup, ssh_span};
use anyhow::{Context, Result};
use async_ssh2_tokio::client::Client;
use async_trait::async_trait;
use deadpool::managed::{Manager, Pool};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::time::{interval, sleep};
use tracing::{Instrument, debug, error, info, warn};
pub struct ConnectionManager {
pools: Arc<RwLock<HashMap<String, Arc<ConnectionPool>>>>,
health_checker: Arc<HealthChecker>,
ssh_managers: Arc<RwLock<HashMap<String, Arc<SshConnectionManager>>>>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SshConnectionState {
Connected,
Disconnected,
Reconnecting,
Failed,
}
pub struct SshConnectionManager {
ssh_client: Arc<RwLock<Option<Client>>>,
connection_state: Arc<RwLock<SshConnectionState>>,
config: crate::config::SshConfig,
reconnect_attempts: Arc<AtomicU32>,
target_name: String,
}
pub enum PooledConnection {
Direct,
SshTunneled,
}
pub struct ConnectionPool {
pool: Pool<ConnectionPoolManager>,
target_name: String,
}
pub struct ConnectionPoolManager {
target_config: TargetConfig,
ssh_manager: Option<Arc<SshConnectionManager>>,
}
#[async_trait]
impl Manager for ConnectionPoolManager {
type Type = PooledConnection;
type Error = anyhow::Error;
async fn create(&self) -> Result<Self::Type, Self::Error> {
if let Some(ssh) = &self.target_config.ssh {
if ssh.enabled {
return self.create_ssh_connection().await;
}
}
self.create_direct_connection().await
}
async fn recycle(
&self,
_conn: &mut Self::Type,
_metrics: &deadpool::managed::Metrics,
) -> deadpool::managed::RecycleResult<Self::Error> {
Ok(())
}
}
impl ConnectionPoolManager {
async fn create_direct_connection(&self) -> Result<PooledConnection> {
let target_addr = format!("{}:{}", self.target_config.host, self.target_config.port);
let _stream = TcpStream::connect(&target_addr)
.await
.with_context(|| format!("Failed to connect to {}", target_addr))?;
debug!("Created direct connection to {}", target_addr);
Ok(PooledConnection::Direct)
}
async fn create_ssh_connection(&self) -> Result<PooledConnection> {
let ssh_manager = self
.ssh_manager
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SSH manager not initialized"))?;
ssh_manager.ensure_connected().await?;
let _stream = ssh_manager
.create_tunneled_connection(&self.target_config)
.await?;
debug!(
"Created SSH tunneled connection to {}:{}",
self.target_config.host, self.target_config.port
);
Ok(PooledConnection::SshTunneled)
}
}
impl SshConnectionManager {
pub fn new(config: crate::config::SshConfig, target_name: String) -> Self {
Self {
ssh_client: Arc::new(RwLock::new(None)),
connection_state: Arc::new(RwLock::new(SshConnectionState::Disconnected)),
config,
reconnect_attempts: Arc::new(AtomicU32::new(0)),
target_name,
}
}
pub async fn ensure_connected(&self) -> Result<()> {
let current_state = self.connection_state.read().clone();
match current_state {
SshConnectionState::Connected => {
if self.is_connection_alive().await {
return Ok(());
} else {
warn!(
"SSH connection for {} appears to be dead, reconnecting",
self.target_name
);
self.set_state(SshConnectionState::Disconnected);
}
}
SshConnectionState::Reconnecting => {
return self.wait_for_reconnection().await;
}
SshConnectionState::Failed => {
return Err(anyhow::anyhow!(
"SSH connection failed for {}",
self.target_name
));
}
SshConnectionState::Disconnected => {}
}
self.reconnect().await
}
pub async fn reconnect(&self) -> Result<()> {
if !self.config.auto_reconnect {
warn!(
target = %self.target_name,
"Auto-reconnect disabled, cannot reconnect SSH connection"
);
return Err(anyhow::anyhow!(
"Auto-reconnect disabled for {}",
self.target_name
));
}
info!(
target = %self.target_name,
max_attempts = self.config.max_reconnect_attempts,
"Starting SSH reconnection process"
);
self.set_state(SshConnectionState::Reconnecting);
let max_attempts = self.config.max_reconnect_attempts;
let base_delay = Duration::from_secs(1);
let max_delay = Duration::from_secs(60);
let backoff_multiplier = self.config.reconnect_backoff_multiplier;
for attempt in 1..=max_attempts {
info!(
"SSH reconnection attempt {} for {}",
attempt, self.target_name
);
match self.create_ssh_client().await {
Ok(client) => {
*self.ssh_client.write() = Some(client);
self.set_state(SshConnectionState::Connected);
self.reconnect_attempts.store(0, Ordering::SeqCst);
info!(
"SSH reconnection successful for {} after {} attempts",
self.target_name, attempt
);
return Ok(());
}
Err(e) => {
self.reconnect_attempts.store(attempt, Ordering::SeqCst);
error!(
"SSH reconnection attempt {} failed for {}: {}",
attempt, self.target_name, e
);
if attempt < max_attempts {
let delay = std::cmp::min(
base_delay.mul_f64(backoff_multiplier.powi((attempt - 1) as i32)),
max_delay,
);
warn!(
"Waiting {}s before next SSH reconnection attempt for {}",
delay.as_secs(),
self.target_name
);
sleep(delay).await;
}
}
}
}
self.set_state(SshConnectionState::Failed);
Err(anyhow::anyhow!(
"SSH reconnection failed for {} after {} attempts",
self.target_name,
max_attempts
))
}
pub async fn create_tunneled_connection(
&self,
target_config: &TargetConfig,
) -> Result<TcpStream> {
{
let client_guard = self.ssh_client.read();
if client_guard.is_none() {
return Err(anyhow::anyhow!("SSH client not available"));
}
}
let target_addr = format!("{}:{}", target_config.host, target_config.port);
let stream = TcpStream::connect(&target_addr)
.await
.with_context(|| format!("Failed to create tunneled connection to {}", target_addr))?;
debug!("Created tunneled connection to {} through SSH", target_addr);
Ok(stream)
}
async fn is_connection_alive(&self) -> bool {
let client_guard = self.ssh_client.read();
client_guard.is_some()
}
async fn wait_for_reconnection(&self) -> Result<()> {
let max_wait = Duration::from_secs(60);
let check_interval = Duration::from_millis(500);
let start_time = Instant::now();
while start_time.elapsed() < max_wait {
let current_state = self.connection_state.read().clone();
match current_state {
SshConnectionState::Connected => return Ok(()),
SshConnectionState::Failed => {
return Err(anyhow::anyhow!(
"SSH connection failed for {}",
self.target_name
));
}
SshConnectionState::Reconnecting => {
sleep(check_interval).await;
}
SshConnectionState::Disconnected => {
return Err(anyhow::anyhow!(
"SSH connection unexpectedly disconnected for {}",
self.target_name
));
}
}
}
Err(anyhow::anyhow!(
"Timeout waiting for SSH reconnection for {}",
self.target_name
))
}
fn set_state(&self, state: SshConnectionState) {
*self.connection_state.write() = state;
}
pub fn get_state(&self) -> SshConnectionState {
self.connection_state.read().clone()
}
pub fn get_reconnect_attempts(&self) -> u32 {
self.reconnect_attempts.load(Ordering::SeqCst)
}
async fn create_ssh_client(&self) -> Result<Client> {
let ssh_host = self
.config
.host
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SSH host not configured"))?;
let ssh_user = self
.config
.user
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SSH user not configured"))?;
let ssh_port = self.config.port.unwrap_or(22);
if self.config.key_file.is_none() {
return Err(anyhow::anyhow!("SSH key file not configured"));
}
info!(
target = %self.target_name,
ssh_host = %ssh_host,
ssh_user = %ssh_user,
ssh_port = ssh_port,
"Creating SSH client"
);
Err(anyhow::anyhow!(
"SSH client creation not yet implemented - this is a placeholder for future development"
))
}
pub async fn start_monitoring(&self) {
let target_name = self.target_name.clone();
let connection_state = Arc::clone(&self.connection_state);
let reconnect_attempts = Arc::clone(&self.reconnect_attempts);
let config = self.config.clone();
let ssh_host = config.host.clone().unwrap_or_default();
let target_name_for_span = target_name.clone();
tokio::spawn(
async move {
let mut interval = interval(Duration::from_secs(config.reconnect_interval_seconds));
loop {
interval.tick().await;
let current_state = connection_state.read().clone();
let attempts = reconnect_attempts.load(Ordering::SeqCst);
debug!(
target = %target_name,
state = ?current_state,
reconnect_attempts = attempts,
"SSH health monitoring check"
);
match current_state {
SshConnectionState::Connected => {}
SshConnectionState::Disconnected => {
if config.auto_reconnect {
info!(
target = %target_name,
"SSH connection disconnected, triggering reconnection"
);
}
}
SshConnectionState::Reconnecting => {}
SshConnectionState::Failed => {
if config.auto_reconnect && attempts < config.max_reconnect_attempts {
info!(
target = %target_name,
attempts = attempts,
max_attempts = config.max_reconnect_attempts,
"SSH connection failed, will retry"
);
}
}
}
}
}
.instrument(ssh_span(&target_name_for_span, &ssh_host)),
);
}
}
impl Clone for SshConnectionManager {
fn clone(&self) -> Self {
Self {
ssh_client: Arc::clone(&self.ssh_client),
connection_state: Arc::clone(&self.connection_state),
config: self.config.clone(),
reconnect_attempts: Arc::clone(&self.reconnect_attempts),
target_name: self.target_name.clone(),
}
}
}
pub struct HealthChecker {
health_status: Arc<RwLock<HashMap<String, HealthStatus>>>,
}
#[derive(Debug, Clone)]
pub struct HealthStatus {
pub is_healthy: bool,
pub last_check: Instant,
pub consecutive_failures: u32,
pub last_error: Option<String>,
pub ssh_state: Option<SshConnectionState>,
pub ssh_reconnect_attempts: Option<u32>,
}
impl Default for HealthStatus {
fn default() -> Self {
Self {
is_healthy: true,
last_check: Instant::now(),
consecutive_failures: 0,
last_error: None,
ssh_state: None,
ssh_reconnect_attempts: None,
}
}
}
impl ConnectionManager {
pub fn new() -> Self {
Self {
pools: Arc::new(RwLock::new(HashMap::new())),
health_checker: Arc::new(HealthChecker::new()),
ssh_managers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn initialize_target(
&self,
target_name: String,
target_config: TargetConfig,
) -> Result<()> {
info!(
target = %target_name,
host = %target_config.host,
port = target_config.port,
ssh_enabled = target_config.ssh.as_ref().map(|s| s.enabled).unwrap_or(false),
"Initializing connection pool for target"
);
if target_config.host.is_empty() {
let error = anyhow::anyhow!("Target host cannot be empty for '{}'", target_name);
log_error_with_context(&error, "Target configuration validation failed");
return Err(error);
}
if target_config.port == 0 {
let error = anyhow::anyhow!("Target port cannot be zero for '{}'", target_name);
log_error_with_context(&error, "Target configuration validation failed");
return Err(error);
}
let ssh_manager = if let Some(ssh) = &target_config.ssh {
if ssh.enabled {
info!(
target = %target_name,
ssh_host = %ssh.host.as_deref().unwrap_or("not configured"),
ssh_user = %ssh.user.as_deref().unwrap_or("not configured"),
"Initializing SSH manager for target"
);
let manager = Arc::new(SshConnectionManager::new(ssh.clone(), target_name.clone()));
manager.start_monitoring().await;
self.ssh_managers
.write()
.insert(target_name.clone(), Arc::clone(&manager));
log_resource_cleanup("ssh_manager", &target_name, true);
Some(manager)
} else {
debug!(
target = %target_name,
"SSH configured but disabled for target"
);
None
}
} else {
debug!(
target = %target_name,
"No SSH configuration for target"
);
None
};
let pool_manager = ConnectionPoolManager {
target_config: target_config.clone(),
ssh_manager: ssh_manager.clone(),
};
let pool_config = deadpool::managed::PoolConfig::new(10);
let pool = Pool::builder(pool_manager)
.config(pool_config)
.build()
.with_context(|| {
format!(
"Failed to create connection pool for target '{}'",
target_name
)
})?;
let connection_pool = Arc::new(ConnectionPool {
pool,
target_name: target_name.clone(),
});
self.pools
.write()
.insert(target_name.clone(), connection_pool);
self.health_checker
.health_status
.write()
.insert(target_name.clone(), HealthStatus::default());
info!("Connection pool initialized for target: {}", target_name);
Ok(())
}
pub async fn is_target_healthy(&self, target_name: &str) -> bool {
self.health_checker
.health_status
.read()
.get(target_name)
.map(|status| status.is_healthy)
.unwrap_or(false)
}
pub async fn start_health_checking(
&self,
targets: Vec<String>,
config: &crate::config::ConnectionManagementConfig,
) {
let health_checker = Arc::clone(&self.health_checker);
let pools = Arc::clone(&self.pools);
let ssh_managers = Arc::clone(&self.ssh_managers);
let check_interval = Duration::from_secs(config.health_check_interval_seconds);
tokio::spawn(async move {
let mut interval = interval(check_interval);
loop {
interval.tick().await;
for target_name in &targets {
let health_checker = Arc::clone(&health_checker);
let pools = Arc::clone(&pools);
let ssh_managers = Arc::clone(&ssh_managers);
let target_name = target_name.clone();
tokio::spawn(async move {
if let Err(e) = health_checker
.check_target_health(&target_name, &pools, &ssh_managers)
.await
{
error!("Health check failed for target {}: {}", target_name, e);
}
});
}
}
});
}
}
impl HealthChecker {
pub fn new() -> Self {
Self {
health_status: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn check_target_health(
&self,
target_name: &str,
pools: &Arc<RwLock<HashMap<String, Arc<ConnectionPool>>>>,
ssh_managers: &Arc<RwLock<HashMap<String, Arc<SshConnectionManager>>>>,
) -> Result<()> {
debug!("Checking health for target: {}", target_name);
let pool = {
let pools_guard = pools.read();
pools_guard.get(target_name).cloned()
};
let pool = match pool {
Some(pool) => pool,
None => {
warn!("Target {} not found in pools", target_name);
return Ok(());
}
};
let ssh_manager = {
let ssh_managers_guard = ssh_managers.read();
ssh_managers_guard.get(target_name).cloned()
};
let health_result = self.perform_health_check(&pool).await;
let (ssh_state, ssh_reconnect_attempts) = if let Some(ssh_mgr) = &ssh_manager {
(
Some(ssh_mgr.get_state()),
Some(ssh_mgr.get_reconnect_attempts()),
)
} else {
(None, None)
};
let mut status_guard = self.health_status.write();
let status = status_guard
.entry(target_name.to_string())
.or_insert_with(HealthStatus::default);
match health_result {
Ok(()) => {
if !status.is_healthy {
info!("Target {} is now healthy", target_name);
}
status.is_healthy = true;
status.consecutive_failures = 0;
status.last_error = None;
}
Err(e) => {
status.consecutive_failures += 1;
status.last_error = Some(e.to_string());
if status.consecutive_failures >= 3 && status.is_healthy {
warn!(
"Target {} marked as unhealthy after {} failures",
target_name, status.consecutive_failures
);
status.is_healthy = false;
}
}
}
status.ssh_state = ssh_state;
status.ssh_reconnect_attempts = ssh_reconnect_attempts;
status.last_check = Instant::now();
Ok(())
}
async fn perform_health_check(&self, pool: &ConnectionPool) -> Result<()> {
let _connection = tokio::time::timeout(Duration::from_secs(5), pool.pool.get())
.await
.map_err(|_| anyhow::anyhow!("Health check timeout"))?
.map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))?;
debug!("Health check successful for target: {}", pool.target_name);
Ok(())
}
}