use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::errors::{Error, Result};
use crate::session::{Session, SessionConfig, SessionId, SessionManager, SessionState};
use crate::shutdown::ShutdownSignal;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_sessions: usize,
pub allow_duplicate_targets: bool,
pub default_session_config: SessionConfig,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_sessions: 100,
allow_duplicate_targets: true,
default_session_config: SessionConfig::default(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub active_sessions: usize,
pub total_started: u64,
pub total_terminated: u64,
pub sessions_per_target: HashMap<String, usize>,
}
pub struct SessionPool {
config: PoolConfig,
manager: SessionManager,
sessions: Arc<RwLock<HashMap<SessionId, SessionEntry>>>,
by_target: Arc<RwLock<HashMap<String, Vec<SessionId>>>>,
stats: Arc<RwLock<PoolStats>>,
shutdown: ShutdownSignal,
}
struct SessionEntry {
session: Session,
target: String,
}
impl SessionPool {
pub async fn new(config: PoolConfig) -> Result<Self> {
let manager = SessionManager::new().await?;
Ok(Self {
config,
manager,
sessions: Arc::new(RwLock::new(HashMap::new())),
by_target: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(PoolStats::default())),
shutdown: ShutdownSignal::new(),
})
}
pub fn with_manager(config: PoolConfig, manager: SessionManager) -> Self {
Self {
config,
manager,
sessions: Arc::new(RwLock::new(HashMap::new())),
by_target: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(PoolStats::default())),
shutdown: ShutdownSignal::new(),
}
}
pub async fn start_session(&self, target: &str) -> Result<SessionHandle<'_>> {
let mut config = self.config.default_session_config.clone();
config.target = target.to_string();
self.start_session_with_config(config).await
}
pub async fn start_session_with_config(
&self,
config: SessionConfig,
) -> Result<SessionHandle<'_>> {
let target = config.target.clone();
{
let sessions = self.sessions.read().await;
if self.config.max_sessions > 0 && sessions.len() >= self.config.max_sessions {
return Err(Error::Config(format!(
"Pool limit reached: {} sessions (max: {})",
sessions.len(),
self.config.max_sessions
)));
}
}
if !self.config.allow_duplicate_targets {
let by_target = self.by_target.read().await;
if let Some(existing) = by_target.get(&target) {
if !existing.is_empty() {
return Err(Error::Config(format!(
"Session already exists for target: {}",
target
)));
}
}
}
let session = self.manager.start_session(config).await?;
let session_id = session.id().to_string();
{
let mut sessions = self.sessions.write().await;
if self.config.max_sessions > 0 && sessions.len() >= self.config.max_sessions {
drop(sessions); let mut session = session;
let _ = session.terminate().await;
return Err(Error::Config(
"Pool limit reached while starting session (race condition)".to_string(),
));
}
let mut by_target = self.by_target.write().await;
if !self.config.allow_duplicate_targets {
if let Some(existing) = by_target.get(&target) {
if !existing.is_empty() {
drop(sessions);
drop(by_target);
let mut session = session;
let _ = session.terminate().await;
return Err(Error::Config(format!(
"Session already exists for target: {} (race condition)",
target
)));
}
}
}
let mut stats = self.stats.write().await;
sessions.insert(
session_id.clone(),
SessionEntry {
session,
target: target.clone(),
},
);
by_target
.entry(target.clone())
.or_default()
.push(session_id.clone());
stats.active_sessions = sessions.len();
stats.total_started += 1;
*stats.sessions_per_target.entry(target).or_insert(0) += 1;
}
info!(session_id = %session_id, "Session added to pool");
Ok(SessionHandle {
session_id,
pool: self,
})
}
pub async fn get(&self, session_id: &str) -> Option<SessionRef<'_>> {
let sessions = self.sessions.read().await;
if sessions.contains_key(session_id) {
Some(SessionRef {
session_id: session_id.to_string(),
pool: self,
})
} else {
None
}
}
pub async fn get_by_target(&self, target: &str) -> Vec<String> {
let by_target = self.by_target.read().await;
by_target.get(target).cloned().unwrap_or_default()
}
pub async fn list_sessions(&self) -> Vec<String> {
let sessions = self.sessions.read().await;
sessions.keys().cloned().collect()
}
pub async fn stats(&self) -> PoolStats {
self.stats.read().await.clone()
}
pub async fn terminate(&self, session_id: &str) -> Result<()> {
let entry = {
let mut sessions = self.sessions.write().await;
sessions.remove(session_id)
};
if let Some(mut entry) = entry {
{
let mut by_target = self.by_target.write().await;
if let Some(ids) = by_target.get_mut(&entry.target) {
ids.retain(|id| id != session_id);
}
}
{
let mut stats = self.stats.write().await;
stats.active_sessions = stats.active_sessions.saturating_sub(1);
stats.total_terminated += 1;
if let Some(count) = stats.sessions_per_target.get_mut(&entry.target) {
*count = count.saturating_sub(1);
}
}
entry.session.terminate().await?;
info!(session_id = %session_id, "Session removed from pool");
} else {
warn!(session_id = %session_id, "Session not found in pool");
}
Ok(())
}
pub async fn terminate_target(&self, target: &str) -> Result<()> {
let session_ids = self.get_by_target(target).await;
for session_id in session_ids {
self.terminate(&session_id).await?;
}
Ok(())
}
pub async fn shutdown(&self) {
info!("Shutting down session pool");
self.shutdown.shutdown();
let session_ids: Vec<String> = {
let sessions = self.sessions.read().await;
sessions.keys().cloned().collect()
};
for session_id in session_ids {
if let Err(e) = self.terminate(&session_id).await {
warn!(session_id = %session_id, error = ?e, "Failed to terminate session during shutdown");
}
}
info!("Session pool shutdown complete");
}
pub fn shutdown_signal(&self) -> ShutdownSignal {
self.shutdown.clone()
}
pub async fn cleanup(&self) {
let mut to_remove = Vec::new();
{
let sessions = self.sessions.read().await;
for (id, entry) in sessions.iter() {
if entry.session.state().await == SessionState::Terminated {
to_remove.push(id.clone());
}
}
}
for session_id in to_remove {
debug!(session_id = %session_id, "Cleaning up terminated session");
let _ = self.terminate(&session_id).await;
}
}
}
pub struct SessionHandle<'a> {
session_id: String,
pool: &'a SessionPool,
}
impl<'a> SessionHandle<'a> {
pub fn id(&self) -> &str {
&self.session_id
}
pub async fn send(&self, data: bytes::Bytes) -> Result<()> {
let sessions = self.pool.sessions.read().await;
if let Some(entry) = sessions.get(&self.session_id) {
entry.session.send(data).await
} else {
Err(Error::InvalidState("Session no longer in pool".to_string()))
}
}
pub async fn output(&self) -> Option<crate::channels::OutputStream> {
let sessions = self.pool.sessions.read().await;
sessions.get(&self.session_id).map(|e| e.session.output())
}
pub async fn terminate(self) -> Result<()> {
self.pool.terminate(&self.session_id).await
}
}
pub struct SessionRef<'a> {
session_id: String,
pool: &'a SessionPool,
}
impl<'a> SessionRef<'a> {
pub fn id(&self) -> &str {
&self.session_id
}
pub async fn send(&self, data: bytes::Bytes) -> Result<()> {
let sessions = self.pool.sessions.read().await;
if let Some(entry) = sessions.get(&self.session_id) {
entry.session.send(data).await
} else {
Err(Error::InvalidState("Session no longer in pool".to_string()))
}
}
pub async fn output(&self) -> Option<crate::channels::OutputStream> {
let sessions = self.pool.sessions.read().await;
sessions.get(&self.session_id).map(|e| e.session.output())
}
pub async fn is_ready(&self) -> bool {
let sessions = self.pool.sessions.read().await;
sessions
.get(&self.session_id)
.map(|e| e.session.is_ready())
.unwrap_or(false)
}
pub async fn state(&self) -> Option<SessionState> {
let sessions = self.pool.sessions.read().await;
if let Some(entry) = sessions.get(&self.session_id) {
Some(entry.session.state().await)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_sessions, 100);
assert!(config.allow_duplicate_targets);
}
#[test]
fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.active_sessions, 0);
assert_eq!(stats.total_started, 0);
assert_eq!(stats.total_terminated, 0);
}
}