use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::agent_lifecycle_manager::HotReloadManager;
use crate::domain::hot_reload::ResourceUsageSnapshot;
#[allow(unused_imports)]
use crate::domain::{
AgentVersion, HotReloadConfig, HotReloadError, HotReloadId, HotReloadRequest, HotReloadResult,
HotReloadStatus, HotReloadStrategy, ReloadMetrics, TrafficSplitPercentage, VersionNumber,
VersionSnapshot,
};
use crate::domain_types::AgentId;
use crate::time_provider::{SharedTimeProvider, production_time_provider};
#[derive(Debug, Clone)]
struct HotReloadContext {
pub request: HotReloadRequest,
pub started_at: SystemTime,
pub status: HotReloadStatus,
pub metrics: ReloadMetrics,
pub current_traffic_split: TrafficSplitPercentage,
pub version_snapshots: Vec<VersionSnapshot>,
pub warmup_completed: bool,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct AgentInstance {
pub agent_id: AgentId,
pub version: AgentVersion,
pub is_active: bool,
pub memory_usage: usize,
pub fuel_consumed: u64,
pub requests_handled: u64,
pub created_at: SystemTime,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct PreservedState {
pub agent_id: AgentId,
pub state_data: Vec<u8>,
pub preserved_at: SystemTime,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct TrafficDecision {
pub route_to_new_version: bool,
pub old_version_weight: u8,
pub new_version_weight: u8,
}
#[async_trait::async_trait]
pub trait RuntimeManager {
async fn create_instance(
&self,
agent_id: AgentId,
version: AgentVersion,
wasm_bytes: &[u8],
) -> Result<(), HotReloadError>;
async fn stop_instance(
&self,
agent_id: AgentId,
version: AgentVersion,
) -> Result<(), HotReloadError>;
async fn get_instance_metrics(
&self,
agent_id: AgentId,
version: AgentVersion,
) -> Result<(usize, u64, u64), HotReloadError>;
async fn preserve_state(
&self,
agent_id: AgentId,
version: AgentVersion,
) -> Result<Vec<u8>, HotReloadError>;
async fn restore_state(
&self,
agent_id: AgentId,
version: AgentVersion,
state_data: &[u8],
) -> Result<(), HotReloadError>;
async fn health_check(
&self,
agent_id: AgentId,
version: AgentVersion,
) -> Result<bool, HotReloadError>;
}
#[async_trait::async_trait]
pub trait TrafficRouter {
async fn set_traffic_split(
&self,
agent_id: AgentId,
old_version: AgentVersion,
new_version: AgentVersion,
split_percentage: TrafficSplitPercentage,
) -> Result<(), HotReloadError>;
async fn get_traffic_split(
&self,
agent_id: AgentId,
) -> Result<TrafficSplitPercentage, HotReloadError>;
async fn switch_traffic(
&self,
agent_id: AgentId,
target_version: AgentVersion,
) -> Result<(), HotReloadError>;
}
pub struct CaxtonHotReloadManager {
active_reloads: Arc<RwLock<HashMap<HotReloadId, HotReloadContext>>>,
version_snapshots: Arc<RwLock<HashMap<AgentId, Vec<VersionSnapshot>>>>,
preserved_states: Arc<Mutex<HashMap<AgentId, PreservedState>>>,
runtime_manager: Arc<dyn RuntimeManager + Send + Sync>,
traffic_router: Arc<dyn TrafficRouter + Send + Sync>,
time_provider: SharedTimeProvider,
max_concurrent_reloads: usize,
default_timeout: Duration,
}
impl CaxtonHotReloadManager {
pub fn new(
runtime_manager: Arc<dyn RuntimeManager + Send + Sync>,
traffic_router: Arc<dyn TrafficRouter + Send + Sync>,
) -> Self {
Self::with_time_provider(runtime_manager, traffic_router, production_time_provider())
}
pub fn with_time_provider(
runtime_manager: Arc<dyn RuntimeManager + Send + Sync>,
traffic_router: Arc<dyn TrafficRouter + Send + Sync>,
time_provider: SharedTimeProvider,
) -> Self {
Self {
active_reloads: Arc::new(RwLock::new(HashMap::new())),
version_snapshots: Arc::new(RwLock::new(HashMap::new())),
preserved_states: Arc::new(Mutex::new(HashMap::new())),
runtime_manager,
traffic_router,
time_provider,
max_concurrent_reloads: 5,
default_timeout: Duration::from_secs(300), }
}
pub fn with_limits(
runtime_manager: Arc<dyn RuntimeManager + Send + Sync>,
traffic_router: Arc<dyn TrafficRouter + Send + Sync>,
max_concurrent: usize,
timeout: Duration,
) -> Self {
Self::with_limits_and_time_provider(
runtime_manager,
traffic_router,
max_concurrent,
timeout,
production_time_provider(),
)
}
pub fn with_limits_and_time_provider(
runtime_manager: Arc<dyn RuntimeManager + Send + Sync>,
traffic_router: Arc<dyn TrafficRouter + Send + Sync>,
max_concurrent: usize,
timeout: Duration,
time_provider: SharedTimeProvider,
) -> Self {
Self {
active_reloads: Arc::new(RwLock::new(HashMap::new())),
version_snapshots: Arc::new(RwLock::new(HashMap::new())),
preserved_states: Arc::new(Mutex::new(HashMap::new())),
runtime_manager,
traffic_router,
time_provider,
max_concurrent_reloads: max_concurrent,
default_timeout: timeout,
}
}
async fn check_reload_limit(&self) -> Result<(), HotReloadError> {
let active = self.active_reloads.read().await;
if active.len() >= self.max_concurrent_reloads {
return Err(HotReloadError::InsufficientResources);
}
Ok(())
}
async fn execute_hot_reload_strategy(
&self,
mut context: HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
match context.request.config.strategy {
HotReloadStrategy::Graceful => self.execute_graceful_reload(&mut context).await,
HotReloadStrategy::Immediate => self.execute_immediate_reload(&mut context).await,
HotReloadStrategy::Parallel => self.execute_parallel_reload(&mut context).await,
HotReloadStrategy::TrafficSplitting => {
self.execute_traffic_splitting_reload(&mut context).await
}
}
}
async fn execute_graceful_reload(
&self,
context: &mut HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
info!(
"Executing graceful hot reload for agent {}",
context.request.agent_id
);
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
context.status = HotReloadStatus::Preparing;
self.create_version_snapshot(agent_id, old_version, context)
.await?;
if context.request.preserve_state {
let state_data = self
.runtime_manager
.preserve_state(agent_id, old_version)
.await?;
let preserved_state = PreservedState {
agent_id,
state_data,
preserved_at: SystemTime::now(),
};
let mut states = self.preserved_states.lock().await;
states.insert(agent_id, preserved_state);
}
context.status = HotReloadStatus::Starting;
self.runtime_manager
.create_instance(agent_id, new_version, &context.request.new_wasm_module)
.await?;
if context.request.config.warmup_duration > Duration::from_secs(0) {
info!(
"Warming up new version for {:?}",
context.request.config.warmup_duration
);
self.time_provider
.sleep(context.request.config.warmup_duration)
.await;
}
context.warmup_completed = true;
if !self
.runtime_manager
.health_check(agent_id, new_version)
.await?
{
return Err(HotReloadError::StatePreservationFailed {
reason: "New version failed health check".to_string(),
});
}
if context.request.preserve_state {
let states = self.preserved_states.lock().await;
if let Some(preserved) = states.get(&agent_id) {
self.runtime_manager
.restore_state(agent_id, new_version, &preserved.state_data)
.await?;
}
}
context.status = HotReloadStatus::InProgress;
if self.time_provider.should_skip_delays() {
debug!("Skipping drain wait in test mode");
} else {
info!(
"Draining old version for {:?}",
context.request.config.drain_timeout.as_duration()
);
let drain_start = self.time_provider.instant();
let drain_duration = context.request.config.drain_timeout.as_duration();
while drain_start.elapsed() < drain_duration {
self.time_provider.sleep(Duration::from_millis(100)).await;
}
}
self.traffic_router
.switch_traffic(agent_id, new_version)
.await?;
self.runtime_manager
.stop_instance(agent_id, old_version)
.await?;
context.status = HotReloadStatus::Completed;
self.update_metrics(context).await?;
Ok(HotReloadResult::success(
context.request.reload_id,
agent_id,
old_version,
new_version,
context.started_at,
Some(context.metrics.clone()),
context.version_snapshots.clone(),
))
}
async fn execute_immediate_reload(
&self,
context: &mut HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
info!(
"Executing immediate hot reload for agent {}",
context.request.agent_id
);
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
context.status = HotReloadStatus::Preparing;
self.create_version_snapshot(agent_id, old_version, context)
.await?;
context.status = HotReloadStatus::Starting;
self.runtime_manager
.create_instance(agent_id, new_version, &context.request.new_wasm_module)
.await?;
context.status = HotReloadStatus::InProgress;
self.traffic_router
.switch_traffic(agent_id, new_version)
.await?;
self.runtime_manager
.stop_instance(agent_id, old_version)
.await?;
if !self
.runtime_manager
.health_check(agent_id, new_version)
.await?
{
warn!("New version failed health check, but immediate reload cannot rollback");
}
context.status = HotReloadStatus::Completed;
self.update_metrics(context).await?;
Ok(HotReloadResult::success(
context.request.reload_id,
agent_id,
old_version,
new_version,
context.started_at,
Some(context.metrics.clone()),
context.version_snapshots.clone(),
))
}
async fn execute_parallel_reload(
&self,
context: &mut HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
info!(
"Executing parallel hot reload for agent {}",
context.request.agent_id
);
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
context.status = HotReloadStatus::Preparing;
self.create_version_snapshot(agent_id, old_version, context)
.await?;
context.status = HotReloadStatus::Starting;
self.runtime_manager
.create_instance(agent_id, new_version, &context.request.new_wasm_module)
.await?;
if context.request.config.warmup_duration > Duration::from_secs(0) {
self.time_provider
.sleep(context.request.config.warmup_duration)
.await;
}
context.warmup_completed = true;
if !self
.runtime_manager
.health_check(agent_id, new_version)
.await?
{
self.runtime_manager
.stop_instance(agent_id, new_version)
.await?;
return Err(HotReloadError::AutomaticRollback {
reason: "New version failed health check".to_string(),
});
}
context.status = HotReloadStatus::InProgress;
let monitoring_duration = if self.time_provider.should_skip_delays() {
Duration::from_millis(1) } else {
Duration::from_secs(60) };
let monitor_start = SystemTime::now();
while monitor_start.elapsed().unwrap_or_default() < monitoring_duration {
self.update_metrics(context).await?;
if context
.request
.config
.rollback_capability
.should_trigger_rollback(&context.metrics)
{
warn!("Automatic rollback triggered during parallel execution");
self.runtime_manager
.stop_instance(agent_id, new_version)
.await?;
return Err(HotReloadError::AutomaticRollback {
reason: "Metrics triggered automatic rollback".to_string(),
});
}
let check_interval = if self.time_provider.should_skip_delays() {
Duration::from_millis(1)
} else {
Duration::from_secs(5)
};
self.time_provider.sleep(check_interval).await;
}
self.traffic_router
.switch_traffic(agent_id, new_version)
.await?;
self.runtime_manager
.stop_instance(agent_id, old_version)
.await?;
context.status = HotReloadStatus::Completed;
Ok(HotReloadResult::success(
context.request.reload_id,
agent_id,
old_version,
new_version,
context.started_at,
Some(context.metrics.clone()),
context.version_snapshots.clone(),
))
}
async fn execute_traffic_splitting_reload(
&self,
context: &mut HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
self.prepare_traffic_splitting_reload(context).await?;
self.execute_gradual_traffic_split(context).await?;
self.finalize_traffic_splitting_reload(context).await
}
async fn prepare_traffic_splitting_reload(
&self,
context: &mut HotReloadContext,
) -> Result<(), HotReloadError> {
info!(
"Executing traffic splitting hot reload for agent {}",
context.request.agent_id
);
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
context.status = HotReloadStatus::Preparing;
self.create_version_snapshot(agent_id, old_version, context)
.await?;
context.status = HotReloadStatus::Starting;
self.runtime_manager
.create_instance(agent_id, new_version, &context.request.new_wasm_module)
.await?;
if context.request.config.warmup_duration > Duration::from_secs(0) {
self.time_provider
.sleep(context.request.config.warmup_duration)
.await;
}
context.warmup_completed = true;
if !self
.runtime_manager
.health_check(agent_id, new_version)
.await?
{
self.runtime_manager
.stop_instance(agent_id, new_version)
.await?;
return Err(HotReloadError::AutomaticRollback {
reason: "New version failed initial health check".to_string(),
});
}
context.status = HotReloadStatus::InProgress;
Ok(())
}
async fn execute_gradual_traffic_split(
&self,
context: &mut HotReloadContext,
) -> Result<(), HotReloadError> {
let traffic_steps = if context.request.config.progressive_rollout {
vec![5, 10, 25, 50, 75, 100]
} else {
vec![context.request.config.traffic_split.as_percentage()]
};
for (i, percentage) in traffic_steps.iter().enumerate() {
self.execute_traffic_split_step(context, *percentage, i, traffic_steps.len())
.await?;
}
Ok(())
}
async fn execute_traffic_split_step(
&self,
context: &mut HotReloadContext,
percentage: u8,
step_index: usize,
total_steps: usize,
) -> Result<(), HotReloadError> {
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
let split_percentage = TrafficSplitPercentage::try_new(percentage).map_err(|_| {
HotReloadError::TrafficSplittingFailed {
reason: "Invalid traffic percentage".to_string(),
}
})?;
info!("Setting traffic split to {}% for new version", percentage);
self.traffic_router
.set_traffic_split(agent_id, old_version, new_version, split_percentage)
.await?;
context.current_traffic_split = split_percentage;
self.monitor_traffic_split_step(context, percentage).await?;
if step_index < total_steps - 1 {
info!(
"Traffic split at {}% successful, proceeding to next step",
percentage
);
}
Ok(())
}
async fn monitor_traffic_split_step(
&self,
context: &mut HotReloadContext,
percentage: u8,
) -> Result<(), HotReloadError> {
let monitor_duration = if self.time_provider.should_skip_delays() {
Duration::from_millis(1)
} else {
Duration::from_secs(30)
};
let step_start = SystemTime::now();
while step_start.elapsed().unwrap_or_default() < monitor_duration {
self.update_metrics(context).await?;
if context
.request
.config
.rollback_capability
.should_trigger_rollback(&context.metrics)
{
self.handle_automatic_rollback(context, percentage).await?;
}
let check_interval = if self.time_provider.should_skip_delays() {
Duration::from_millis(1)
} else {
Duration::from_secs(5)
};
self.time_provider.sleep(check_interval).await;
}
Ok(())
}
async fn handle_automatic_rollback(
&self,
context: &mut HotReloadContext,
percentage: u8,
) -> Result<(), HotReloadError> {
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
warn!("Automatic rollback triggered at {}% traffic", percentage);
let zero_split = TrafficSplitPercentage::try_new(0).unwrap();
self.traffic_router
.set_traffic_split(agent_id, old_version, new_version, zero_split)
.await?;
self.runtime_manager
.stop_instance(agent_id, new_version)
.await?;
Err(HotReloadError::AutomaticRollback {
reason: format!("Rollback triggered at {percentage}% traffic"),
})
}
async fn finalize_traffic_splitting_reload(
&self,
context: &mut HotReloadContext,
) -> Result<HotReloadResult, HotReloadError> {
let agent_id = context.request.agent_id;
let old_version = context.request.from_version;
let new_version = context.request.to_version;
self.traffic_router
.switch_traffic(agent_id, new_version)
.await?;
self.runtime_manager
.stop_instance(agent_id, old_version)
.await?;
context.status = HotReloadStatus::Completed;
Ok(HotReloadResult::success(
context.request.reload_id,
agent_id,
old_version,
new_version,
context.started_at,
Some(context.metrics.clone()),
context.version_snapshots.clone(),
))
}
async fn create_version_snapshot(
&self,
agent_id: AgentId,
version: AgentVersion,
context: &mut HotReloadContext,
) -> Result<(), HotReloadError> {
debug!(
"Creating version snapshot for agent {} version {}",
agent_id, version
);
let (memory, fuel, requests) = self
.runtime_manager
.get_instance_metrics(agent_id, version)
.await?;
let snapshot = VersionSnapshot {
version,
version_number: context.request.to_version_number,
wasm_module: context.request.new_wasm_module.clone(),
created_at: SystemTime::now(),
resource_usage: ResourceUsageSnapshot {
memory_allocated: memory,
fuel_consumed: fuel,
requests_handled: requests,
average_response_time_ms: 100, },
};
context.version_snapshots.push(snapshot.clone());
let mut snapshots = self.version_snapshots.write().await;
let agent_snapshots = snapshots.entry(agent_id).or_insert_with(Vec::new);
agent_snapshots.push(snapshot);
let max_snapshots = context
.request
.config
.rollback_capability
.preserve_previous_versions as usize;
if agent_snapshots.len() > max_snapshots {
agent_snapshots.drain(0..agent_snapshots.len() - max_snapshots);
}
Ok(())
}
async fn update_metrics(&self, context: &mut HotReloadContext) -> Result<(), HotReloadError> {
let agent_id = context.request.agent_id;
let new_version = context.request.to_version;
if let Ok((memory, _fuel, requests)) = self
.runtime_manager
.get_instance_metrics(agent_id, new_version)
.await
{
context.metrics.memory_usage_peak = context.metrics.memory_usage_peak.max(memory);
context.metrics.requests_processed = requests;
if let Ok(healthy) = self
.runtime_manager
.health_check(agent_id, new_version)
.await
{
context.metrics.health_check_success_rate = if healthy { 100.0 } else { 0.0 };
}
context.metrics.collected_at = SystemTime::now();
}
Ok(())
}
}
#[async_trait::async_trait]
impl HotReloadManager for CaxtonHotReloadManager {
async fn hot_reload_agent(
&self,
request: HotReloadRequest,
) -> std::result::Result<HotReloadResult, HotReloadError> {
info!(
"Starting hot reload for agent {} with strategy {:?}",
request.agent_id, request.config.strategy
);
self.check_reload_limit().await?;
request.validate()?;
let context = HotReloadContext {
request: request.clone(),
started_at: SystemTime::now(),
status: HotReloadStatus::Pending,
metrics: ReloadMetrics::new(),
current_traffic_split: request.config.traffic_split,
version_snapshots: Vec::new(),
warmup_completed: false,
};
{
let mut active = self.active_reloads.write().await;
active.insert(request.reload_id, context.clone());
}
let result = timeout(
self.default_timeout,
self.execute_hot_reload_strategy(context),
)
.await
.map_err(|_| HotReloadError::TimeoutExceeded {
timeout: u64::try_from(self.default_timeout.as_millis()).unwrap_or(u64::MAX),
})?;
{
let mut active = self.active_reloads.write().await;
active.remove(&request.reload_id);
}
match &result {
Ok(reload_result) => {
info!(
"Hot reload completed successfully for agent {} in {:?}",
request.agent_id,
reload_result.duration().unwrap_or_default()
);
}
Err(e) => {
error!("Hot reload failed for agent {}: {}", request.agent_id, e);
}
}
result
}
async fn get_hot_reload_status(
&self,
reload_id: HotReloadId,
) -> std::result::Result<HotReloadStatus, HotReloadError> {
let active = self.active_reloads.read().await;
if let Some(context) = active.get(&reload_id) {
Ok(context.status)
} else {
Ok(HotReloadStatus::Completed)
}
}
async fn cancel_hot_reload(
&self,
reload_id: HotReloadId,
) -> std::result::Result<(), HotReloadError> {
let mut active = self.active_reloads.write().await;
if let Some(context) = active.remove(&reload_id) {
info!("Cancelling hot reload {}", reload_id);
if let Err(e) = self
.runtime_manager
.stop_instance(context.request.agent_id, context.request.to_version)
.await
{
warn!("Failed to stop new instance during cancellation: {}", e);
}
if let Err(e) = self
.traffic_router
.switch_traffic(context.request.agent_id, context.request.from_version)
.await
{
warn!("Failed to reset traffic during cancellation: {}", e);
}
Ok(())
} else {
Err(HotReloadError::AlreadyInProgress { reload_id })
}
}
async fn rollback_hot_reload(
&self,
reload_id: HotReloadId,
target_version: AgentVersion,
) -> std::result::Result<HotReloadResult, HotReloadError> {
info!(
"Rolling back hot reload {} to version {}",
reload_id, target_version
);
let active = self.active_reloads.read().await;
if let Some(context) = active.get(&reload_id) {
let agent_id = context.request.agent_id;
let snapshots = self.version_snapshots.read().await;
let agent_snapshots =
snapshots
.get(&agent_id)
.ok_or(HotReloadError::VersionNotFound {
version: target_version,
})?;
let target_snapshot = agent_snapshots
.iter()
.find(|s| s.version == target_version)
.ok_or(HotReloadError::VersionNotFound {
version: target_version,
})?;
self.runtime_manager
.stop_instance(agent_id, context.request.to_version)
.await
.map_err(|e| HotReloadError::RollbackFailed {
reason: format!("Failed to stop current version: {e}"),
})?;
self.runtime_manager
.create_instance(agent_id, target_version, &target_snapshot.wasm_module)
.await
.map_err(|e| HotReloadError::RollbackFailed {
reason: format!("Failed to create target version instance: {e}"),
})?;
self.traffic_router
.switch_traffic(agent_id, target_version)
.await
.map_err(|e| HotReloadError::RollbackFailed {
reason: format!("Failed to switch traffic: {e}"),
})?;
Ok(HotReloadResult::rollback(
reload_id,
agent_id,
context.request.from_version,
context.request.to_version,
Some(context.started_at),
format!("Rolled back to version {target_version}"),
Some(context.metrics.clone()),
))
} else {
Err(HotReloadError::AlreadyInProgress { reload_id })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::TrafficSplitPercentage;
use std::sync::atomic::{AtomicBool, Ordering};
struct MockRuntimeManager {
should_succeed: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl RuntimeManager for MockRuntimeManager {
async fn create_instance(
&self,
_: AgentId,
_: AgentVersion,
_: &[u8],
) -> Result<(), HotReloadError> {
if self.should_succeed.load(Ordering::SeqCst) {
Ok(())
} else {
Err(HotReloadError::StatePreservationFailed {
reason: "Mock creation failure".to_string(),
})
}
}
async fn stop_instance(&self, _: AgentId, _: AgentVersion) -> Result<(), HotReloadError> {
Ok(())
}
async fn get_instance_metrics(
&self,
_: AgentId,
_: AgentVersion,
) -> Result<(usize, u64, u64), HotReloadError> {
Ok((1024, 1000, 100))
}
async fn preserve_state(
&self,
_: AgentId,
_: AgentVersion,
) -> Result<Vec<u8>, HotReloadError> {
Ok(vec![1, 2, 3, 4])
}
async fn restore_state(
&self,
_: AgentId,
_: AgentVersion,
_: &[u8],
) -> Result<(), HotReloadError> {
Ok(())
}
async fn health_check(&self, _: AgentId, _: AgentVersion) -> Result<bool, HotReloadError> {
Ok(self.should_succeed.load(Ordering::SeqCst))
}
}
struct MockTrafficRouter;
#[async_trait::async_trait]
impl TrafficRouter for MockTrafficRouter {
async fn set_traffic_split(
&self,
_: AgentId,
_: AgentVersion,
_: AgentVersion,
_: TrafficSplitPercentage,
) -> Result<(), HotReloadError> {
Ok(())
}
async fn get_traffic_split(
&self,
_: AgentId,
) -> Result<TrafficSplitPercentage, HotReloadError> {
Ok(TrafficSplitPercentage::half())
}
async fn switch_traffic(&self, _: AgentId, _: AgentVersion) -> Result<(), HotReloadError> {
Ok(())
}
}
fn create_test_hot_reload_manager() -> CaxtonHotReloadManager {
let runtime_manager = Arc::new(MockRuntimeManager {
should_succeed: Arc::new(AtomicBool::new(true)),
});
let traffic_router = Arc::new(MockTrafficRouter);
CaxtonHotReloadManager::new(runtime_manager, traffic_router)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "Slow test - involves sleep operations"]
async fn test_graceful_hot_reload() {
let manager = create_test_hot_reload_manager();
let mut config = HotReloadConfig::graceful();
config.warmup_duration = Duration::from_millis(1);
let request = HotReloadRequest::new(
AgentId::generate(),
None,
AgentVersion::generate(),
AgentVersion::generate(),
VersionNumber::first().next().unwrap(),
config,
vec![5, 6, 7, 8],
);
let result =
tokio::time::timeout(Duration::from_secs(1), manager.hot_reload_agent(request)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[tokio::test]
async fn test_immediate_hot_reload() {
let manager = create_test_hot_reload_manager();
let mut config = HotReloadConfig::immediate();
config.warmup_duration = Duration::from_millis(1);
let request = HotReloadRequest::new(
AgentId::generate(),
None,
AgentVersion::generate(),
AgentVersion::generate(),
VersionNumber::first().next().unwrap(),
config,
vec![5, 6, 7, 8],
);
let result = manager.hot_reload_agent(request).await;
assert!(result.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "Slow test - involves sleep operations"]
async fn test_traffic_splitting_hot_reload() {
let manager = create_test_hot_reload_manager();
let traffic_split = TrafficSplitPercentage::try_new(25).unwrap();
let mut config = HotReloadConfig::traffic_splitting(traffic_split);
config.warmup_duration = Duration::from_millis(1);
let request = HotReloadRequest::new(
AgentId::generate(),
None,
AgentVersion::generate(),
AgentVersion::generate(),
VersionNumber::first().next().unwrap(),
config,
vec![5, 6, 7, 8],
);
let result =
tokio::time::timeout(Duration::from_secs(1), manager.hot_reload_agent(request)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[tokio::test]
async fn test_hot_reload_status() {
let manager = create_test_hot_reload_manager();
let reload_id = HotReloadId::generate();
let status = manager.get_hot_reload_status(reload_id).await;
assert!(status.is_ok());
assert_eq!(status.unwrap(), HotReloadStatus::Completed);
}
#[tokio::test]
async fn test_hot_reload_cancellation() {
let manager = create_test_hot_reload_manager();
let reload_id = HotReloadId::generate();
let result = manager.cancel_hot_reload(reload_id).await;
assert!(result.is_err());
}
}