use crate::core::info::{IgnoredGpuProcess, SchedulerInfo};
use crate::core::job::{DependencyMode, Job, JobNotifications};
use anyhow::{anyhow, Context};
use reqwest::{Client as ReqwestClient, StatusCode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
fn connection_error_context(err: reqwest::Error) -> anyhow::Error {
if err.is_connect() {
anyhow!(
"Could not connect to gflowd server. Is the server running?\n\
Hint: Start the server with 'gflowd up'"
)
} else {
err.into()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobSubmitResponse {
pub id: u32,
pub run_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaginatedJobsResponse {
pub jobs: Vec<Job>,
pub total: usize,
pub limit: usize,
pub offset: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct UpdateJobRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub script: Option<PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpus: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conda_env: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub time_limit: Option<Option<Duration>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_limit_mb: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu_memory_limit_mb: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub depends_on_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dependency_mode: Option<Option<DependencyMode>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_cancel_on_dependency_failure: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent: Option<Option<usize>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_retries: Option<Option<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub notifications: Option<JobNotifications>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateJobResponse {
pub job: Job,
pub updated_fields: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopJob {
pub id: u32,
pub name: Option<String>,
pub runtime_secs: f64,
pub gpus: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
pub user: Option<String>,
pub since: Option<u64>,
pub total_jobs: usize,
pub completed_jobs: usize,
pub failed_jobs: usize,
pub cancelled_jobs: usize,
pub timeout_jobs: usize,
pub running_jobs: usize,
pub queued_jobs: usize,
pub avg_wait_secs: Option<f64>,
pub avg_runtime_secs: Option<f64>,
pub total_gpu_hours: f64,
pub jobs_with_gpus: usize,
pub avg_gpus_per_job: f64,
pub peak_gpu_usage: u32,
pub success_rate: f64,
pub top_jobs: Vec<TopJob>,
}
#[derive(Debug, Clone)]
pub struct Client {
client: ReqwestClient,
base_url: String,
}
impl Client {
pub fn build(config: &crate::config::Config) -> anyhow::Result<Self> {
crate::tls::ensure_rustls_provider_installed();
let host = &config.daemon.host;
let port = config.daemon.port;
let base_url = format!("http://{host}:{port}");
let client = ReqwestClient::new();
Ok(Self { client, base_url })
}
async fn extract_error_message(response: reqwest::Response) -> String {
let error_body = response
.text()
.await
.unwrap_or_else(|_| String::from("Unknown error"));
if let Ok(json_error) = serde_json::from_str::<serde_json::Value>(&error_body) {
if let Some(error_msg) = json_error.get("error").and_then(|e| e.as_str()) {
return error_msg.to_string();
}
}
error_body
}
async fn post_expect_success(&self, path: String, action: &str) -> anyhow::Result<()> {
let response = self
.client
.post(path)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let status = response.status();
let error_msg = Self::extract_error_message(response).await;
let detail = if error_msg.trim().is_empty() {
status.to_string()
} else {
format!("{status}: {error_msg}")
};
return Err(anyhow!("Failed to {action}: {detail}"));
}
Ok(())
}
pub async fn list_jobs(&self) -> anyhow::Result<Vec<Job>> {
let jobs = self
.client
.get(format!("{}/jobs", self.base_url))
.send()
.await
.map_err(connection_error_context)?
.json::<Vec<Job>>()
.await
.context("Failed to parse jobs from response")?;
Ok(jobs)
}
pub async fn list_jobs_with_query(
&self,
states: Option<String>,
user: Option<String>,
limit: Option<usize>,
offset: Option<usize>,
created_after: Option<i64>,
order: Option<String>,
) -> anyhow::Result<Vec<Job>> {
let mut request = self.client.get(format!("{}/jobs", self.base_url));
let mut params = vec![];
if let Some(s) = states {
params.push(("state", s));
}
if let Some(u) = user {
params.push(("user", u));
}
if let Some(l) = limit {
params.push(("limit", l.to_string()));
}
if let Some(o) = offset {
params.push(("offset", o.to_string()));
}
if let Some(t) = created_after {
params.push(("created_after", t.to_string()));
}
if let Some(order) = order {
params.push(("order", order));
}
if !params.is_empty() {
request = request.query(¶ms);
}
let response = request.send().await.map_err(connection_error_context)?;
let response_text = response.text().await?;
if let Ok(paginated) = serde_json::from_str::<PaginatedJobsResponse>(&response_text) {
Ok(paginated.jobs)
} else {
serde_json::from_str::<Vec<Job>>(&response_text)
.context("Failed to parse jobs from response")
}
}
pub async fn get_job(&self, job_id: u32) -> anyhow::Result<Option<Job>> {
tracing::debug!("Getting job {job_id}");
let response = self
.client
.get(format!("{}/jobs/{}", self.base_url, job_id))
.send()
.await
.map_err(connection_error_context)?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
let job = response
.json::<Job>()
.await
.context("Failed to parse job from response")?;
Ok(Some(job))
}
pub async fn add_job(&self, job: Job) -> anyhow::Result<JobSubmitResponse> {
tracing::debug!("Adding job: {job:?}");
let response = self
.client
.post(format!("{}/jobs", self.base_url))
.json(&job)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow::anyhow!("Failed to add job: {}", error_msg));
}
let job_response: JobSubmitResponse = response
.json()
.await
.context("Failed to parse response json")?;
Ok(job_response)
}
pub async fn add_jobs(&self, jobs: Vec<Job>) -> anyhow::Result<Vec<JobSubmitResponse>> {
if jobs.is_empty() {
return Ok(Vec::new());
}
tracing::debug!("Adding {} jobs in batch", jobs.len());
let response = self
.client
.post(format!("{}/jobs/batch", self.base_url))
.json(&jobs)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow::anyhow!("Failed to add batch jobs: {}", error_msg));
}
let job_responses: Vec<JobSubmitResponse> = response
.json()
.await
.context("Failed to parse batch response json")?;
Ok(job_responses)
}
pub async fn finish_job(&self, job_id: u32) -> anyhow::Result<()> {
tracing::debug!("Finishing job {job_id}");
self.post_expect_success(
format!("{}/jobs/{}/finish", self.base_url, job_id),
"finish job",
)
.await
}
pub async fn fail_job(&self, job_id: u32) -> anyhow::Result<()> {
tracing::debug!("Failing job {job_id}");
self.post_expect_success(
format!("{}/jobs/{}/fail", self.base_url, job_id),
"fail job",
)
.await
}
pub async fn cancel_job(&self, job_id: u32) -> anyhow::Result<()> {
tracing::debug!("Cancelling job {job_id}");
self.post_expect_success(
format!("{}/jobs/{}/cancel", self.base_url, job_id),
"cancel job",
)
.await
}
pub async fn hold_job(&self, job_id: u32) -> anyhow::Result<()> {
tracing::debug!("Holding job {job_id}");
self.post_expect_success(
format!("{}/jobs/{}/hold", self.base_url, job_id),
"hold job",
)
.await
}
pub async fn release_job(&self, job_id: u32) -> anyhow::Result<()> {
tracing::debug!("Releasing job {job_id}");
self.post_expect_success(
format!("{}/jobs/{}/release", self.base_url, job_id),
"release job",
)
.await
}
pub async fn update_job(
&self,
job_id: u32,
request: UpdateJobRequest,
) -> anyhow::Result<UpdateJobResponse> {
tracing::debug!("Updating job {job_id}");
let response = self
.client
.patch(format!("{}/jobs/{}", self.base_url, job_id))
.json(&request)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to update job: {}", error_msg));
}
let result: UpdateJobResponse = response
.json()
.await
.context("Failed to parse update job response")?;
Ok(result)
}
pub async fn get_job_log_path(&self, job_id: u32) -> anyhow::Result<Option<String>> {
tracing::debug!("Getting log path for job {job_id}");
let response = self
.client
.get(format!("{}/jobs/{}/log", self.base_url, job_id))
.send()
.await
.map_err(connection_error_context)?;
let status = response.status();
if status == StatusCode::OK {
response
.json::<Option<String>>()
.await
.context("Failed to parse log path from response")
} else if status == StatusCode::NOT_FOUND {
Ok(None)
} else {
let body = response
.text()
.await
.unwrap_or_else(|_| String::from("<failed to read body>"));
Err(anyhow!(
"Failed to get log path for job {} (status {}): {}",
job_id,
status,
body
))
}
}
pub async fn get_stats(
&self,
user: Option<&str>,
since: Option<i64>,
) -> anyhow::Result<UsageStats> {
let mut params = vec![];
if let Some(u) = user {
params.push(("user", u.to_string()));
}
if let Some(s) = since {
params.push(("since", s.to_string()));
}
let mut request = self.client.get(format!("{}/stats", self.base_url));
if !params.is_empty() {
request = request.query(¶ms);
}
let stats = request
.send()
.await
.map_err(connection_error_context)?
.json::<UsageStats>()
.await
.context("Failed to parse stats from response")?;
Ok(stats)
}
pub async fn get_info(&self) -> anyhow::Result<SchedulerInfo> {
tracing::debug!("Getting scheduler info");
let info = self
.client
.get(format!("{}/info", self.base_url))
.send()
.await
.map_err(connection_error_context)?
.json::<SchedulerInfo>()
.await
.context("Failed to parse info from response")?;
Ok(info)
}
pub async fn get_health(&self) -> anyhow::Result<StatusCode> {
tracing::debug!("Getting health status");
let health = self
.client
.get(format!("{}/health", self.base_url))
.send()
.await
.map_err(connection_error_context)?
.status();
Ok(health)
}
pub async fn get_health_with_pid(&self) -> anyhow::Result<Option<u32>> {
tracing::debug!("Getting health status with PID");
let response = self
.client
.get(format!("{}/health", self.base_url))
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
return Ok(None);
}
let health_data: serde_json::Value = response
.json()
.await
.context("Failed to parse health response")?;
let pid = health_data
.get("pid")
.and_then(|p| p.as_u64())
.map(|p| p as u32);
Ok(pid)
}
pub async fn resolve_dependency(&self, username: &str, shorthand: &str) -> anyhow::Result<u32> {
tracing::debug!(
"Resolving dependency '{}' for user '{}'",
shorthand,
username
);
let response = self
.client
.get(format!("{}/jobs/resolve-dependency", self.base_url))
.query(&[("username", username), ("shorthand", shorthand)])
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow::anyhow!(
"Failed to resolve dependency: {}",
error_msg
));
}
let result: serde_json::Value = response
.json()
.await
.context("Failed to parse response json")?;
let job_id = result
.get("job_id")
.and_then(|v| v.as_u64())
.context("Invalid response format: missing or invalid job_id")?
as u32;
Ok(job_id)
}
pub async fn set_allowed_gpus(&self, allowed_indices: Option<Vec<u32>>) -> anyhow::Result<()> {
tracing::debug!("Setting allowed GPU indices: {:?}", allowed_indices);
let request_body = serde_json::json!({
"allowed_indices": allowed_indices
});
let response = self
.client
.post(format!("{}/gpus", self.base_url))
.json(&request_body)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to set GPU configuration: {}", error_msg));
}
Ok(())
}
pub async fn list_ignored_gpu_processes(&self) -> anyhow::Result<Vec<IgnoredGpuProcess>> {
tracing::debug!("Listing ignored GPU processes");
let response = self
.client
.get(format!("{}/gpu-processes", self.base_url))
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!(
"Failed to list ignored GPU processes: {}",
error_msg
));
}
let processes = response
.json::<Vec<IgnoredGpuProcess>>()
.await
.context("Failed to parse ignored GPU processes from response")?;
Ok(processes)
}
pub async fn ignore_gpu_process(&self, gpu_index: u32, pid: u32) -> anyhow::Result<()> {
tracing::debug!("Ignoring GPU process pid={} on gpu={}", pid, gpu_index);
self.post_gpu_process_action("ignore", gpu_index, pid).await
}
pub async fn unignore_gpu_process(&self, gpu_index: u32, pid: u32) -> anyhow::Result<()> {
tracing::debug!(
"Removing ignored GPU process pid={} on gpu={}",
pid,
gpu_index
);
self.post_gpu_process_action("unignore", gpu_index, pid)
.await
}
async fn post_gpu_process_action(
&self,
action: &str,
gpu_index: u32,
pid: u32,
) -> anyhow::Result<()> {
let request_body = serde_json::json!({
"gpu_index": gpu_index,
"pid": pid
});
let response = self
.client
.post(format!("{}/gpu-processes/{}", self.base_url, action))
.json(&request_body)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to {} GPU process: {}", action, error_msg));
}
Ok(())
}
pub async fn set_group_max_concurrency(
&self,
group_id: &str,
max_concurrent: usize,
) -> anyhow::Result<usize> {
tracing::debug!(
"Setting max_concurrency for group '{}' to {}",
group_id,
max_concurrent
);
let request_body = serde_json::json!({
"max_concurrent": max_concurrent
});
let response = self
.client
.post(format!(
"{}/groups/{}/max-concurrency",
self.base_url, group_id
))
.json(&request_body)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!(
"Failed to set group max_concurrency: {}",
error_msg
));
}
let result: serde_json::Value = response
.json()
.await
.context("Failed to parse response json")?;
let updated_jobs = result
.get("updated_jobs")
.and_then(|v| v.as_u64())
.context("Invalid response format: missing or invalid updated_jobs")?
as usize;
Ok(updated_jobs)
}
pub async fn create_reservation(
&self,
user: String,
gpu_spec: crate::core::reservation::GpuSpec,
start_time: std::time::SystemTime,
duration_secs: u64,
) -> anyhow::Result<u32> {
use crate::core::reservation::GpuSpec;
let mut request_body = serde_json::json!({
"user": user,
"start_time": start_time,
"duration_secs": duration_secs,
});
match gpu_spec {
GpuSpec::Count(count) => {
request_body["gpu_count"] = serde_json::json!(count);
}
GpuSpec::Indices(indices) => {
request_body["gpu_indices"] = serde_json::json!(indices);
}
}
let response = self
.client
.post(format!("{}/reservations", self.base_url))
.json(&request_body)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to create reservation: {}", error_msg));
}
let result: serde_json::Value = response
.json()
.await
.context("Failed to parse response json")?;
let reservation_id = result
.get("reservation_id")
.and_then(|v| v.as_u64())
.context("Invalid response format: missing or invalid reservation_id")?
as u32;
Ok(reservation_id)
}
pub async fn list_reservations(
&self,
user: Option<String>,
status: Option<String>,
active_only: bool,
) -> anyhow::Result<Vec<crate::core::reservation::GpuReservation>> {
let mut url = format!("{}/reservations", self.base_url);
let mut query_params = Vec::new();
if let Some(user) = user {
query_params.push(format!("user={}", user));
}
if let Some(status) = status {
query_params.push(format!("status={}", status));
}
if active_only {
query_params.push("active_only=true".to_string());
}
if !query_params.is_empty() {
url.push('?');
url.push_str(&query_params.join("&"));
}
let response = self
.client
.get(&url)
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to list reservations: {}", error_msg));
}
let reservations = response
.json()
.await
.context("Failed to parse response json")?;
Ok(reservations)
}
pub async fn get_reservation(
&self,
id: u32,
) -> anyhow::Result<Option<crate::core::reservation::GpuReservation>> {
let response = self
.client
.get(format!("{}/reservations/{}", self.base_url, id))
.send()
.await
.map_err(connection_error_context)?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to get reservation: {}", error_msg));
}
let reservation = response
.json()
.await
.context("Failed to parse response json")?;
Ok(Some(reservation))
}
pub async fn cancel_reservation(&self, id: u32) -> anyhow::Result<()> {
let response = self
.client
.delete(format!("{}/reservations/{}", self.base_url, id))
.send()
.await
.map_err(connection_error_context)?;
if !response.status().is_success() {
let error_msg = Self::extract_error_message(response).await;
return Err(anyhow!("Failed to cancel reservation: {}", error_msg));
}
Ok(())
}
}
pub async fn get_job_or_warn(client: &Client, job_id: u32) -> anyhow::Result<Option<Job>> {
match client.get_job(job_id).await? {
Some(job) => Ok(Some(job)),
None => {
eprintln!("Error: Job {} not found", job_id);
Ok(None)
}
}
}