use serde_json::json;
use sqlx::postgres::PgPoolOptions;
use std::sync::Arc;
use stormchaser_engine::handler;
use stormchaser_model::auth::OpaClient;
use stormchaser_model::step::{StepInstance, StepStatus};
use stormchaser_model::RunId;
use stormchaser_tls::TlsConfig;
use stormchaser_tls::TlsReloader;
#[tokio::test]
async fn test_dynamic_parallelism_with_batching() {
let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
dotenvy::dotenv().ok();
format!(
"postgres://stormchaser:{}@localhost:5432/stormchaser",
std::env::var("STORMCHASER_DEV_PASSWORD")
.expect("STORMCHASER_DEV_PASSWORD must be set if DATABASE_URL is not set")
)
});
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&db_url)
.await
.unwrap();
let nats_url = std::env::var("NATS_URL").unwrap_or_else(|_| "nats://localhost:4222".into());
let nats_client = async_nats::connect(nats_url).await.unwrap();
let opa_client = Arc::new(OpaClient::new(None, None));
let run_id = RunId::new_v4();
let dsl = r#"
stormchaser_dsl_version = "v1"
workflow "parallel-test" {
steps {
step "generate" "RunContainer" {
image = "alpine"
next = ["process"]
}
step "process" "RunContainer" {
iterate = "[\"a\", \"b\", \"c\", \"d\"]"
iterate_as = "item"
strategy {
max_parallel = 2
}
params = {
val = "${item}"
}
image = "alpine"
}
}
}
"#;
let payload = json!({
"run_id": run_id,
"dsl": dsl,
"inputs": {},
"initiating_user": "test"
});
handler::handle_workflow_direct(
payload,
pool.clone(),
opa_client.clone(),
nats_client.clone(),
)
.await
.unwrap();
handler::handle_workflow_start_pending(
run_id,
pool.clone(),
nats_client.clone(),
Arc::new(TlsReloader::new(TlsConfig::default()).await.unwrap()),
)
.await
.unwrap();
let instances: Vec<StepInstance> = sqlx::query_as(r#"SELECT id, run_id, step_name, step_type, status as "status", iteration_index, runner_id, affinity_context, started_at, finished_at, exit_code, error, spec, params, created_at FROM step_instances WHERE run_id = $1"#)
.bind(run_id)
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(instances.len(), 1);
assert_eq!(instances[0].step_name, "generate");
assert!(
instances[0].status == StepStatus::Pending || instances[0].status == StepStatus::Running,
"Status was {:?}",
instances[0].status
);
let generate_id = instances[0].id;
let completed_payload = json!({
"run_id": run_id,
"step_id": generate_id,"event_type": "StepCompletedEvent",
"timestamp": chrono::Utc::now(),
"exit_code": 0
});
let log_backend = Arc::new(None);
handler::handle_step_completed(
serde_json::from_value(completed_payload).unwrap(),
pool.clone(),
nats_client.clone(),
log_backend.clone(),
Arc::new(TlsReloader::new(TlsConfig::default()).await.unwrap()),
)
.await
.unwrap();
let instances: Vec<StepInstance> = sqlx::query_as(r#"SELECT id, run_id, step_name, step_type, status as "status", iteration_index, runner_id, affinity_context, started_at, finished_at, exit_code, error, spec, params, created_at FROM step_instances WHERE run_id = $1 ORDER BY iteration_index ASC"#)
.bind(run_id)
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(instances.len(), 5);
let process_instances: Vec<&StepInstance> = instances
.iter()
.filter(|i| i.step_name == "process")
.collect();
assert_eq!(process_instances.len(), 4);
assert!(
process_instances[0].status == StepStatus::Pending
|| process_instances[0].status == StepStatus::Running
);
assert!(
process_instances[1].status == StepStatus::Pending
|| process_instances[1].status == StepStatus::Running
);
assert_eq!(process_instances[2].status, StepStatus::WaitingForEvent);
assert_eq!(process_instances[3].status, StepStatus::WaitingForEvent);
let completed_payload = json!({
"run_id": run_id,
"step_id": process_instances[0].id,"event_type": "StepCompletedEvent",
"timestamp": chrono::Utc::now(),
"exit_code": 0
});
handler::handle_step_completed(
serde_json::from_value(completed_payload).unwrap(),
pool.clone(),
nats_client.clone(),
log_backend.clone(),
Arc::new(TlsReloader::new(TlsConfig::default()).await.unwrap()),
)
.await
.unwrap();
let it2: StepInstance = sqlx::query_as(r#"SELECT id, run_id, step_name, step_type, status as "status", iteration_index, runner_id, affinity_context, started_at, finished_at, exit_code, error, spec, params, created_at FROM step_instances WHERE id = $1"#)
.bind(process_instances[2].id)
.fetch_one(&pool)
.await
.unwrap();
assert!(it2.status == StepStatus::Pending || it2.status == StepStatus::Running);
let it3: StepInstance = sqlx::query_as(r#"SELECT id, run_id, step_name, step_type, status as "status", iteration_index, runner_id, affinity_context, started_at, finished_at, exit_code, error, spec, params, created_at FROM step_instances WHERE id = $1"#)
.bind(process_instances[3].id)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(it3.status, StepStatus::WaitingForEvent);
}