use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::Instant;
use turbomcp_protocol::{Error, Result};
use turbomcp_transport::Transport;
use super::core::Client;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ConnectionState {
Healthy,
Degraded,
Unhealthy,
Connecting,
Disconnected,
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub id: String,
pub state: ConnectionState,
pub established_at: Instant,
pub last_health_check: Option<Instant>,
pub failed_health_checks: usize,
pub successful_requests: usize,
pub failed_requests: usize,
}
#[derive(Debug, Clone)]
pub struct ManagerConfig {
pub max_connections: usize,
pub health_check_interval: Duration,
pub health_check_threshold: usize,
pub health_check_timeout: Duration,
pub auto_reconnect: bool,
pub reconnect_delay: Duration,
pub max_reconnect_delay: Duration,
pub reconnect_backoff_multiplier: f64,
}
impl Default for ManagerConfig {
fn default() -> Self {
Self {
max_connections: 10,
health_check_interval: Duration::from_secs(30),
health_check_threshold: 3,
health_check_timeout: Duration::from_secs(5),
auto_reconnect: true,
reconnect_delay: Duration::from_secs(1),
max_reconnect_delay: Duration::from_secs(60),
reconnect_backoff_multiplier: 2.0,
}
}
}
impl ManagerConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
struct ManagedConnection<T: Transport + 'static> {
client: Client<T>,
info: ConnectionInfo,
}
#[derive(Debug, Clone)]
pub struct ServerGroup {
pub primary: String,
pub backups: Vec<String>,
pub failover_threshold: usize,
}
impl ServerGroup {
pub fn new(primary: impl Into<String>, backups: Vec<String>) -> Self {
Self {
primary: primary.into(),
backups,
failover_threshold: 3,
}
}
#[must_use]
pub fn with_failover_threshold(mut self, threshold: usize) -> Self {
self.failover_threshold = threshold;
self
}
#[must_use]
pub fn all_servers(&self) -> Vec<&str> {
std::iter::once(self.primary.as_str())
.chain(self.backups.iter().map(|s| s.as_str()))
.collect()
}
#[must_use]
pub fn next_server(&self, current: &str) -> Option<&str> {
let servers = self.all_servers();
let current_idx = servers.iter().position(|&s| s == current)?;
servers.get(current_idx + 1).copied()
}
}
pub struct SessionManager<T: Transport + 'static> {
config: ManagerConfig,
connections: Arc<RwLock<HashMap<String, ManagedConnection<T>>>>,
health_check_task: Option<tokio::task::JoinHandle<()>>,
}
impl<T: Transport + Send + 'static> SessionManager<T> {
#[must_use]
pub fn new(config: ManagerConfig) -> Self {
Self {
config,
connections: Arc::new(RwLock::new(HashMap::new())),
health_check_task: None,
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(ManagerConfig::default())
}
pub async fn add_server(&mut self, id: impl Into<String>, transport: T) -> Result<()> {
let id = id.into();
let mut connections = self.connections.write().await;
if connections.len() >= self.config.max_connections {
return Err(Error::invalid_request(format!(
"Maximum connections limit ({}) reached",
self.config.max_connections
)));
}
if connections.contains_key(&id) {
return Err(Error::invalid_request(format!(
"Connection with ID '{}' already exists",
id
)));
}
let client = Client::new(transport);
client.initialize().await?;
let info = ConnectionInfo {
id: id.clone(),
state: ConnectionState::Healthy,
established_at: Instant::now(),
last_health_check: Some(Instant::now()),
failed_health_checks: 0,
successful_requests: 0,
failed_requests: 0,
};
connections.insert(id, ManagedConnection { client, info });
Ok(())
}
pub async fn remove_server(&mut self, id: &str) -> bool {
let mut connections = self.connections.write().await;
connections.remove(id).is_some()
}
pub async fn get_session_info(&self, id: &str) -> Option<ConnectionInfo> {
let connections = self.connections.read().await;
connections.get(id).map(|conn| conn.info.clone())
}
pub async fn list_sessions(&self) -> Vec<ConnectionInfo> {
let connections = self.connections.read().await;
connections.values().map(|conn| conn.info.clone()).collect()
}
pub async fn get_healthy_connection(&self) -> Option<String> {
let connections = self.connections.read().await;
connections
.iter()
.filter(|(_, conn)| conn.info.state == ConnectionState::Healthy)
.min_by_key(|(_, conn)| conn.info.successful_requests + conn.info.failed_requests)
.map(|(id, _)| id.clone())
}
pub async fn session_stats(&self) -> HashMap<ConnectionState, usize> {
let connections = self.connections.read().await;
let mut stats = HashMap::new();
for conn in connections.values() {
*stats.entry(conn.info.state.clone()).or_insert(0) += 1;
}
stats
}
pub async fn start_health_monitoring(&mut self) {
if self.health_check_task.is_some() {
return; }
let connections = Arc::clone(&self.connections);
let interval = self.config.health_check_interval;
let threshold = self.config.health_check_threshold;
let timeout = self.config.health_check_timeout;
let task = tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
let mut connections = connections.write().await;
for (id, managed) in connections.iter_mut() {
let health_result = tokio::time::timeout(timeout, managed.client.ping()).await;
match health_result {
Ok(Ok(_)) => {
managed.info.last_health_check = Some(Instant::now());
managed.info.failed_health_checks = 0;
if managed.info.state != ConnectionState::Healthy {
tracing::info!(
connection_id = %id,
"Connection recovered and is now healthy"
);
managed.info.state = ConnectionState::Healthy;
}
}
Ok(Err(_)) | Err(_) => {
managed.info.failed_health_checks += 1;
if managed.info.failed_health_checks >= threshold {
if managed.info.state != ConnectionState::Unhealthy {
tracing::warn!(
connection_id = %id,
failed_checks = managed.info.failed_health_checks,
"Connection marked as unhealthy"
);
managed.info.state = ConnectionState::Unhealthy;
}
} else if managed.info.state == ConnectionState::Healthy {
tracing::debug!(
connection_id = %id,
failed_checks = managed.info.failed_health_checks,
"Connection degraded"
);
managed.info.state = ConnectionState::Degraded;
}
}
}
}
}
});
self.health_check_task = Some(task);
}
pub fn stop_health_monitoring(&mut self) {
if let Some(task) = self.health_check_task.take() {
task.abort();
}
}
pub async fn session_count(&self) -> usize {
let connections = self.connections.read().await;
connections.len()
}
}
impl<T: Transport + 'static> Drop for SessionManager<T> {
fn drop(&mut self) {
self.stop_health_monitoring();
}
}
impl SessionManager<turbomcp_transport::resilience::TurboTransport> {
pub async fn add_resilient_server<BaseT>(
&mut self,
id: impl Into<String>,
transport: BaseT,
retry_config: turbomcp_transport::resilience::RetryConfig,
circuit_config: turbomcp_transport::resilience::CircuitBreakerConfig,
health_config: turbomcp_transport::resilience::HealthCheckConfig,
) -> Result<()>
where
BaseT: Transport + 'static,
{
use turbomcp_transport::resilience::TurboTransport;
let robust = TurboTransport::new(
Box::new(transport),
retry_config,
circuit_config,
health_config,
);
self.add_server(id, robust).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manager_config_defaults() {
let config = ManagerConfig::default();
assert_eq!(config.max_connections, 10);
assert!(config.auto_reconnect);
}
#[test]
fn test_connection_state_equality() {
assert_eq!(ConnectionState::Healthy, ConnectionState::Healthy);
assert_ne!(ConnectionState::Healthy, ConnectionState::Unhealthy);
}
}