use crate::cgroups_stats::ContainerStats;
use crate::runtime::{ContainerId, Runtime};
use crate::service::ServiceManager;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use zlayer_scheduler::metrics::{
ContainerStatsProvider, MetricsContainerId, RawContainerStats, ServiceContainerProvider,
};
pub struct ServiceManagerContainerProvider {
manager: Arc<ServiceManager>,
}
impl ServiceManagerContainerProvider {
pub fn new(manager: Arc<ServiceManager>) -> Self {
Self { manager }
}
}
#[async_trait]
impl ServiceContainerProvider for ServiceManagerContainerProvider {
async fn get_container_ids(&self, service_name: &str) -> Vec<MetricsContainerId> {
let container_ids = self.manager.get_service_containers(service_name).await;
container_ids
.into_iter()
.map(|id| MetricsContainerId {
service: id.service,
replica: id.replica,
})
.collect()
}
async fn get_all_services(&self) -> HashMap<String, Vec<MetricsContainerId>> {
let service_names = self.manager.list_services().await;
let mut result = HashMap::new();
for name in service_names {
let container_ids = self.get_container_ids(&name).await;
if !container_ids.is_empty() {
result.insert(name, container_ids);
}
}
result
}
}
pub struct RuntimeStatsProvider {
runtime: Arc<dyn Runtime + Send + Sync>,
}
impl RuntimeStatsProvider {
pub fn new(runtime: Arc<dyn Runtime + Send + Sync>) -> Self {
Self { runtime }
}
}
#[async_trait]
impl ContainerStatsProvider for RuntimeStatsProvider {
async fn get_stats(&self, id: &MetricsContainerId) -> Result<RawContainerStats, String> {
let container_id = ContainerId {
service: id.service.clone(),
replica: id.replica,
};
let stats = self
.runtime
.get_container_stats(&container_id)
.await
.map_err(|e| e.to_string())?;
Ok(container_stats_to_raw(&stats))
}
}
fn container_stats_to_raw(stats: &ContainerStats) -> RawContainerStats {
RawContainerStats {
cpu_usage_usec: stats.cpu_usage_usec,
memory_bytes: stats.memory_bytes,
memory_limit: stats.memory_limit,
timestamp: stats.timestamp,
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::runtime::MockRuntime;
use std::sync::Arc;
fn mock_spec() -> zlayer_spec::ServiceSpec {
serde_yaml::from_str::<zlayer_spec::DeploymentSpec>(
r"
version: v1
deployment: test
services:
test:
rtype: service
image:
name: test:latest
endpoints:
- name: http
protocol: http
port: 8080
scale:
mode: fixed
replicas: 1
",
)
.unwrap()
.services
.remove("test")
.unwrap()
}
#[tokio::test]
async fn test_service_manager_container_provider_empty() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime));
let provider = ServiceManagerContainerProvider::new(manager);
let containers = provider.get_container_ids("nonexistent").await;
assert!(containers.is_empty());
let all = provider.get_all_services().await;
assert!(all.is_empty());
}
#[tokio::test]
async fn test_service_manager_container_provider_with_service() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime));
Box::pin(manager.upsert_service("api".to_string(), mock_spec()))
.await
.unwrap();
manager.scale_service("api", 3).await.unwrap();
let provider = ServiceManagerContainerProvider::new(manager);
let containers = provider.get_container_ids("api").await;
assert_eq!(containers.len(), 3);
for c in &containers {
assert_eq!(c.service, "api");
assert!(c.replica >= 1 && c.replica <= 3);
}
let all = provider.get_all_services().await;
assert_eq!(all.len(), 1);
assert!(all.contains_key("api"));
assert_eq!(all["api"].len(), 3);
}
#[tokio::test]
async fn test_runtime_stats_provider() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let manager = Arc::new(ServiceManager::new(runtime.clone()));
Box::pin(manager.upsert_service("test".to_string(), mock_spec()))
.await
.unwrap();
manager.scale_service("test", 1).await.unwrap();
let stats_provider = RuntimeStatsProvider::new(runtime);
let id = MetricsContainerId {
service: "test".to_string(),
replica: 1,
};
let stats = stats_provider.get_stats(&id).await.unwrap();
assert_eq!(stats.cpu_usage_usec, 1_000_000);
assert_eq!(stats.memory_bytes, 50 * 1024 * 1024);
assert_eq!(stats.memory_limit, 256 * 1024 * 1024);
}
#[tokio::test]
async fn test_runtime_stats_provider_not_found() {
let runtime: Arc<dyn Runtime + Send + Sync> = Arc::new(MockRuntime::new());
let stats_provider = RuntimeStatsProvider::new(runtime);
let id = MetricsContainerId {
service: "nonexistent".to_string(),
replica: 1,
};
let result = stats_provider.get_stats(&id).await;
assert!(result.is_err());
}
#[test]
fn test_container_stats_to_raw() {
use std::time::Instant;
let stats = ContainerStats {
cpu_usage_usec: 1_000_000,
memory_bytes: 100 * 1024 * 1024,
memory_limit: 256 * 1024 * 1024,
timestamp: Instant::now(),
};
let raw = container_stats_to_raw(&stats);
assert_eq!(raw.cpu_usage_usec, stats.cpu_usage_usec);
assert_eq!(raw.memory_bytes, stats.memory_bytes);
assert_eq!(raw.memory_limit, stats.memory_limit);
}
}