use std::collections::HashMap;
use std::time::Instant;
use actix_web::web::Bytes;
use actix_web::{web, HttpResponse};
use paste::paste;
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use crate::devices::gpu::power::{GpuPowerBroadcast, GpuPowerSnapshot};
use crate::devices::gpu::{GpuCommand, GpuManagementTasks, GpuResponse};
use crate::error::ZeusdError;
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GpuReadQuery {
pub gpu_ids: Option<String>,
}
fn parse_gpu_ids(raw: &str) -> Vec<usize> {
raw.split(',')
.filter_map(|part| part.trim().parse().ok())
.collect()
}
macro_rules! impl_handler_for_gpu_command {
($api:ident, $path:literal, $($field:ident: $ftype:ty,)*) => {
paste! {
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct [<$api:camel>] {
pub gpu_ids: String,
$(pub $field: $ftype,)*
pub block: bool,
}
impl From<[<$api:camel>]> for GpuCommand {
fn from(_request: [<$api:camel>]) -> Self {
GpuCommand::[<$api:camel>] {
$($field: _request.$field),*
}
}
}
#[actix_web::post($path)]
#[tracing::instrument(
skip(query, device_tasks),
fields(
gpu_ids = %query.gpu_ids,
block = %query.block,
$($field = %query.$field),*
)
)]
async fn [<$api:snake _handler>](
query: web::Query<[<$api:camel>]>,
device_tasks: web::Data<GpuManagementTasks>,
) -> Result<HttpResponse, ZeusdError> {
let now = Instant::now();
tracing::info!("Received request");
let gpu_ids = parse_gpu_ids(&query.gpu_ids);
if gpu_ids.is_empty() {
return Ok(HttpResponse::BadRequest().json(serde_json::json!({
"error": "gpu_ids must contain at least one GPU index"
})));
}
let device_count = device_tasks.device_count();
for &id in &gpu_ids {
if id >= device_count {
return Err(ZeusdError::GpuNotFoundError(id));
}
}
let query = query.into_inner();
let block = query.block;
let command: GpuCommand = query.into();
if block {
let mut handles = Vec::with_capacity(gpu_ids.len());
for &gpu_id in &gpu_ids {
let cmd = command.clone();
let tasks = device_tasks.clone();
handles.push(async move {
(gpu_id, tasks.send_command_blocking(gpu_id, cmd, now).await)
});
}
let results = futures::future::join_all(handles).await;
let mut errors: HashMap<usize, String> = HashMap::new();
for (gpu_id, result) in results {
if let Err(e) = result {
errors.insert(gpu_id, e.to_string());
}
}
if errors.is_empty() {
Ok(HttpResponse::Ok().finish())
} else {
Ok(HttpResponse::InternalServerError().json(serde_json::json!({
"errors": errors
})))
}
} else {
let mut errors: HashMap<usize, String> = HashMap::new();
for &gpu_id in &gpu_ids {
if let Err(e) = device_tasks.send_command_nonblocking(gpu_id, command.clone(), now) {
errors.insert(gpu_id, e.to_string());
}
}
if errors.is_empty() {
Ok(HttpResponse::Ok().finish())
} else {
Ok(HttpResponse::InternalServerError().json(serde_json::json!({
"errors": errors
})))
}
}
}
}
};
}
impl_handler_for_gpu_command!(
set_persistence_mode,
"/set_persistence_mode",
enabled: bool,
);
impl_handler_for_gpu_command!(
set_power_limit,
"/set_power_limit",
power_limit_mw: u32,
);
impl_handler_for_gpu_command!(
set_gpu_locked_clocks,
"/set_gpu_locked_clocks",
min_clock_mhz: u32,
max_clock_mhz: u32,
);
impl_handler_for_gpu_command!(reset_gpu_locked_clocks, "/reset_gpu_locked_clocks",);
impl_handler_for_gpu_command!(
set_mem_locked_clocks,
"/set_mem_locked_clocks",
min_clock_mhz: u32,
max_clock_mhz: u32,
);
impl_handler_for_gpu_command!(reset_mem_locked_clocks, "/reset_mem_locked_clocks",);
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GpuGetCumulativeEnergyQuery {
pub gpu_ids: Option<String>,
}
#[derive(Serialize)]
struct GpuEnergyResponse {
energy_mj: u64,
}
#[actix_web::get("/get_cumulative_energy")]
#[tracing::instrument(skip(query, device_tasks), fields(gpu_ids = ?query.gpu_ids))]
async fn get_cumulative_energy_handler(
query: web::Query<GpuGetCumulativeEnergyQuery>,
device_tasks: web::Data<GpuManagementTasks>,
) -> Result<HttpResponse, ZeusdError> {
let now = Instant::now();
tracing::info!("Received request");
let device_count = device_tasks.device_count();
let gpu_ids: Vec<usize> = match &query.gpu_ids {
Some(raw) => {
let ids = parse_gpu_ids(raw);
if ids.is_empty() {
return Ok(HttpResponse::BadRequest().json(serde_json::json!({
"error": "gpu_ids must contain at least one GPU index"
})));
}
for &id in &ids {
if id >= device_count {
return Err(ZeusdError::GpuNotFoundError(id));
}
}
ids
}
None => (0..device_count).collect(),
};
let mut handles = Vec::with_capacity(gpu_ids.len());
for &gpu_id in &gpu_ids {
let tasks = device_tasks.clone();
handles.push(async move {
(
gpu_id,
tasks
.send_command_blocking(gpu_id, GpuCommand::GetTotalEnergyConsumption, now)
.await,
)
});
}
let results = futures::future::join_all(handles).await;
let mut response_map: HashMap<String, GpuEnergyResponse> = HashMap::new();
let mut errors: HashMap<String, String> = HashMap::new();
for (gpu_id, result) in results {
match result {
Ok(GpuResponse::Energy { energy_mj }) => {
response_map.insert(gpu_id.to_string(), GpuEnergyResponse { energy_mj });
}
Ok(_) => {
errors.insert(gpu_id.to_string(), "Unexpected response type".to_string());
}
Err(e) => {
errors.insert(gpu_id.to_string(), e.to_string());
}
}
}
if errors.is_empty() {
Ok(HttpResponse::Ok().json(response_map))
} else {
Ok(HttpResponse::InternalServerError().json(serde_json::json!({
"errors": errors
})))
}
}
fn filter_snapshot(snapshot: &GpuPowerSnapshot, gpu_ids: &Option<Vec<usize>>) -> GpuPowerSnapshot {
match gpu_ids {
None => snapshot.clone(),
Some(ids) => GpuPowerSnapshot {
timestamp_ms: snapshot.timestamp_ms,
power_mw: snapshot
.power_mw
.iter()
.filter(|(k, _)| ids.contains(k))
.map(|(&k, &v)| (k, v))
.collect(),
},
}
}
#[actix_web::get("/get_power")]
#[tracing::instrument(skip(broadcast), fields(gpu_ids = ?query.gpu_ids))]
async fn get_power_handler(
query: web::Query<GpuReadQuery>,
broadcast: web::Data<GpuPowerBroadcast>,
) -> HttpResponse {
tracing::info!("Received request");
let gpu_ids = query.gpu_ids.as_ref().map(|s| parse_gpu_ids(s));
if let Some(ref ids) = gpu_ids {
if let Err(unknown) = broadcast.validate_ids(ids) {
return HttpResponse::BadRequest().json(serde_json::json!({
"error": format!(
"Unknown GPU indices: {:?}. Available: {:?}",
unknown,
broadcast.valid_ids(),
)
}));
}
}
let _guard = broadcast.add_subscriber();
let snapshot = broadcast.wait_for_fresh().await.unwrap_or_default();
let filtered = filter_snapshot(&snapshot, &gpu_ids);
HttpResponse::Ok().json(filtered)
}
#[actix_web::get("/stream_power")]
#[tracing::instrument(skip(broadcast), fields(gpu_ids = ?query.gpu_ids))]
async fn stream_power_handler(
query: web::Query<GpuReadQuery>,
broadcast: web::Data<GpuPowerBroadcast>,
) -> HttpResponse {
tracing::info!("Received request");
let gpu_ids = query.gpu_ids.as_ref().map(|s| parse_gpu_ids(s));
if let Some(ref ids) = gpu_ids {
if let Err(unknown) = broadcast.validate_ids(ids) {
return HttpResponse::BadRequest().json(serde_json::json!({
"error": format!(
"Unknown GPU indices: {:?}. Available: {:?}",
unknown,
broadcast.valid_ids(),
)
}));
}
}
let guard = broadcast.add_subscriber();
let stream = broadcast.stream().map(move |snapshot| {
let _ = &guard; let filtered = filter_snapshot(&snapshot, &gpu_ids);
let json = serde_json::to_string(&filtered).unwrap_or_default();
Ok::<_, actix_web::Error>(Bytes::from(format!("data: {json}\n\n")))
});
HttpResponse::Ok()
.insert_header(("Content-Type", "text/event-stream"))
.insert_header(("Cache-Control", "no-cache"))
.streaming(stream)
}
pub fn gpu_read_routes(cfg: &mut web::ServiceConfig) {
cfg.service(get_cumulative_energy_handler)
.service(get_power_handler)
.service(stream_power_handler);
}
pub fn gpu_control_routes(cfg: &mut web::ServiceConfig) {
cfg.service(set_persistence_mode_handler)
.service(set_power_limit_handler)
.service(set_gpu_locked_clocks_handler)
.service(reset_gpu_locked_clocks_handler)
.service(set_mem_locked_clocks_handler)
.service(reset_mem_locked_clocks_handler);
}