use actix_web::dev::Server;
use actix_web::{web, App, HttpServer};
use std::collections::HashSet;
use std::fs;
use std::net::TcpListener;
use std::os::unix::fs::{chown, PermissionsExt};
use std::os::unix::net::UnixListener;
use tracing::subscriber::set_global_default;
use tracing_log::LogTracer;
use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::{EnvFilter, Registry};
use crate::auth::{AuthMiddleware, SigningKeyData};
use crate::config::ApiGroup;
use crate::devices::cpu::power::{start_cpu_poller, CpuPowerBroadcast, CpuPowerPoller};
use crate::devices::cpu::{CpuManagementTasks, CpuManager, RaplCpu};
use crate::devices::gpu::power::{start_gpu_poller, GpuPowerBroadcast, GpuPowerPoller};
use crate::devices::gpu::{GpuManagementTasks, GpuManager, NvmlGpu};
use crate::routes::cpu_routes;
use crate::routes::{gpu_control_routes, gpu_read_routes, server_routes, DiscoveryInfo};
pub fn init_tracing<S>(sink: S) -> anyhow::Result<()>
where
S: for<'a> MakeWriter<'a> + Send + Sync + 'static,
{
LogTracer::init()?;
let formatter = tracing_subscriber::fmt::layer().with_writer(sink);
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let subscriber = Registry::default().with(formatter).with(env_filter);
set_global_default(subscriber)?;
Ok(())
}
pub fn get_unix_listener(
socket_path: &str,
permissions: u32,
uid: Option<u32>,
gid: Option<u32>,
) -> anyhow::Result<UnixListener> {
if fs::metadata(socket_path).is_ok() {
tracing::error!(
"Socket file {} already exists. Please remove it and restart Zeusd.",
socket_path,
);
anyhow::bail!("Socket file already exists");
}
let listener = UnixListener::bind(socket_path)?;
fs::set_permissions(socket_path, fs::Permissions::from_mode(permissions))?;
chown(socket_path, uid, gid)?;
Ok(listener)
}
pub fn start_gpu_device_tasks() -> anyhow::Result<GpuManagementTasks> {
tracing::info!("Starting NVML and GPU management tasks.");
let num_gpus = NvmlGpu::device_count()?;
let mut gpus = Vec::with_capacity(num_gpus as usize);
for gpu_id in 0..num_gpus {
let gpu = NvmlGpu::init(gpu_id)?;
tracing::info!("Initialized NVML for GPU {}", gpu_id);
gpus.push(gpu);
}
Ok(GpuManagementTasks::start(gpus)?)
}
pub fn start_gpu_power_poller(poll_hz: u32) -> anyhow::Result<GpuPowerPoller> {
tracing::info!("Starting GPU power poller at {} Hz.", poll_hz);
let num_gpus = NvmlGpu::device_count()?;
let mut gpus = Vec::with_capacity(num_gpus as usize);
for gpu_id in 0..num_gpus {
let gpu = NvmlGpu::init(gpu_id)?;
gpus.push((gpu_id as usize, gpu));
}
Ok(start_gpu_poller(gpus, poll_hz))
}
pub fn start_cpu_device_tasks() -> anyhow::Result<(CpuManagementTasks, Vec<bool>)> {
tracing::info!("Starting Rapl and CPU management tasks.");
let num_cpus = RaplCpu::device_count()?;
let mut cpus = Vec::with_capacity(num_cpus);
let mut dram_available = Vec::with_capacity(num_cpus);
for cpu_id in 0..num_cpus {
let cpu = RaplCpu::init(cpu_id)?;
dram_available.push(cpu.is_dram_available());
tracing::info!(
"Initialized RAPL for CPU {} (DRAM: {})",
cpu_id,
dram_available[cpu_id],
);
cpus.push(cpu);
}
Ok((CpuManagementTasks::start(cpus)?, dram_available))
}
pub fn start_cpu_power_poller(poll_hz: u32) -> anyhow::Result<CpuPowerPoller> {
tracing::info!("Starting CPU RAPL power poller at {} Hz.", poll_hz);
let num_cpus = RaplCpu::device_count()?;
let mut cpus = Vec::with_capacity(num_cpus);
for cpu_id in 0..num_cpus {
let cpu = RaplCpu::init(cpu_id)?;
cpus.push((cpu_id, cpu));
}
Ok(start_cpu_poller(cpus, poll_hz))
}
pub fn check_privileges(enabled_groups: &[ApiGroup]) -> anyhow::Result<()> {
let is_root = nix::unistd::geteuid().is_root();
for &group in enabled_groups {
if group.requires_root() && !is_root {
tracing::error!(
"API group '{}' requires root privileges. \
Either run as root or remove it from --enable.",
group,
);
anyhow::bail!(
"API group '{}' requires root but Zeusd is not running as root",
group,
);
}
}
Ok(())
}
#[derive(Clone, Debug)]
pub struct EnabledGroups(pub HashSet<ApiGroup>);
#[derive(Clone)]
pub struct ServerState {
pub gpu_device_tasks: Option<GpuManagementTasks>,
pub cpu_device_tasks: Option<CpuManagementTasks>,
pub gpu_power_broadcast: Option<GpuPowerBroadcast>,
pub cpu_power_broadcast: Option<CpuPowerBroadcast>,
pub discovery_info: DiscoveryInfo,
pub enabled_groups: EnabledGroups,
pub signing_key: Option<SigningKeyData>,
}
macro_rules! configure_server {
($state:expr, $workers:expr) => {
HttpServer::new(move || {
let state = $state.clone();
let enabled = &state.enabled_groups.0;
let mut app = App::new()
.wrap(AuthMiddleware)
.wrap(tracing_actix_web::TracingLogger::default())
.configure(server_routes)
.app_data(web::Data::new(state.discovery_info.clone()))
.app_data(web::Data::new(state.enabled_groups.clone()));
if let Some(ref key) = state.signing_key {
app = app.app_data(web::Data::new(key.clone()));
}
if enabled.contains(&ApiGroup::GpuRead) || enabled.contains(&ApiGroup::GpuControl) {
let mut gpu_scope = web::scope("/gpu");
if enabled.contains(&ApiGroup::GpuRead) {
gpu_scope = gpu_scope.configure(gpu_read_routes);
}
if enabled.contains(&ApiGroup::GpuControl) {
gpu_scope = gpu_scope.configure(gpu_control_routes);
}
app = app.service(gpu_scope);
}
if let Some(ref tasks) = state.gpu_device_tasks {
app = app.app_data(web::Data::new(tasks.clone()));
}
if let Some(ref broadcast) = state.gpu_power_broadcast {
app = app.app_data(web::Data::new(broadcast.clone()));
}
if enabled.contains(&ApiGroup::CpuRead) {
app = app.service(web::scope("/cpu").configure(cpu_routes));
}
if let Some(ref tasks) = state.cpu_device_tasks {
app = app.app_data(web::Data::new(tasks.clone()));
}
if let Some(ref broadcast) = state.cpu_power_broadcast {
app = app.app_data(web::Data::new(broadcast.clone()));
}
app
})
.workers($workers)
};
}
pub fn start_server_uds(
listener: UnixListener,
state: ServerState,
num_workers: usize,
) -> std::io::Result<Server> {
Ok(configure_server!(state, num_workers)
.listen_uds(listener)?
.run())
}
pub fn start_server_tcp(
listener: TcpListener,
state: ServerState,
num_workers: usize,
) -> std::io::Result<Server> {
Ok(configure_server!(state, num_workers)
.listen(listener)?
.run())
}