mod error;
mod run;
pub use error::{CaughtError, ErrorKind, ExecutionError, ResourceValidationError};
pub use run::DEFAULT_SWITCH_CASE;
use crate::edge::Edge;
use crate::graph::Graph;
use crate::hooks::HooksAPI;
use crate::hooks::events::{GraphEvent, RunId, RunLabels};
use crate::hooks::schedule::{OnGraphComplete, OnGraphFailure, OnGraphStart, OnSystemStart};
use crate::middleware::{self, MiddlewareAPI};
use crate::node::{ContextMode, Node, NodeId};
use hashbrown::HashSet;
use polaris_system::param::{AccessMode, SystemContext};
use polaris_system::plugin::{Schedule, ScheduleId};
use std::any::{Any, TypeId};
use std::time::Duration;
use tracing::Instrument;
#[derive(Clone)]
pub(crate) struct RunContext {
pub(crate) run_id: RunId,
pub(crate) labels: RunLabels,
}
#[derive(Default)]
pub struct ExecutionResult {
pub(crate) run_id: RunId,
pub(crate) nodes_executed: usize,
pub(crate) duration: Duration,
pub(crate) final_output: Option<Box<dyn Any + Send + Sync>>,
}
impl std::fmt::Debug for ExecutionResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExecutionResult")
.field("run_id", &self.run_id)
.field("nodes_executed", &self.nodes_executed)
.field("duration", &self.duration)
.field(
"final_output",
if self.final_output.is_some() {
&"Some(<output>)"
} else {
&"None"
},
)
.finish()
}
}
impl ExecutionResult {
#[must_use]
pub fn run_id(&self) -> &RunId {
&self.run_id
}
#[must_use]
pub fn nodes_executed(&self) -> usize {
self.nodes_executed
}
#[must_use]
pub fn duration(&self) -> Duration {
self.duration
}
#[must_use]
pub fn output<T: 'static>(&self) -> Option<&T> {
self.final_output.as_ref()?.downcast_ref::<T>()
}
#[must_use]
pub fn has_output(&self) -> bool {
self.final_output.is_some()
}
}
#[derive(Debug, Clone)]
pub struct GraphExecutor {
pub(crate) default_max_iterations: Option<usize>,
pub(crate) max_recursion_depth: usize,
pub(crate) max_duration: Option<Duration>,
}
impl Default for GraphExecutor {
fn default() -> Self {
Self::new()
}
}
impl GraphExecutor {
const DEFAULT_MAX_RECURSION_DEPTH: usize = 64;
#[must_use]
pub fn new() -> Self {
Self {
default_max_iterations: Some(1000),
max_recursion_depth: Self::DEFAULT_MAX_RECURSION_DEPTH,
max_duration: None,
}
}
#[must_use]
pub fn without_iteration_limit() -> Self {
Self {
default_max_iterations: None,
max_recursion_depth: Self::DEFAULT_MAX_RECURSION_DEPTH,
max_duration: None,
}
}
#[must_use]
pub fn with_default_max_iterations(mut self, max: usize) -> Self {
self.default_max_iterations = Some(max);
self
}
#[must_use]
pub fn with_max_recursion_depth(mut self, max: usize) -> Self {
self.max_recursion_depth = max;
self
}
#[must_use]
pub fn with_max_duration(mut self, duration: Duration) -> Self {
self.max_duration = Some(duration);
self
}
pub fn validate_resources(
&self,
graph: &Graph,
ctx: &SystemContext<'_>,
hooks: Option<&HooksAPI>,
) -> Result<(), Vec<ResourceValidationError>> {
let mut errors = Vec::new();
let hook_provided: HashSet<TypeId> = hooks
.map(|h| {
let mut resources = HashSet::new();
resources.extend(h.provided_resources_for(OnGraphStart::schedule_id()));
resources.extend(h.provided_resources_for(OnSystemStart::schedule_id()));
resources
})
.unwrap_or_default();
self.validate_graph_resources(graph, ctx, &hook_provided, &mut errors, 0);
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
fn validate_graph_resources(
&self,
graph: &Graph,
ctx: &SystemContext<'_>,
hook_provided: &HashSet<TypeId>,
errors: &mut Vec<ResourceValidationError>,
depth: usize,
) {
if depth > self.max_recursion_depth {
return;
}
for node in graph.nodes() {
match node {
Node::System(sys) => {
let access = sys.system.access();
self.validate_system_access(
&sys.id,
sys.system.name(),
&access,
ctx,
hook_provided,
errors,
);
}
Node::Scope(scope) => {
match scope.context_policy.mode {
ContextMode::Shared => {
self.validate_graph_resources(
&scope.graph,
ctx,
hook_provided,
errors,
depth + 1,
);
}
ContextMode::Inherit => {
let mut child = ctx.child();
for fwd in &scope.context_policy.forward_resources {
child.insert_boxed(fwd.type_id, Box::new(()));
}
self.validate_graph_resources(
&scope.graph,
&child,
hook_provided,
errors,
depth + 1,
);
}
ContextMode::Isolated => {
let mut child = match ctx.globals_arc() {
Some(globals) => SystemContext::with_globals(globals),
None => SystemContext::new(),
};
for fwd in &scope.context_policy.forward_resources {
child.insert_boxed(fwd.type_id, Box::new(()));
}
self.validate_graph_resources(
&scope.graph,
&child,
hook_provided,
errors,
depth + 1,
);
}
}
}
_ => {}
}
}
self.validate_output_reachability(graph, hook_provided, errors);
}
fn validate_output_reachability(
&self,
graph: &Graph,
hook_provided: &HashSet<TypeId>,
errors: &mut Vec<ResourceValidationError>,
) {
let chain = self.build_linear_chain(graph);
let mut produced_outputs: HashSet<TypeId> = HashSet::new();
for node_id in &chain {
let Some(node) = graph.get_node(node_id.clone()) else {
continue;
};
match node {
Node::System(sys) => {
let access = sys.system.access();
for out_access in &access.outputs {
if !produced_outputs.contains(&out_access.type_id)
&& !hook_provided.contains(&out_access.type_id)
{
errors.push(ResourceValidationError::MissingOutput {
node: sys.id.clone(),
system_name: sys.system.name(),
output_type: out_access.type_name,
type_id: out_access.type_id,
});
}
}
produced_outputs.insert(sys.system.output_type_id());
}
Node::Decision(dec) => {
for branch in [&dec.true_branch, &dec.false_branch].into_iter().flatten() {
for (type_id, _) in graph.collect_branch_output_types(branch) {
produced_outputs.insert(type_id);
}
}
}
Node::Switch(sw) => {
for (_, target) in &sw.cases {
for (type_id, _) in graph.collect_branch_output_types(target) {
produced_outputs.insert(type_id);
}
}
if let Some(default) = &sw.default {
for (type_id, _) in graph.collect_branch_output_types(default) {
produced_outputs.insert(type_id);
}
}
}
Node::Loop(lp) => {
if let Some(body) = &lp.body_entry {
for (type_id, _) in graph.collect_branch_output_types(body) {
produced_outputs.insert(type_id);
}
}
}
Node::Parallel(par) => {
for branch in &par.branches {
for (type_id, _) in graph.collect_branch_output_types(branch) {
produced_outputs.insert(type_id);
}
}
}
Node::Scope(_) => {}
}
}
}
fn build_linear_chain(&self, graph: &Graph) -> Vec<NodeId> {
let mut chain = Vec::new();
let Some(entry) = graph.entry() else {
return chain;
};
let seq_edges: hashbrown::HashMap<NodeId, NodeId> = graph
.edges()
.iter()
.filter_map(|edge| {
if let Edge::Sequential(seq) = edge {
Some((seq.from.clone(), seq.to.clone()))
} else {
None
}
})
.collect();
let mut current = Some(entry);
let mut visited = HashSet::new();
while let Some(node_id) = current {
if !visited.insert(node_id.clone()) {
break;
}
chain.push(node_id.clone());
current = seq_edges.get(&node_id).cloned();
}
chain
}
fn validate_system_access(
&self,
node_id: &NodeId,
system_name: &'static str,
access: &polaris_system::param::SystemAccess,
ctx: &SystemContext<'_>,
hook_provided: &HashSet<TypeId>,
errors: &mut Vec<ResourceValidationError>,
) {
for res_access in &access.resources {
if hook_provided.contains(&res_access.type_id) {
continue;
}
let exists = match res_access.mode {
AccessMode::Read => ctx.contains_resource_by_type_id(res_access.type_id),
AccessMode::Write => ctx.contains_local_resource_by_type_id(res_access.type_id),
};
if !exists {
errors.push(ResourceValidationError::MissingResource {
node: node_id.clone(),
system_name,
resource_type: res_access.type_name,
type_id: res_access.type_id,
access_mode: res_access.mode,
});
}
}
}
pub async fn execute(
&self,
graph: &Graph,
ctx: &mut SystemContext<'_>,
hooks: Option<&HooksAPI>,
middleware: Option<&MiddlewareAPI>,
) -> Result<ExecutionResult, ExecutionError> {
self.execute_with_labels(graph, ctx, hooks, middleware, RunLabels::empty())
.await
}
pub async fn execute_with_labels(
&self,
graph: &Graph,
ctx: &mut SystemContext<'_>,
hooks: Option<&HooksAPI>,
middleware: Option<&MiddlewareAPI>,
labels: RunLabels,
) -> Result<ExecutionResult, ExecutionError> {
let default_mw = MiddlewareAPI::default();
let mw = middleware.unwrap_or(&default_mw);
let entry = graph.entry().ok_or(ExecutionError::EmptyGraph)?;
let node_count = graph.node_count();
let run_ctx = RunContext {
run_id: RunId::new(),
labels,
};
let run_span =
tracing::info_span!("polaris.run", polaris.run.id = run_ctx.run_id.as_str(),);
let middleware_info = middleware::info::GraphInfo { node_count };
mw.inner
.graph_execution
.execute(middleware_info, ctx, |ctx| {
let entry = entry.clone();
let run_ctx = run_ctx.clone();
Box::pin(self.execute_graph_body(graph, ctx, entry, node_count, hooks, mw, run_ctx))
})
.instrument(run_span)
.await
}
async fn execute_graph_body(
&self,
graph: &Graph,
ctx: &mut SystemContext<'_>,
entry: NodeId,
node_count: usize,
hooks: Option<&HooksAPI>,
middleware: &MiddlewareAPI,
run_ctx: RunContext,
) -> Result<ExecutionResult, ExecutionError> {
let start = std::time::Instant::now();
let node_map: Vec<_> = graph
.nodes()
.iter()
.map(|node| (node.id(), node.name()))
.collect();
Self::invoke_hook::<OnGraphStart>(
hooks,
ctx,
&GraphEvent::GraphStart {
run_id: run_ctx.run_id.clone(),
labels: run_ctx.labels.clone(),
node_count,
node_map,
},
);
let effective_timeout = graph.max_duration.or(self.max_duration);
let result = if let Some(max) = effective_timeout {
match tokio::time::timeout(
max,
self.execute_from(graph, ctx, entry, 0, hooks, middleware, &run_ctx),
)
.await
{
Ok(inner) => inner,
Err(_timeout) => {
let elapsed = start.elapsed();
Err(ExecutionError::GraphTimeout { elapsed, max })
}
}
} else {
self.execute_from(graph, ctx, entry, 0, hooks, middleware, &run_ctx)
.await
};
let duration = start.elapsed();
match result {
Ok(nodes_executed) => {
let final_output = ctx.outputs_mut().take_last();
Self::invoke_hook::<OnGraphComplete>(
hooks,
ctx,
&GraphEvent::GraphComplete {
run_id: run_ctx.run_id.clone(),
labels: run_ctx.labels.clone(),
nodes_executed,
duration,
},
);
Ok(ExecutionResult {
run_id: run_ctx.run_id,
nodes_executed,
duration,
final_output,
})
}
Err(err) => {
Self::invoke_hook::<OnGraphFailure>(
hooks,
ctx,
&GraphEvent::GraphFailure {
run_id: run_ctx.run_id.clone(),
labels: run_ctx.labels.clone(),
error: err.clone(),
},
);
Err(err)
}
}
}
pub(crate) fn invoke_hook<S: Schedule>(
hooks: Option<&HooksAPI>,
ctx: &mut SystemContext<'_>,
event: &GraphEvent,
) {
if let Some(api) = hooks {
api.invoke(S::schedule_id(), ctx, event);
}
}
pub(crate) fn invoke_custom_schedules(
hooks: Option<&HooksAPI>,
ctx: &mut SystemContext<'_>,
schedules: &[ScheduleId],
event: &GraphEvent,
) {
if let Some(api) = hooks {
for schedule in schedules {
api.invoke(*schedule, ctx, event);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn executor_creation() {
let executor = GraphExecutor::new();
assert_eq!(executor.default_max_iterations, Some(1000));
assert_eq!(executor.max_recursion_depth, 64);
assert_eq!(executor.max_duration, None);
}
#[test]
fn executor_without_limit() {
let executor = GraphExecutor::without_iteration_limit();
assert_eq!(executor.default_max_iterations, None);
assert_eq!(executor.max_duration, None);
}
#[test]
fn executor_with_custom_limit() {
let executor = GraphExecutor::new().with_default_max_iterations(500);
assert_eq!(executor.default_max_iterations, Some(500));
}
#[test]
fn executor_with_custom_recursion_depth() {
let executor = GraphExecutor::new().with_max_recursion_depth(128);
assert_eq!(executor.max_recursion_depth, 128);
}
#[test]
fn executor_with_max_duration() {
let executor = GraphExecutor::new().with_max_duration(Duration::from_secs(30));
assert_eq!(executor.max_duration, Some(Duration::from_secs(30)));
}
#[test]
fn graph_max_duration() {
let mut graph = Graph::new();
assert_eq!(graph.max_duration(), None);
graph.with_max_duration(Duration::from_secs(10));
assert_eq!(graph.max_duration(), Some(Duration::from_secs(10)));
}
#[test]
fn graph_max_duration_chains() {
async fn step() {}
let mut graph = Graph::new();
graph
.with_max_duration(Duration::from_secs(10))
.add_system(step);
assert_eq!(graph.max_duration(), Some(Duration::from_secs(10)));
assert_eq!(graph.node_count(), 1);
}
}