use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock, Semaphore};
use tokio::time::interval;
use uuid::Uuid;
use crate::agent::persistent::PersistentClaudeAgent;
use crate::agent::{Task, TaskResult};
use crate::config::ClaudeConfig;
use crate::identity::AgentRole;
use crate::session::worktree_session::{WorktreeSessionConfig, WorktreeSessionManager};
#[derive(Debug, Clone)]
pub struct PooledSession {
pub agent: Arc<Mutex<PersistentClaudeAgent>>,
pub metadata: SessionPoolMetadata,
pub stats: SessionUsageStats,
pub last_health_check: Instant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionPoolMetadata {
pub pool_id: String,
pub agent_id: String,
pub role: AgentRole,
pub created_at: DateTime<Utc>,
pub pool_generation: u64,
pub priority_score: f64,
pub max_concurrent_tasks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionUsageStats {
pub total_tasks_executed: usize,
pub successful_tasks: usize,
pub failed_tasks: usize,
pub total_execution_time: Duration,
pub average_execution_time: Duration,
pub current_load: f64,
pub peak_load: f64,
pub uptime: Duration,
pub last_activity: DateTime<Utc>,
}
impl Default for SessionUsageStats {
fn default() -> Self {
Self {
total_tasks_executed: 0,
successful_tasks: 0,
failed_tasks: 0,
total_execution_time: Duration::ZERO,
average_execution_time: Duration::ZERO,
current_load: 0.0,
peak_load: 0.0,
uptime: Duration::ZERO,
last_activity: Utc::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionPoolConfig {
pub min_sessions_per_role: usize,
pub max_sessions_per_role: usize,
pub max_concurrent_tasks_per_session: usize,
pub health_check_interval: Duration,
pub warmup_strategy: WarmupStrategy,
pub load_balancing: LoadBalancingStrategy,
pub auto_scaling: AutoScalingConfig,
pub enable_performance_monitoring: bool,
}
#[derive(Debug, Clone)]
pub enum WarmupStrategy {
Lazy, Eager, Predictive, }
#[derive(Debug, Clone)]
pub enum LoadBalancingStrategy {
RoundRobin, LeastLoaded, WeightedRandom, Adaptive, }
#[derive(Debug, Clone)]
pub struct AutoScalingConfig {
pub enabled: bool,
pub scale_up_threshold: f64, pub scale_down_threshold: f64, pub scale_up_cooldown: Duration,
pub scale_down_cooldown: Duration,
pub target_load: f64,
}
impl Default for SessionPoolConfig {
fn default() -> Self {
Self {
min_sessions_per_role: 1,
max_sessions_per_role: 5,
max_concurrent_tasks_per_session: 3,
health_check_interval: Duration::from_secs(30),
warmup_strategy: WarmupStrategy::Lazy,
load_balancing: LoadBalancingStrategy::LeastLoaded,
auto_scaling: AutoScalingConfig {
enabled: true,
scale_up_threshold: 0.8,
scale_down_threshold: 0.3,
scale_up_cooldown: Duration::from_secs(60),
scale_down_cooldown: Duration::from_secs(300),
target_load: 0.6,
},
enable_performance_monitoring: true,
}
}
}
#[derive(Debug)]
pub struct SessionPool {
pools: Arc<RwLock<HashMap<String, Vec<PooledSession>>>>,
worktree_manager: Arc<Mutex<WorktreeSessionManager>>,
config: SessionPoolConfig,
load_balancer_state: Arc<RwLock<HashMap<String, LoadBalancerState>>>,
performance_metrics: Arc<RwLock<PerformanceMetrics>>,
background_tasks: Vec<tokio::task::JoinHandle<()>>,
creation_semaphore: Arc<Semaphore>,
scaling_state: Arc<RwLock<HashMap<String, ScalingState>>>,
}
#[derive(Debug, Clone)]
struct LoadBalancerState {
last_selected: usize,
#[allow(dead_code)] selection_weights: Vec<f64>,
#[allow(dead_code)] performance_history: VecDeque<f64>,
}
#[derive(Debug, Default)]
struct PerformanceMetrics {
total_tasks_processed: usize,
total_execution_time: Duration,
#[allow(dead_code)] session_utilization: HashMap<String, f64>,
#[allow(dead_code)] error_rates: HashMap<String, f64>,
#[allow(dead_code)] throughput_history: VecDeque<(DateTime<Utc>, usize)>,
}
#[derive(Debug, Clone)]
struct ScalingState {
#[allow(dead_code)] last_scale_up: Option<DateTime<Utc>>,
#[allow(dead_code)] last_scale_down: Option<DateTime<Utc>>,
#[allow(dead_code)] pending_scale_operations: usize,
}
impl SessionPool {
pub async fn new(
worktree_config: WorktreeSessionConfig,
pool_config: SessionPoolConfig,
) -> Result<Self> {
let mut worktree_manager = WorktreeSessionManager::new(worktree_config)?;
worktree_manager.start().await?;
let pool = Self {
pools: Arc::new(RwLock::new(HashMap::new())),
worktree_manager: Arc::new(Mutex::new(worktree_manager)),
config: pool_config.clone(),
load_balancer_state: Arc::new(RwLock::new(HashMap::new())),
performance_metrics: Arc::new(RwLock::new(PerformanceMetrics::default())),
background_tasks: Vec::new(),
creation_semaphore: Arc::new(Semaphore::new(pool_config.max_sessions_per_role * 4)),
scaling_state: Arc::new(RwLock::new(HashMap::new())),
};
Ok(pool)
}
pub async fn start(&mut self) -> Result<()> {
tracing::info!("Starting session pool");
let health_check_task = self.start_health_check_task().await;
self.background_tasks.push(health_check_task);
if self.config.auto_scaling.enabled {
let scaling_task = self.start_auto_scaling_task().await;
self.background_tasks.push(scaling_task);
}
if self.config.enable_performance_monitoring {
let monitoring_task = self.start_performance_monitoring_task().await;
self.background_tasks.push(monitoring_task);
}
if matches!(self.config.warmup_strategy, WarmupStrategy::Eager) {
self.warmup_essential_sessions().await?;
}
tracing::info!("Session pool started successfully");
Ok(())
}
pub async fn execute_task(
&self,
role: AgentRole,
task: Task,
claude_config: ClaudeConfig,
) -> Result<TaskResult> {
let start_time = Instant::now();
let session = self.get_optimal_session(&role, &claude_config).await?;
let result = {
let mut agent = session.agent.lock().await;
agent.execute_task(task.clone()).await?
};
self.update_session_stats(&session, &result, start_time.elapsed())
.await;
Ok(result)
}
pub async fn execute_task_batch(
&self,
role: AgentRole,
tasks: Vec<Task>,
claude_config: ClaudeConfig,
) -> Result<Vec<TaskResult>> {
if tasks.is_empty() {
return Ok(Vec::new());
}
let start_time = Instant::now();
let session = self.get_optimal_session(&role, &claude_config).await?;
let results = {
let mut agent = session.agent.lock().await;
agent.execute_task_batch(tasks).await?
};
for result in &results {
self.update_session_stats(
&session,
result,
start_time.elapsed() / results.len() as u32,
)
.await;
}
tracing::info!("Batch of {} tasks completed in session pool", results.len());
Ok(results)
}
async fn get_optimal_session(
&self,
role: &AgentRole,
claude_config: &ClaudeConfig,
) -> Result<Arc<PooledSession>> {
let role_key = role.name().to_lowercase();
let available_session = self.find_available_session(&role_key).await;
if let Some(session) = available_session {
return Ok(session);
}
self.ensure_session_available(&role_key, role.clone(), claude_config.clone())
.await
}
async fn find_available_session(&self, role_key: &str) -> Option<Arc<PooledSession>> {
let pools = self.pools.read().await;
let sessions = pools.get(role_key)?;
if sessions.is_empty() {
return None;
}
match self.config.load_balancing {
LoadBalancingStrategy::LeastLoaded => {
sessions
.iter()
.filter(|s| s.stats.current_load < 1.0) .min_by(|a, b| {
a.stats
.current_load
.partial_cmp(&b.stats.current_load)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|s| Arc::new((*s).clone()))
}
LoadBalancingStrategy::RoundRobin => {
let mut lb_state = self.load_balancer_state.write().await;
let state = lb_state
.entry(role_key.to_string())
.or_insert(LoadBalancerState {
last_selected: 0,
selection_weights: vec![1.0; sessions.len()],
performance_history: VecDeque::new(),
});
let available_sessions: Vec<_> = sessions
.iter()
.filter(|s| s.stats.current_load < 1.0)
.collect();
if available_sessions.is_empty() {
return None;
}
let selected_idx = state.last_selected % available_sessions.len();
state.last_selected = (state.last_selected + 1) % available_sessions.len();
Some(Arc::new((*available_sessions[selected_idx]).clone()))
}
LoadBalancingStrategy::Adaptive => {
sessions
.iter()
.filter(|s| s.stats.current_load < 1.0)
.max_by(|a, b| {
let score_a = self.calculate_session_score(a);
let score_b = self.calculate_session_score(b);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|s| Arc::new((*s).clone()))
}
LoadBalancingStrategy::WeightedRandom => {
let available_sessions: Vec<_> = sessions
.iter()
.filter(|s| s.stats.current_load < 1.0)
.collect();
if available_sessions.is_empty() {
return None;
}
available_sessions.first().map(|s| Arc::new((**s).clone()))
}
}
}
fn calculate_session_score(&self, session: &PooledSession) -> f64 {
let load_factor = 1.0 - session.stats.current_load;
let success_rate = if session.stats.total_tasks_executed > 0 {
session.stats.successful_tasks as f64 / session.stats.total_tasks_executed as f64
} else {
1.0
};
let speed_factor = if session.stats.average_execution_time.as_millis() > 0 {
1000.0 / session.stats.average_execution_time.as_millis() as f64
} else {
1.0
};
(load_factor * 0.4) + (success_rate * 0.4) + (speed_factor.min(2.0) * 0.2)
}
async fn ensure_session_available(
&self,
role_key: &str,
role: AgentRole,
claude_config: ClaudeConfig,
) -> Result<Arc<PooledSession>> {
let current_count = {
let pools = self.pools.read().await;
pools.get(role_key).map(|p| p.len()).unwrap_or(0)
};
if current_count >= self.config.max_sessions_per_role {
tokio::time::timeout(
Duration::from_secs(30),
self.wait_for_available_session(role_key),
)
.await?
} else {
self.create_new_pooled_session(role_key, role, claude_config)
.await
}
}
async fn wait_for_available_session(&self, role_key: &str) -> Result<Arc<PooledSession>> {
let mut interval = interval(Duration::from_millis(100));
loop {
interval.tick().await;
if let Some(session) = self.find_available_session(role_key).await {
return Ok(session);
}
}
}
async fn create_new_pooled_session(
&self,
role_key: &str,
role: AgentRole,
claude_config: ClaudeConfig,
) -> Result<Arc<PooledSession>> {
let _permit = self.creation_semaphore.acquire().await?;
tracing::info!("Creating new pooled session for role: {}", role_key);
let agent = {
let worktree_manager = self.worktree_manager.lock().await;
worktree_manager
.get_or_create_worktree_session(role.clone(), claude_config.clone())
.await?
};
let metadata = SessionPoolMetadata {
pool_id: Uuid::new_v4().to_string(),
agent_id: {
let agent_guard = agent.lock().await;
agent_guard.identity.agent_id.clone()
},
role: role.clone(),
created_at: Utc::now(),
pool_generation: 1,
priority_score: 1.0,
max_concurrent_tasks: self.config.max_concurrent_tasks_per_session,
};
let pooled_session = PooledSession {
agent,
metadata,
stats: SessionUsageStats::default(),
last_health_check: Instant::now(),
};
let pooled_session = Arc::new(pooled_session);
{
let mut pools = self.pools.write().await;
pools
.entry(role_key.to_string())
.or_insert_with(Vec::new)
.push(pooled_session.as_ref().clone());
}
tracing::info!(
"Created pooled session successfully: {}",
pooled_session.metadata.agent_id
);
Ok(pooled_session)
}
async fn update_session_stats(
&self,
session: &Arc<PooledSession>,
result: &TaskResult,
execution_time: Duration,
) {
tracing::debug!(
"Task completed in session {}: success={}, duration={:?}",
session.metadata.agent_id,
result.success,
execution_time
);
let mut metrics = self.performance_metrics.write().await;
metrics.total_tasks_processed += 1;
metrics.total_execution_time += execution_time;
}
async fn start_health_check_task(&self) -> tokio::task::JoinHandle<()> {
let pools = Arc::clone(&self.pools);
let interval_duration = self.config.health_check_interval;
tokio::spawn(async move {
let mut interval = interval(interval_duration);
loop {
interval.tick().await;
let pools_guard = pools.read().await;
for (role_key, sessions) in pools_guard.iter() {
for session in sessions {
if let Err(e) = Self::health_check_session(session).await {
tracing::warn!(
"Health check failed for session {} in role {}: {}",
session.metadata.agent_id,
role_key,
e
);
}
}
}
}
})
}
async fn health_check_session(session: &PooledSession) -> Result<()> {
let agent = session.agent.lock().await;
let stats = agent.get_session_stats().await;
if !stats.is_active {
return Err(anyhow::anyhow!("Session is not active"));
}
Ok(())
}
async fn start_auto_scaling_task(&self) -> tokio::task::JoinHandle<()> {
let pools = Arc::clone(&self.pools);
let _scaling_state = Arc::clone(&self.scaling_state);
let config = self.config.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(30));
loop {
interval.tick().await;
let pools_guard = pools.read().await;
for (role_key, sessions) in pools_guard.iter() {
let current_load = Self::calculate_role_load(sessions);
if current_load > config.auto_scaling.scale_up_threshold {
tracing::info!(
"High load detected for role {}: {:.2}",
role_key,
current_load
);
} else if current_load < config.auto_scaling.scale_down_threshold {
tracing::debug!(
"Low load detected for role {}: {:.2}",
role_key,
current_load
);
}
}
}
})
}
fn calculate_role_load(sessions: &[PooledSession]) -> f64 {
if sessions.is_empty() {
return 0.0;
}
let total_load: f64 = sessions.iter().map(|s| s.stats.current_load).sum();
total_load / sessions.len() as f64
}
async fn start_performance_monitoring_task(&self) -> tokio::task::JoinHandle<()> {
let performance_metrics = Arc::clone(&self.performance_metrics);
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60));
loop {
interval.tick().await;
let metrics = performance_metrics.read().await;
tracing::info!(
"Performance metrics - Total tasks: {}, Avg execution time: {:?}",
metrics.total_tasks_processed,
if metrics.total_tasks_processed > 0 {
metrics.total_execution_time / metrics.total_tasks_processed as u32
} else {
Duration::ZERO
}
);
}
})
}
async fn warmup_essential_sessions(&self) -> Result<()> {
use crate::identity::{default_backend_role, default_frontend_role};
tracing::info!("Warming up essential sessions");
let essential_roles = vec![default_frontend_role(), default_backend_role()];
for role in essential_roles {
for _ in 0..self.config.min_sessions_per_role {
let claude_config = ClaudeConfig::default();
if let Err(e) = self
.create_new_pooled_session(
&role.name().to_lowercase(),
role.clone(),
claude_config,
)
.await
{
tracing::warn!("Failed to warmup session for role {}: {}", role.name(), e);
}
}
}
Ok(())
}
pub async fn get_pool_statistics(&self) -> PoolStatistics {
let pools = self.pools.read().await;
let performance_metrics = self.performance_metrics.read().await;
let mut role_stats = HashMap::new();
let mut total_sessions = 0;
let mut active_sessions = 0;
for (role_key, sessions) in pools.iter() {
total_sessions += sessions.len();
let role_active = sessions
.iter()
.filter(|s| s.stats.current_load > 0.0)
.count();
active_sessions += role_active;
role_stats.insert(
role_key.clone(),
RoleStatistics {
total_sessions: sessions.len(),
active_sessions: role_active,
average_load: Self::calculate_role_load(sessions),
total_tasks: sessions.iter().map(|s| s.stats.total_tasks_executed).sum(),
},
);
}
PoolStatistics {
total_sessions,
active_sessions,
role_statistics: role_stats,
global_performance: GlobalPerformanceStats {
total_tasks_processed: performance_metrics.total_tasks_processed,
total_execution_time: performance_metrics.total_execution_time,
average_execution_time: if performance_metrics.total_tasks_processed > 0 {
performance_metrics.total_execution_time
/ performance_metrics.total_tasks_processed as u32
} else {
Duration::ZERO
},
},
}
}
pub async fn shutdown(&mut self) -> Result<()> {
tracing::info!("Shutting down session pool");
for handle in self.background_tasks.drain(..) {
handle.abort();
}
{
let mut worktree_manager = self.worktree_manager.lock().await;
worktree_manager.shutdown().await?;
}
{
let mut pools = self.pools.write().await;
pools.clear();
}
tracing::info!("Session pool shutdown complete");
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PoolStatistics {
pub total_sessions: usize,
pub active_sessions: usize,
pub role_statistics: HashMap<String, RoleStatistics>,
pub global_performance: GlobalPerformanceStats,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RoleStatistics {
pub total_sessions: usize,
pub active_sessions: usize,
pub average_load: f64,
pub total_tasks: usize,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GlobalPerformanceStats {
pub total_tasks_processed: usize,
pub total_execution_time: Duration,
pub average_execution_time: Duration,
}