use crate::client::apis::configuration::{Configuration, TlsConfig};
use crate::client::apis::default_api;
use crate::client::commands::get_env_user_name;
use crate::client::commands::select_workflow_interactively;
use crate::client::job_runner::JobRunner;
use crate::client::log_paths::get_job_runner_log_file;
use crate::client::utils::detect_nvidia_gpus;
use crate::client::workflow_manager::WorkflowManager;
use crate::config::TorcConfig;
use crate::models;
use chrono::{DateTime, Utc};
use clap::Parser;
use env_logger::Builder;
use log::{LevelFilter, error, info};
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use sysinfo::{CpuRefreshKind, RefreshKind, System, SystemExt};
struct MultiWriter {
stdout: std::io::Stdout,
file: File,
}
impl Write for MultiWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.stdout.write_all(buf)?;
self.file.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.stdout.flush()?;
self.file.flush()
}
}
#[derive(Parser, Debug)]
#[command(about = "Run jobs locally on the current node", long_about = None)]
pub struct Args {
#[arg()]
pub workflow_id: Option<i64>,
#[arg(short, long, default_value = "http://localhost:8080/torc-service/v1")]
pub url: String,
#[arg(short, long, default_value = "torc_output")]
pub output_dir: PathBuf,
#[arg(short, long, default_value = "5.0")]
pub poll_interval: f64,
#[arg(long)]
pub max_parallel_jobs: Option<i64>,
#[arg(long)]
pub time_limit: Option<String>,
#[arg(long)]
pub end_time: Option<String>,
#[arg(long)]
pub num_cpus: Option<i64>,
#[arg(long)]
pub memory_gb: Option<f64>,
#[arg(long)]
pub num_gpus: Option<i64>,
#[arg(long)]
pub num_nodes: Option<i64>,
#[arg(long)]
pub scheduler_config_id: Option<i64>,
#[arg(long)]
pub log_prefix: Option<String>,
#[arg(long)]
pub cpu_affinity_cpus_per_job: Option<i64>,
#[arg(long, default_value = "info")]
pub log_level: String,
#[arg(long, env = "TORC_PASSWORD", hide_env_values = true)]
pub password: Option<String>,
#[arg(long, env = "TORC_TLS_CA_CERT")]
pub tls_ca_cert: Option<String>,
#[arg(long, env = "TORC_TLS_INSECURE")]
pub tls_insecure: bool,
}
pub fn run(args: &Args) {
let hostname = hostname::get()
.expect("Failed to get hostname")
.into_string()
.expect("Hostname is not valid UTF-8");
let tls = TlsConfig {
ca_cert_path: args.tls_ca_cert.as_ref().map(std::path::PathBuf::from),
insecure: args.tls_insecure,
};
let mut config = Configuration::with_tls(tls);
config.base_path = args.url.clone();
if let Some(ref password) = args.password {
let username = get_env_user_name();
config.basic_auth = Some((username, Some(password.clone())));
}
let user = get_env_user_name();
let workflow_id = args.workflow_id.unwrap_or_else(|| {
select_workflow_interactively(&config, &user).unwrap_or_else(|e| {
eprintln!("Error selecting workflow: {}", e);
std::process::exit(1);
})
});
let workflow = match default_api::get_workflow(&config, workflow_id) {
Ok(workflow) => workflow,
Err(e) => {
eprintln!("Error getting workflow: {}", e);
std::process::exit(1);
}
};
match default_api::is_workflow_uninitialized(&config, workflow_id) {
Ok(response) => {
if let Some(is_uninitialized) =
response.get("is_uninitialized").and_then(|v| v.as_bool())
&& is_uninitialized
{
eprintln!(
"Workflow {} has all jobs uninitialized. Initializing workflow...",
workflow_id
);
let torc_config = TorcConfig::load().unwrap_or_default();
let workflow_manager =
WorkflowManager::new(config.clone(), torc_config, workflow.clone());
match workflow_manager.initialize(false) {
Ok(()) => {
eprintln!("Successfully initialized workflow {}", workflow_id);
}
Err(e) => {
eprintln!("Error initializing workflow: {}", e);
std::process::exit(1);
}
}
}
}
Err(e) => {
eprintln!("Error checking if workflow is uninitialized: {}", e);
std::process::exit(1);
}
}
let run_id = match default_api::get_workflow_status(&config, workflow_id) {
Ok(status) => status.run_id,
Err(e) => {
eprintln!("Error getting workflow status: {}", e);
std::process::exit(1);
}
};
if let Err(e) = std::fs::create_dir_all(&args.output_dir) {
eprintln!(
"Error creating output directory {}: {}",
args.output_dir.display(),
e
);
std::process::exit(1);
}
let log_file_path =
get_job_runner_log_file(args.output_dir.clone(), &hostname, workflow_id, run_id);
let log_file = match File::create(&log_file_path) {
Ok(file) => file,
Err(e) => {
eprintln!("Error creating log file {}: {}", log_file_path, e);
std::process::exit(1);
}
};
let multi_writer = MultiWriter {
stdout: std::io::stdout(),
file: log_file,
};
let log_level_filter = match args.log_level.to_lowercase().as_str() {
"error" => LevelFilter::Error,
"warn" => LevelFilter::Warn,
"info" => LevelFilter::Info,
"debug" => LevelFilter::Debug,
"trace" => LevelFilter::Trace,
_ => {
eprintln!(
"Invalid log level '{}', defaulting to 'info'",
args.log_level
);
LevelFilter::Info
}
};
let mut builder = Builder::from_default_env();
builder
.target(env_logger::Target::Pipe(Box::new(multi_writer)))
.filter_level(log_level_filter)
.try_init()
.ok();
info!("Starting job runner");
info!("Hostname: {}", hostname);
info!("Output directory: {}", args.output_dir.display());
info!("Log file: {}", log_file_path);
let parsed_end_time = if let Some(end_time_str) = &args.end_time {
match end_time_str.parse::<DateTime<Utc>>() {
Ok(dt) => Some(dt),
Err(e) => {
error!("Error parsing end_time: {}", e);
std::process::exit(1);
}
}
} else {
None
};
let refresh_kind = RefreshKind::new()
.with_cpu(CpuRefreshKind::everything())
.with_memory();
let mut system = System::new_with_specifics(refresh_kind);
system.refresh_cpu();
system.refresh_memory();
let system_cpus = system.cpus().len() as i64;
let system_memory_gb = (system.total_memory() as f64) / (1024.0 * 1024.0 * 1024.0);
let system_gpus = detect_nvidia_gpus();
let resources = models::ComputeNodesResources::new(
args.num_cpus.unwrap_or(system_cpus),
args.memory_gb.unwrap_or(system_memory_gb),
args.num_gpus.unwrap_or(system_gpus),
args.num_nodes.unwrap_or(1),
);
let pid = 1; let unique_label = format!("wf{}_h{}_r{}", workflow_id, hostname, run_id);
let mut compute_node_model = models::ComputeNodeModel::new(
workflow_id,
hostname.clone(),
pid,
Utc::now().to_rfc3339(),
resources.num_cpus,
resources.memory_gb,
resources.num_gpus,
resources.num_nodes,
"local".to_string(),
None,
);
compute_node_model.is_active = Some(true);
let compute_node = match default_api::create_compute_node(&config, compute_node_model) {
Ok(node) => node,
Err(e) => {
error!("Error creating compute node: {}", e);
std::process::exit(1);
}
};
let mut job_runner = JobRunner::new(
config.clone(),
workflow,
run_id,
compute_node.id.expect("Compute node ID should be set"),
args.output_dir.clone(),
args.poll_interval,
args.max_parallel_jobs,
args.time_limit.clone(),
parsed_end_time,
resources,
args.scheduler_config_id,
args.log_prefix.clone(),
args.cpu_affinity_cpus_per_job,
false,
unique_label,
None, );
match job_runner.run_worker() {
Ok(result) => {
info!(
"Job runner completed successfully (had_failures={}, had_terminations={})",
result.had_failures, result.had_terminations
);
}
Err(e) => {
error!("Job runner failed: {}", e);
std::process::exit(1);
}
}
}