use anyhow::Result;
use clap::Args;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use systemprompt_extension::ExtensionRegistry;
use systemprompt_runtime::AppContext;
use systemprompt_traits::{Job, JobContext};
use super::types::{BatchJobRunOutput, JobRunOutput, JobRunResult};
use crate::shared::CommandResult;
#[derive(Debug, Args)]
pub struct RunArgs {
#[arg(help = "Job name(s) to run", num_args = 1..)]
pub job_names: Vec<String>,
#[arg(long, help = "Run all enabled jobs")]
pub all: bool,
#[arg(long, help = "Run all jobs with the specified tag")]
pub tag: Option<String>,
#[arg(long, help = "Run jobs sequentially instead of in parallel")]
pub sequential: bool,
#[arg(
long = "param",
short = 'p',
value_name = "KEY=VALUE",
help = "Job parameters (can be specified multiple times)"
)]
pub params: Vec<String>,
}
pub async fn execute(args: RunArgs) -> Result<CommandResult<BatchJobRunOutput>> {
let ctx = Arc::new(AppContext::new().await?);
let registry = ExtensionRegistry::discover();
let parameters = parse_params(&args.params)?;
let job_names: Vec<String> = if args.all {
let mut names: Vec<String> = registry
.all_jobs()
.into_iter()
.filter(|j| j.enabled())
.map(|j| j.name().to_string())
.collect();
for job in inventory::iter::<&'static dyn Job> {
if job.enabled() && !names.contains(&job.name().to_string()) {
names.push(job.name().to_string());
}
}
names
} else if let Some(tag) = &args.tag {
let jobs = registry.jobs_by_tag(tag);
if jobs.is_empty() {
anyhow::bail!("No jobs found with tag '{}'", tag);
}
jobs.into_iter()
.filter(|j| j.enabled())
.map(|j| j.name().to_string())
.collect()
} else if args.job_names.is_empty() {
anyhow::bail!("Specify job name(s), use --all, or use --tag <tag> to run jobs");
} else {
args.job_names
};
let mut results = Vec::new();
for job_name in &job_names {
let result = run_single_job(job_name, Arc::clone(&ctx), ®istry, ¶meters).await;
results.push(result);
}
let succeeded = results.iter().filter(|r| r.result.success).count();
let failed = results.len() - succeeded;
let output = BatchJobRunOutput {
total: results.len(),
succeeded,
failed,
jobs_run: results,
};
Ok(CommandResult::table(output).with_title("Job Execution Results"))
}
fn parse_params(params: &[String]) -> Result<HashMap<String, String>> {
let mut map = HashMap::new();
for param in params {
let parts: Vec<&str> = param.splitn(2, '=').collect();
if parts.len() != 2 {
anyhow::bail!(
"Invalid parameter format '{}'. Use KEY=VALUE format.",
param
);
}
map.insert(parts[0].to_string(), parts[1].to_string());
}
Ok(map)
}
async fn run_single_job(
job_name: &str,
ctx: Arc<AppContext>,
registry: &ExtensionRegistry,
parameters: &HashMap<String, String>,
) -> JobRunOutput {
let start = Instant::now();
let ext_job = registry.job_by_name(job_name);
let inv_job = inventory::iter::<&'static dyn Job>
.into_iter()
.find(|&j| j.name() == job_name)
.copied();
if ext_job.is_none() && inv_job.is_none() {
return JobRunOutput {
job_name: job_name.to_string(),
status: "failed".to_string(),
duration_ms: start.elapsed().as_millis() as u64,
result: JobRunResult {
success: false,
message: Some(format!("Job '{}' not found", job_name)),
items_processed: None,
items_failed: None,
},
};
}
let db_pool = Arc::clone(ctx.db_pool());
let db_pool_any: Arc<dyn std::any::Any + Send + Sync> = Arc::new(db_pool);
let app_context_any: Arc<dyn std::any::Any + Send + Sync> = Arc::new(Arc::clone(&ctx));
let job_ctx = JobContext::new(db_pool_any, app_context_any).with_parameters(parameters.clone());
let execute_result = if let Some(job) = ext_job {
job.execute(&job_ctx).await
} else if let Some(job) = inv_job {
job.execute(&job_ctx).await
} else {
unreachable!()
};
match execute_result {
Ok(result) => JobRunOutput {
job_name: job_name.to_string(),
status: if result.success { "success" } else { "failed" }.to_string(),
duration_ms: start.elapsed().as_millis() as u64,
result: JobRunResult {
success: result.success,
message: result.message,
items_processed: result.items_processed,
items_failed: result.items_failed,
},
},
Err(e) => JobRunOutput {
job_name: job_name.to_string(),
status: "failed".to_string(),
duration_ms: start.elapsed().as_millis() as u64,
result: JobRunResult {
success: false,
message: Some(e.to_string()),
items_processed: None,
items_failed: None,
},
},
}
}