use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
use log::{debug, error, info, warn};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::process::Command;
use std::thread;
use std::time::{Duration, Instant};
use crate::client::apis::configuration::Configuration;
use crate::client::apis::default_api;
const PING_INTERVAL_SECONDS: u64 = 30;
pub fn shell_command() -> Command {
if cfg!(target_os = "windows") {
let mut cmd = Command::new("cmd");
cmd.arg("/C");
cmd
} else {
let mut cmd = Command::new("bash");
cmd.arg("-c");
cmd
}
}
pub fn send_with_retries<T, E, F>(
config: &Configuration,
mut api_call: F,
wait_for_healthy_database_minutes: u64,
) -> Result<T, E>
where
F: FnMut() -> Result<T, E>,
E: std::fmt::Display,
{
match api_call() {
Ok(result) => Ok(result),
Err(e) => {
let error_str = e.to_string().to_lowercase();
let is_network_error = error_str.contains("connection")
|| error_str.contains("timeout")
|| error_str.contains("network")
|| error_str.contains("dns")
|| error_str.contains("resolve")
|| error_str.contains("unreachable");
if !is_network_error {
return Err(e);
}
warn!(
"Network error detected: {}. Entering retry loop for up to {} minutes.",
e, wait_for_healthy_database_minutes
);
let start_time = Instant::now();
let timeout_duration = Duration::from_secs(wait_for_healthy_database_minutes * 60);
loop {
if start_time.elapsed() >= timeout_duration {
error!(
"Retry timeout exceeded ({} minutes). Giving up.",
wait_for_healthy_database_minutes
);
return Err(e);
}
thread::sleep(Duration::from_secs(PING_INTERVAL_SECONDS));
match default_api::ping(config) {
Ok(_) => {
info!("Server is back online. Retrying original API call.");
return api_call();
}
Err(ping_error) => {
debug!(
"Server still unreachable: {}. Continuing to wait...",
ping_error
);
continue;
}
}
}
}
}
}
pub fn claim_action(
config: &Configuration,
workflow_id: i64,
action_id: i64,
compute_node_id: Option<i64>,
wait_for_healthy_database_minutes: u64,
) -> Result<bool, Box<dyn std::error::Error>> {
let claimed = send_with_retries(
config,
|| -> Result<bool, Box<dyn std::error::Error>> {
let body = match compute_node_id {
Some(id) => serde_json::json!({ "compute_node_id": id }),
None => serde_json::json!({}),
};
match default_api::claim_action(config, workflow_id, action_id, body) {
Ok(result) => {
let claimed = result
.get("claimed")
.and_then(|v| v.as_bool())
.unwrap_or(false);
Ok(claimed)
}
Err(err) => {
if let crate::client::apis::Error::ResponseError(ref response_content) = err
&& response_content.status == reqwest::StatusCode::CONFLICT
{
return Ok(false);
}
Err(Box::new(err))
}
}
},
wait_for_healthy_database_minutes,
)?;
Ok(claimed)
}
pub fn detect_nvidia_gpus() -> i64 {
match nvml_wrapper::Nvml::init() {
Ok(nvml) => match nvml.device_count() {
Ok(count) => {
info!("Detected {} NVIDIA GPU(s)", count);
count as i64
}
Err(e) => {
debug!("Failed to get NVIDIA GPU count: {}", e);
0
}
},
Err(e) => {
debug!(
"NVML initialization failed (no NVIDIA GPUs or drivers): {}",
e
);
0
}
}
}
pub fn capture_env_vars(file_path: &Path, substring: &str) {
info!(
"Capturing environment variables containing '{}' to: {}",
substring,
file_path.display()
);
let mut env_vars: Vec<(String, String)> = std::env::vars()
.filter(|(key, _)| key.contains(substring))
.collect();
env_vars.sort_by(|a, b| a.0.cmp(&b.0));
match File::create(file_path) {
Ok(mut file) => {
for (key, value) in &env_vars {
if let Err(e) = writeln!(file, "{}={}", key, value) {
error!("Error writing environment variable to file: {}", e);
return;
}
}
info!(
"Successfully captured {} environment variables",
env_vars.len()
);
}
Err(e) => {
error!(
"Error creating environment variables file {}: {}",
file_path.display(),
e
);
}
}
}
pub fn capture_dmesg(file_path: &Path, filter_after: Option<DateTime<Local>>) {
info!("Capturing dmesg output to: {}", file_path.display());
if let Some(cutoff) = filter_after {
info!(
"Filtering dmesg to only include messages after: {}",
cutoff.format("%Y-%m-%d %H:%M:%S")
);
}
match Command::new("dmesg").arg("--ctime").output() {
Ok(output) => match File::create(file_path) {
Ok(mut file) => {
let stdout_str = String::from_utf8_lossy(&output.stdout);
let filtered_output = if let Some(cutoff) = filter_after {
filter_dmesg_by_time(&stdout_str, cutoff)
} else {
stdout_str.to_string()
};
if let Err(e) = file.write_all(filtered_output.as_bytes()) {
error!("Error writing dmesg stdout to file: {}", e);
}
if !output.stderr.is_empty() {
if let Err(e) = file.write_all(b"\n--- stderr ---\n") {
error!("Error writing dmesg separator: {}", e);
}
if let Err(e) = file.write_all(&output.stderr) {
error!("Error writing dmesg stderr to file: {}", e);
}
}
info!("Successfully captured dmesg output");
}
Err(e) => {
error!("Error creating dmesg file {}: {}", file_path.display(), e);
}
},
Err(e) => {
error!("Error running dmesg command: {}", e);
}
}
}
fn filter_dmesg_by_time(dmesg_output: &str, cutoff: DateTime<Local>) -> String {
let mut filtered_lines = Vec::new();
let mut include_following = false;
for line in dmesg_output.lines() {
if let Some(timestamp) = parse_dmesg_timestamp(line) {
include_following = timestamp >= cutoff;
}
if include_following {
filtered_lines.push(line);
}
}
if filtered_lines.is_empty() {
format!(
"# No dmesg messages found after {}\n",
cutoff.format("%Y-%m-%d %H:%M:%S")
)
} else {
filtered_lines.join("\n") + "\n"
}
}
fn parse_dmesg_timestamp(line: &str) -> Option<DateTime<Local>> {
let start = line.find('[')?;
let end = line.find(']')?;
if start >= end {
return None;
}
let timestamp_str = &line[start + 1..end];
let naive = NaiveDateTime::parse_from_str(timestamp_str, "%a %b %e %H:%M:%S %Y").ok()?;
Local.from_local_datetime(&naive).single()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_dmesg_timestamp() {
let line = "[Tue Nov 25 10:11:08 2025] BIOS-e820: some message";
let ts = parse_dmesg_timestamp(line);
assert!(ts.is_some());
let line = "[Mon Dec 1 09:05:00 2025] kernel: message";
let ts = parse_dmesg_timestamp(line);
assert!(ts.is_some());
let line = "No timestamp here";
let ts = parse_dmesg_timestamp(line);
assert!(ts.is_none());
let line = "[invalid timestamp] message";
let ts = parse_dmesg_timestamp(line);
assert!(ts.is_none());
}
#[test]
fn test_filter_dmesg_by_time() {
let dmesg = "\
[Tue Nov 25 08:00:00 2025] old message 1
[Tue Nov 25 09:00:00 2025] old message 2
[Tue Nov 25 10:00:00 2025] new message 1
[Tue Nov 25 11:00:00 2025] new message 2
";
let naive =
NaiveDateTime::parse_from_str("Tue Nov 25 09:30:00 2025", "%a %b %e %H:%M:%S %Y")
.unwrap();
let cutoff = Local.from_local_datetime(&naive).single().unwrap();
let filtered = filter_dmesg_by_time(dmesg, cutoff);
assert!(filtered.contains("new message 1"));
assert!(filtered.contains("new message 2"));
assert!(!filtered.contains("old message 1"));
assert!(!filtered.contains("old message 2"));
}
}