use crate::error::Result;
use crate::metrics_providers::{RuntimeStatsProvider, ServiceManagerContainerProvider};
use crate::runtime::Runtime;
use crate::service::ServiceManager;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use zlayer_scheduler::metrics::{CgroupsMetricsSource, MetricsCollector, MetricsSource};
use zlayer_scheduler::Autoscaler;
use zlayer_spec::ScaleSpec;
pub const DEFAULT_AUTOSCALE_INTERVAL: Duration = Duration::from_secs(10);
pub struct AutoscaleController {
service_manager: Arc<ServiceManager>,
metrics: Arc<MetricsCollector>,
autoscaler: Arc<RwLock<Autoscaler>>,
service_specs: Arc<RwLock<HashMap<String, ScaleSpec>>>,
last_scale_times: Arc<RwLock<HashMap<String, Instant>>>,
interval: Duration,
shutdown: Arc<tokio::sync::Notify>,
}
impl AutoscaleController {
pub fn new(
service_manager: Arc<ServiceManager>,
runtime: Arc<dyn Runtime + Send + Sync>,
interval: Duration,
) -> Self {
let mut metrics = MetricsCollector::new();
let stats_provider = Arc::new(RuntimeStatsProvider::new(runtime));
let service_provider = Arc::new(ServiceManagerContainerProvider::new(
service_manager.clone(),
));
let source: Arc<dyn MetricsSource> =
Arc::new(CgroupsMetricsSource::new(service_provider, stats_provider));
metrics.add_source(source);
Self {
service_manager,
metrics: Arc::new(metrics),
autoscaler: Arc::new(RwLock::new(Autoscaler::new())),
service_specs: Arc::new(RwLock::new(HashMap::new())),
last_scale_times: Arc::new(RwLock::new(HashMap::new())),
interval,
shutdown: Arc::new(tokio::sync::Notify::new()),
}
}
pub fn with_custom_metrics(
service_manager: Arc<ServiceManager>,
metrics: MetricsCollector,
interval: Duration,
) -> Self {
Self {
service_manager,
metrics: Arc::new(metrics),
autoscaler: Arc::new(RwLock::new(Autoscaler::new())),
service_specs: Arc::new(RwLock::new(HashMap::new())),
last_scale_times: Arc::new(RwLock::new(HashMap::new())),
interval,
shutdown: Arc::new(tokio::sync::Notify::new()),
}
}
pub async fn register_service(&self, name: &str, spec: &ScaleSpec, initial_replicas: u32) {
if !matches!(spec, ScaleSpec::Adaptive { .. }) {
debug!(
service = name,
"Skipping registration for non-adaptive service"
);
return;
}
{
let mut autoscaler = self.autoscaler.write().await;
autoscaler.register_service(name, spec.clone(), initial_replicas);
}
{
let mut specs = self.service_specs.write().await;
specs.insert(name.to_string(), spec.clone());
}
info!(
service = name,
initial_replicas, "Registered service for autoscaling"
);
}
pub async fn unregister_service(&self, name: &str) {
{
let mut autoscaler = self.autoscaler.write().await;
autoscaler.unregister_service(name);
}
{
let mut specs = self.service_specs.write().await;
specs.remove(name);
}
{
let mut times = self.last_scale_times.write().await;
times.remove(name);
}
info!(service = name, "Unregistered service from autoscaling");
}
pub async fn is_registered(&self, name: &str) -> bool {
let specs = self.service_specs.read().await;
specs.contains_key(name)
}
async fn should_scale(&self, service_name: &str) -> bool {
let cooldown = {
let specs = self.service_specs.read().await;
match specs.get(service_name) {
Some(ScaleSpec::Adaptive { cooldown, .. }) => {
cooldown.unwrap_or(zlayer_scheduler::DEFAULT_COOLDOWN)
}
_ => return false, }
};
let last_scale_times = self.last_scale_times.read().await;
if let Some(last_time) = last_scale_times.get(service_name) {
if last_time.elapsed() < cooldown {
let remaining = cooldown
.checked_sub(last_time.elapsed())
.unwrap_or_default();
debug!(
service = service_name,
remaining_secs = remaining.as_secs(),
"Service in cooldown"
);
return false;
}
}
true
}
async fn record_scale_action(&self, service_name: &str) {
let mut times = self.last_scale_times.write().await;
times.insert(service_name.to_string(), Instant::now());
}
#[allow(clippy::cast_possible_truncation)]
pub async fn run_loop(&self) -> Result<()> {
let mut ticker = tokio::time::interval(self.interval);
info!(
interval_ms = self.interval.as_millis() as u64,
"Starting autoscale controller loop"
);
loop {
tokio::select! {
_ = ticker.tick() => {
self.evaluate_all_services().await;
}
() = self.shutdown.notified() => {
info!("Autoscale controller shutting down");
break;
}
}
}
Ok(())
}
async fn evaluate_all_services(&self) {
let service_names: Vec<String> = {
let specs = self.service_specs.read().await;
specs.keys().cloned().collect()
};
for service_name in service_names {
if let Err(e) = self.evaluate_and_scale(&service_name).await {
warn!(
service = %service_name,
error = %e,
"Failed to evaluate/scale service"
);
}
}
}
async fn evaluate_and_scale(&self, service_name: &str) -> Result<()> {
if !self.should_scale(service_name).await {
return Ok(());
}
let aggregated = match self.metrics.collect(service_name).await {
Ok(m) => m,
Err(e) => {
debug!(
service = service_name,
error = %e,
"No metrics available for service"
);
return Ok(());
}
};
let decision = {
let mut autoscaler = self.autoscaler.write().await;
match autoscaler.evaluate(service_name, &aggregated) {
Ok(d) => d,
Err(e) => {
debug!(
service = service_name,
error = %e,
"Failed to evaluate scaling"
);
return Ok(());
}
}
};
debug!(
service = service_name,
?decision,
cpu = aggregated.avg_cpu_percent,
memory = aggregated.avg_memory_percent,
instances = aggregated.instance_count,
"Autoscale evaluation"
);
if let Some(target) = decision.target_replicas() {
info!(
service = service_name,
target_replicas = target,
decision = ?decision,
"Executing autoscale"
);
if let Err(e) = self
.service_manager
.scale_service(service_name, target)
.await
{
error!(
service = service_name,
target = target,
error = %e,
"Failed to scale service"
);
return Err(e);
}
self.record_scale_action(service_name).await;
{
let mut autoscaler = self.autoscaler.write().await;
if let Err(e) = autoscaler.record_scale_action(service_name, target) {
warn!(
service = service_name,
error = %e,
"Failed to record scale action in autoscaler"
);
}
}
}
Ok(())
}
pub fn shutdown(&self) {
self.shutdown.notify_one();
}
#[must_use]
pub fn interval(&self) -> Duration {
self.interval
}
pub async fn registered_service_count(&self) -> usize {
let specs = self.service_specs.read().await;
specs.len()
}
}
#[must_use]
#[allow(clippy::implicit_hasher)]
pub fn has_adaptive_scaling(services: &HashMap<String, zlayer_spec::ServiceSpec>) -> bool {
services
.values()
.any(|s| matches!(s.scale, ScaleSpec::Adaptive { .. }))
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::runtime::MockRuntime;
use zlayer_scheduler::metrics::{MockMetricsSource, ServiceMetrics};
use zlayer_spec::ScaleTargets;
fn mock_spec() -> zlayer_spec::ServiceSpec {
serde_yaml::from_str::<zlayer_spec::DeploymentSpec>(
r"
version: v1
deployment: test
services:
test:
rtype: service
image:
name: test:latest
endpoints:
- name: http
protocol: http
port: 8080
scale:
mode: fixed
replicas: 1
",
)
.unwrap()
.services
.remove("test")
.unwrap()
}
fn adaptive_spec(
min: u32,
max: u32,
cpu_target: Option<u8>,
memory_target: Option<u8>,
) -> ScaleSpec {
ScaleSpec::Adaptive {
min,
max,
cooldown: Some(Duration::from_secs(0)), targets: ScaleTargets {
cpu: cpu_target,
memory: memory_target,
rps: None,
},
}
}
#[tokio::test]
async fn test_autoscale_controller_creation() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = AutoscaleController::new(manager, runtime, Duration::from_secs(10));
assert_eq!(controller.interval(), Duration::from_secs(10));
assert_eq!(controller.registered_service_count().await, 0);
}
#[tokio::test]
async fn test_register_service() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = AutoscaleController::new(manager, runtime, Duration::from_secs(10));
let spec = adaptive_spec(1, 10, Some(70), None);
controller.register_service("api", &spec, 2).await;
assert!(controller.is_registered("api").await);
assert_eq!(controller.registered_service_count().await, 1);
}
#[tokio::test]
async fn test_register_fixed_service_ignored() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = AutoscaleController::new(manager, runtime, Duration::from_secs(10));
let spec = ScaleSpec::Fixed { replicas: 3 };
controller.register_service("api", &spec, 3).await;
assert!(!controller.is_registered("api").await);
assert_eq!(controller.registered_service_count().await, 0);
}
#[tokio::test]
async fn test_unregister_service() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = AutoscaleController::new(manager, runtime, Duration::from_secs(10));
let spec = adaptive_spec(1, 10, Some(70), None);
controller.register_service("api", &spec, 2).await;
assert!(controller.is_registered("api").await);
controller.unregister_service("api").await;
assert!(!controller.is_registered("api").await);
assert_eq!(controller.registered_service_count().await, 0);
}
#[tokio::test]
async fn test_has_adaptive_scaling() {
let mut services = HashMap::new();
let mut fixed_spec = mock_spec();
fixed_spec.scale = ScaleSpec::Fixed { replicas: 3 };
services.insert("web".to_string(), fixed_spec);
assert!(!has_adaptive_scaling(&services));
let mut adaptive = mock_spec();
adaptive.scale = adaptive_spec(1, 10, Some(70), None);
services.insert("api".to_string(), adaptive);
assert!(has_adaptive_scaling(&services));
}
#[tokio::test]
async fn test_autoscale_controller_with_mock_metrics() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let mock = Arc::new(MockMetricsSource::new());
mock.set_metrics(
"api",
vec![
ServiceMetrics {
cpu_percent: 85.0,
memory_bytes: 100 * 1024 * 1024,
memory_limit: 512 * 1024 * 1024,
rps: None,
timestamp: Some(Instant::now()),
},
ServiceMetrics {
cpu_percent: 90.0,
memory_bytes: 150 * 1024 * 1024,
memory_limit: 512 * 1024 * 1024,
rps: None,
timestamp: Some(Instant::now()),
},
],
)
.await;
let mut metrics = MetricsCollector::new();
metrics.add_source(mock);
let controller = AutoscaleController::with_custom_metrics(
manager.clone(),
metrics,
Duration::from_secs(10),
);
Box::pin(manager.upsert_service("api".to_string(), mock_spec()))
.await
.unwrap();
manager.scale_service("api", 2).await.unwrap();
let spec = adaptive_spec(1, 10, Some(70), None);
controller.register_service("api", &spec, 2).await;
controller.evaluate_and_scale("api").await.unwrap();
let count = manager.service_replica_count("api").await.unwrap();
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_autoscale_controller_cooldown() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = AutoscaleController::new(manager, runtime, Duration::from_secs(10));
let spec = ScaleSpec::Adaptive {
min: 1,
max: 10,
cooldown: Some(Duration::from_secs(60)), targets: ScaleTargets {
cpu: Some(70),
memory: None,
rps: None,
},
};
controller.register_service("api", &spec, 2).await;
assert!(controller.should_scale("api").await);
controller.record_scale_action("api").await;
assert!(!controller.should_scale("api").await);
}
#[tokio::test]
async fn test_autoscale_controller_shutdown() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
let controller = Arc::new(AutoscaleController::new(
manager,
runtime,
Duration::from_millis(100), ));
let controller_clone = controller.clone();
let handle = tokio::spawn(async move { controller_clone.run_loop().await });
tokio::time::sleep(Duration::from_millis(50)).await;
controller.shutdown();
let result = handle.await.unwrap();
assert!(result.is_ok());
}
}