use jamjet_core::node::NodeId;
use jamjet_core::workflow::ExecutionId;
use jamjet_ir::WorkflowIr;
use jamjet_state::backend::{StateBackend, WorkItem};
use jamjet_state::event::EventKind;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tracing::{debug, info, instrument, warn};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub poll_interval: Duration,
pub max_concurrent_nodes_per_execution: usize,
pub max_dispatch_per_tick: usize,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_millis(500),
max_concurrent_nodes_per_execution: 16,
max_dispatch_per_tick: 8,
}
}
}
pub struct Scheduler {
backend: Arc<dyn StateBackend>,
config: SchedulerConfig,
ir_cache: Mutex<HashMap<(String, String), Arc<WorkflowIr>>>,
progress: Mutex<HashMap<ExecutionId, ExecProgress>>,
}
impl Scheduler {
pub fn new(backend: Arc<dyn StateBackend>) -> Self {
Self {
backend,
config: SchedulerConfig::default(),
ir_cache: Mutex::new(HashMap::new()),
progress: Mutex::new(HashMap::new()),
}
}
pub fn with_config(mut self, config: SchedulerConfig) -> Self {
self.config = config;
self
}
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
self.config.poll_interval = interval;
self
}
pub async fn run(&self) {
info!(
"Scheduler started (poll_interval={:?}, max_concurrent={}, max_dispatch_per_tick={})",
self.config.poll_interval,
self.config.max_concurrent_nodes_per_execution,
self.config.max_dispatch_per_tick,
);
loop {
if let Err(e) = self.tick().await {
warn!("Scheduler tick error: {e}");
}
tokio::time::sleep(self.config.poll_interval).await;
}
}
#[instrument(skip(self))]
async fn tick(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.reclaim_expired_leases().await.unwrap_or_else(|e| {
warn!("Failed to reclaim expired leases: {e}");
});
let running = self
.backend
.list_executions(Some(jamjet_core::workflow::WorkflowStatus::Running), 100, 0)
.await?;
for execution in running {
self.schedule_runnable_nodes(
&execution.execution_id,
&execution.workflow_id,
&execution.workflow_version,
)
.await
.unwrap_or_else(|e| {
warn!(
execution_id = %execution.execution_id,
"Failed to schedule runnable nodes: {e}"
);
});
}
Ok(())
}
async fn reclaim_expired_leases(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let reclaimed = self.backend.reclaim_expired_leases().await?;
for item in &reclaimed.retryable {
let seq = self.backend.latest_sequence(&item.execution_id).await? + 1;
let failed_event = jamjet_state::Event::new(
item.execution_id.clone(),
seq,
jamjet_state::event::EventKind::NodeFailed {
node_id: item.node_id.clone(),
error: "lease expired: worker presumed dead".into(),
attempt: item.attempt.saturating_sub(1),
retryable: true,
},
);
self.backend.append_event(failed_event).await?;
let seq = self.backend.latest_sequence(&item.execution_id).await? + 1;
let retry_event = jamjet_state::Event::new(
item.execution_id.clone(),
seq,
jamjet_state::event::EventKind::RetryScheduled {
node_id: item.node_id.clone(),
attempt: item.attempt,
delay_ms: (1u64 << item.attempt.min(6)) * 1000,
},
);
self.backend.append_event(retry_event).await?;
warn!(
execution_id = %item.execution_id,
node_id = %item.node_id,
attempt = item.attempt,
"Lease expired — requeueing for retry"
);
}
for item in &reclaimed.exhausted {
let seq = self.backend.latest_sequence(&item.execution_id).await? + 1;
let failed_event = jamjet_state::Event::new(
item.execution_id.clone(),
seq,
jamjet_state::event::EventKind::NodeFailed {
node_id: item.node_id.clone(),
error: format!("exhausted {} attempts: lease expired", item.max_attempts),
attempt: item.attempt,
retryable: false,
},
);
self.backend.append_event(failed_event).await?;
warn!(
execution_id = %item.execution_id,
node_id = %item.node_id,
attempts = item.attempt,
"Node exhausted retries — moved to dead-letter queue"
);
}
Ok(())
}
#[instrument(skip(self), fields(execution_id = %execution_id))]
async fn schedule_runnable_nodes(
&self,
execution_id: &ExecutionId,
workflow_id: &str,
workflow_version: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let cache_key = (workflow_id.to_string(), workflow_version.to_string());
let cached = self.ir_cache.lock().unwrap().get(&cache_key).cloned();
let ir: Arc<WorkflowIr> = match cached {
Some(ir) => ir,
None => {
let def = self
.backend
.get_workflow(workflow_id, workflow_version)
.await?;
let Some(def) = def else {
warn!(%workflow_id, %workflow_version, "Workflow definition not found; cannot schedule");
return Ok(());
};
let ir = Arc::new(serde_json::from_value::<WorkflowIr>(def.ir)?);
self.ir_cache
.lock()
.unwrap()
.insert(cache_key, Arc::clone(&ir));
ir
}
};
let mut progress = self
.progress
.lock()
.unwrap()
.get(execution_id)
.cloned()
.unwrap_or_default();
let new_events = self
.backend
.get_events_since(execution_id, progress.last_sequence)
.await?;
for event in &new_events {
progress.apply(event);
}
self.progress
.lock()
.unwrap()
.insert(execution_id.clone(), progress.clone());
let completed = &progress.completed;
let scheduled = &progress.scheduled;
let terminal_failed = &progress.terminal_failed;
debug!(
execution_id = %execution_id,
completed_nodes = completed.len(),
scheduled_nodes = scheduled.len(),
terminal_failed_nodes = terminal_failed.len(),
"Checking for runnable nodes"
);
if scheduled.len() >= self.config.max_concurrent_nodes_per_execution {
debug!(
execution_id = %execution_id,
active = scheduled.len(),
limit = self.config.max_concurrent_nodes_per_execution,
"Concurrency limit reached — skipping dispatch"
);
return Ok(());
}
let mut enqueued = 0usize;
for (node_id, node) in &ir.nodes {
if enqueued >= self.config.max_dispatch_per_tick {
debug!(
execution_id = %execution_id,
dispatched = enqueued,
"Per-tick dispatch limit reached — deferring remaining nodes"
);
break;
}
if scheduled.len() + enqueued >= self.config.max_concurrent_nodes_per_execution {
break;
}
if terminal_failed.contains(node_id.as_str()) {
continue; }
if is_runnable(node_id, &ir, completed, scheduled) {
let queue_type = serde_json::to_value(node.kind.queue_type())
.ok()
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "general".to_string());
let seq = self.backend.latest_sequence(execution_id).await? + 1;
let sched_event = jamjet_state::Event::new(
execution_id.clone(),
seq,
EventKind::NodeScheduled {
node_id: node_id.clone(),
queue_type: queue_type.clone(),
},
);
self.backend.append_event(sched_event).await?;
let max_attempts: u32 = match node.retry_policy.as_deref() {
Some("no_retry") => 1,
Some("io_default") => 5,
Some("llm_default") => 3,
_ => 3,
};
let item = WorkItem {
id: Uuid::new_v4(),
execution_id: execution_id.clone(),
node_id: node_id.clone(),
queue_type,
payload: serde_json::json!({
"workflow_id": workflow_id,
"workflow_version": workflow_version,
"node_id": node_id,
}),
attempt: 0,
max_attempts,
created_at: chrono::Utc::now(),
lease_expires_at: None,
worker_id: None,
tenant_id: jamjet_state::DEFAULT_TENANT.to_string(),
};
self.backend.enqueue_work_item(item).await?;
enqueued += 1;
info!(
execution_id = %execution_id,
node_id = %node_id,
"Enqueued node for execution"
);
}
}
if enqueued > 0 {
debug!(execution_id = %execution_id, enqueued, "Dispatch complete");
}
if enqueued == 0 && scheduled.is_empty() {
let (status, event_kind) = if terminal_failed.is_empty() {
(
jamjet_core::workflow::WorkflowStatus::Completed,
EventKind::WorkflowCompleted {
final_state: serde_json::Value::Object(progress.final_state.clone()),
},
)
} else {
let mut failed: Vec<String> =
terminal_failed.iter().map(|n| n.to_string()).collect();
failed.sort();
(
jamjet_core::workflow::WorkflowStatus::Failed,
EventKind::WorkflowFailed {
error: format!("node(s) failed terminally: {}", failed.join(", ")),
},
)
};
let seq = self.backend.latest_sequence(execution_id).await? + 1;
self.backend
.append_event(jamjet_state::Event::new(
execution_id.clone(),
seq,
event_kind,
))
.await?;
info!(execution_id = %execution_id, ?status, "Execution reached terminal state");
self.backend
.update_execution_status(execution_id, status)
.await?;
self.progress.lock().unwrap().remove(execution_id);
}
Ok(())
}
}
#[derive(Default, Clone)]
struct ExecProgress {
completed: HashSet<NodeId>,
scheduled: HashSet<NodeId>,
terminal_failed: HashSet<NodeId>,
final_state: serde_json::Map<String, serde_json::Value>,
last_sequence: jamjet_state::EventSequence,
}
impl ExecProgress {
fn apply(&mut self, event: &jamjet_state::Event) {
match &event.kind {
EventKind::NodeCompleted {
node_id,
state_patch,
..
} => {
self.completed.insert(node_id.clone());
self.scheduled.remove(node_id);
if let serde_json::Value::Object(patch) = state_patch {
for (k, v) in patch {
self.final_state.insert(k.clone(), v.clone());
}
}
}
EventKind::NodeSkipped { node_id, .. } => {
self.completed.insert(node_id.clone());
self.scheduled.remove(node_id);
}
EventKind::NodeScheduled { node_id, .. } | EventKind::NodeStarted { node_id, .. } => {
self.scheduled.insert(node_id.clone());
}
EventKind::NodeCancelled { node_id } => {
self.completed.insert(node_id.clone());
self.scheduled.remove(node_id);
}
EventKind::NodeFailed {
node_id,
retryable: false,
..
} => {
self.terminal_failed.insert(node_id.clone());
self.scheduled.remove(node_id);
}
EventKind::NodeFailed {
node_id,
retryable: true,
..
} => {
self.scheduled.remove(node_id);
}
EventKind::RetryScheduled { node_id, .. } => {
self.scheduled.insert(node_id.clone());
}
_ => {}
}
self.last_sequence = self.last_sequence.max(event.sequence);
}
}
fn is_runnable(
node_id: &str,
ir: &WorkflowIr,
completed: &HashSet<NodeId>,
scheduled: &HashSet<NodeId>,
) -> bool {
if scheduled.contains(node_id) || completed.contains(node_id) {
return false;
}
ir.edges
.iter()
.filter(|e| e.to == node_id)
.all(|e| completed.contains(&e.from))
}
#[cfg(test)]
mod tests {
use super::*;
use jamjet_core::workflow::{WorkflowExecution, WorkflowStatus};
use jamjet_state::backend::WorkflowDefinition;
use jamjet_state::{Event, InMemoryBackend, DEFAULT_TENANT};
fn linear_ir() -> serde_json::Value {
let node = |id: &str| {
serde_json::json!({
"id": id,
"kind": { "type": "condition", "branches": [] },
"retry_policy": null,
"node_timeout_secs": null,
"description": null,
"labels": {}
})
};
serde_json::json!({
"workflow_id": "wf",
"version": "0.1.0",
"name": null,
"description": null,
"state_schema": "",
"start_node": "a",
"nodes": { "a": node("a"), "b": node("b") },
"edges": [
{ "from": "a", "to": "b", "condition": null },
{ "from": "b", "to": "end", "condition": null }
],
"retry_policies": {},
"timeouts": {},
"models": {},
"tools": {},
"mcp_servers": {},
"remote_agents": {},
"labels": {}
})
}
async fn setup(ir: serde_json::Value) -> (Scheduler, Arc<dyn StateBackend>, ExecutionId) {
let backend: Arc<dyn StateBackend> = Arc::new(InMemoryBackend::new());
backend
.store_workflow(WorkflowDefinition {
workflow_id: "wf".into(),
version: "0.1.0".into(),
ir,
created_at: chrono::Utc::now(),
tenant_id: DEFAULT_TENANT.into(),
})
.await
.unwrap();
let exec_id = ExecutionId::new();
let now = chrono::Utc::now();
backend
.create_execution(WorkflowExecution {
execution_id: exec_id.clone(),
workflow_id: "wf".into(),
workflow_version: "0.1.0".into(),
status: WorkflowStatus::Running,
initial_input: serde_json::json!({}),
current_state: serde_json::json!({}),
started_at: now,
updated_at: now,
completed_at: None,
session_type: None,
})
.await
.unwrap();
(Scheduler::new(backend.clone()), backend, exec_id)
}
async fn tick(s: &Scheduler, e: &ExecutionId) {
s.schedule_runnable_nodes(e, "wf", "0.1.0").await.unwrap();
}
async fn append(b: &Arc<dyn StateBackend>, e: &ExecutionId, kind: EventKind) {
let seq = b.latest_sequence(e).await.unwrap() + 1;
b.append_event(Event::new(e.clone(), seq, kind))
.await
.unwrap();
}
async fn status(b: &Arc<dyn StateBackend>, e: &ExecutionId) -> WorkflowStatus {
b.get_execution(e).await.unwrap().unwrap().status
}
fn scheduled_nodes(events: &[Event]) -> Vec<String> {
events
.iter()
.filter_map(|ev| match &ev.kind {
EventKind::NodeScheduled { node_id, .. } => Some(node_id.to_string()),
_ => None,
})
.collect()
}
fn node_completed(node_id: &str, patch: serde_json::Value) -> EventKind {
EventKind::NodeCompleted {
node_id: node_id.into(),
output: serde_json::json!({}),
state_patch: patch,
duration_ms: 1,
gen_ai_system: None,
gen_ai_model: None,
input_tokens: None,
output_tokens: None,
finish_reason: None,
cost_usd: None,
provenance: None,
}
}
#[tokio::test]
async fn schedules_in_dependency_order_and_completes() {
let (s, b, e) = setup(linear_ir()).await;
tick(&s, &e).await;
let evs = b.get_events(&e).await.unwrap();
assert!(
scheduled_nodes(&evs).contains(&"a".to_string()),
"a should be scheduled first"
);
assert!(
!scheduled_nodes(&evs).contains(&"b".to_string()),
"b must wait for its predecessor a"
);
assert_eq!(status(&b, &e).await, WorkflowStatus::Running);
append(&b, &e, node_completed("a", serde_json::json!({ "x": 1 }))).await;
tick(&s, &e).await;
let evs = b.get_events(&e).await.unwrap();
assert!(
scheduled_nodes(&evs).contains(&"b".to_string()),
"b should be scheduled once a completes"
);
assert_eq!(status(&b, &e).await, WorkflowStatus::Running);
assert!(
s.progress.lock().unwrap().contains_key(&e),
"progress should be cached while the execution is running"
);
append(&b, &e, node_completed("b", serde_json::json!({ "y": 2 }))).await;
tick(&s, &e).await;
assert_eq!(status(&b, &e).await, WorkflowStatus::Completed);
assert!(
!s.progress.lock().unwrap().contains_key(&e),
"progress should be dropped once the execution is terminal"
);
let evs = b.get_events(&e).await.unwrap();
let final_state = evs
.iter()
.find_map(|ev| match &ev.kind {
EventKind::WorkflowCompleted { final_state } => Some(final_state.clone()),
_ => None,
})
.expect("WorkflowCompleted should be emitted");
assert_eq!(final_state, serde_json::json!({ "x": 1, "y": 2 }));
}
#[tokio::test]
async fn terminal_node_failure_fails_the_workflow() {
let (s, b, e) = setup(linear_ir()).await;
tick(&s, &e).await; append(
&b,
&e,
EventKind::NodeFailed {
node_id: "a".into(),
error: "boom".into(),
attempt: 1,
retryable: false,
},
)
.await;
tick(&s, &e).await;
assert_eq!(status(&b, &e).await, WorkflowStatus::Failed);
}
}