mod error;
mod run;
pub use error::{CaughtError, ErrorKind, ExecutionError, ResourceValidationError};
pub use run::DEFAULT_SWITCH_CASE;
use crate::graph::Graph;
use crate::hooks::HooksAPI;
use crate::hooks::events::GraphEvent;
use crate::hooks::schedule::{OnGraphComplete, OnGraphFailure, OnGraphStart, OnSystemStart};
use crate::middleware::{self, MiddlewareAPI};
use crate::node::{Node, NodeId};
use hashbrown::HashSet;
use polaris_system::param::{AccessMode, SystemContext};
use polaris_system::plugin::{Schedule, ScheduleId};
use std::any::TypeId;
use std::time::Duration;
#[derive(Debug, Default)]
pub struct ExecutionResult {
pub nodes_executed: usize,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct GraphExecutor {
pub(crate) default_max_iterations: Option<usize>,
pub(crate) max_recursion_depth: usize,
}
impl Default for GraphExecutor {
fn default() -> Self {
Self::new()
}
}
impl GraphExecutor {
const DEFAULT_MAX_RECURSION_DEPTH: usize = 64;
#[must_use]
pub fn new() -> Self {
Self {
default_max_iterations: Some(1000),
max_recursion_depth: Self::DEFAULT_MAX_RECURSION_DEPTH,
}
}
#[must_use]
pub fn without_iteration_limit() -> Self {
Self {
default_max_iterations: None,
max_recursion_depth: Self::DEFAULT_MAX_RECURSION_DEPTH,
}
}
#[must_use]
pub fn with_default_max_iterations(mut self, max: usize) -> Self {
self.default_max_iterations = Some(max);
self
}
#[must_use]
pub fn with_max_recursion_depth(mut self, max: usize) -> Self {
self.max_recursion_depth = max;
self
}
pub fn validate_resources(
&self,
graph: &Graph,
ctx: &SystemContext<'_>,
hooks: Option<&HooksAPI>,
) -> Result<(), Vec<ResourceValidationError>> {
let mut errors = Vec::new();
let hook_provided: HashSet<TypeId> = hooks
.map(|h| {
let mut resources = HashSet::new();
resources.extend(h.provided_resources_for(OnGraphStart::schedule_id()));
resources.extend(h.provided_resources_for(OnSystemStart::schedule_id()));
resources
})
.unwrap_or_default();
for node in graph.nodes() {
if let Node::System(sys) = node {
let access = sys.system.access();
self.validate_system_access(
&sys.id,
sys.system.name(),
&access,
ctx,
&hook_provided,
&mut errors,
);
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
fn validate_system_access(
&self,
node_id: &NodeId,
system_name: &'static str,
access: &polaris_system::param::SystemAccess,
ctx: &SystemContext<'_>,
hook_provided: &HashSet<TypeId>,
errors: &mut Vec<ResourceValidationError>,
) {
for res_access in &access.resources {
if hook_provided.contains(&res_access.type_id) {
continue;
}
let exists = match res_access.mode {
AccessMode::Read => ctx.contains_resource_by_type_id(res_access.type_id),
AccessMode::Write => ctx.contains_local_resource_by_type_id(res_access.type_id),
};
if !exists {
errors.push(ResourceValidationError::MissingResource {
node: node_id.clone(),
system_name,
resource_type: res_access.type_name,
type_id: res_access.type_id,
access_mode: res_access.mode,
});
}
}
}
pub async fn execute(
&self,
graph: &Graph,
ctx: &mut SystemContext<'_>,
hooks: Option<&HooksAPI>,
middleware: Option<&MiddlewareAPI>,
) -> Result<ExecutionResult, ExecutionError> {
let default_mw = MiddlewareAPI::default();
let mw = middleware.unwrap_or(&default_mw);
let entry = graph.entry().ok_or(ExecutionError::EmptyGraph)?;
let node_count = graph.node_count();
let middleware_info = middleware::info::GraphInfo { node_count };
mw.inner
.graph_execution
.execute(middleware_info, ctx, |ctx| {
let entry = entry.clone();
Box::pin(self.execute_graph_body(graph, ctx, entry, node_count, hooks, mw))
})
.await
}
async fn execute_graph_body(
&self,
graph: &Graph,
ctx: &mut SystemContext<'_>,
entry: NodeId,
node_count: usize,
hooks: Option<&HooksAPI>,
middleware: &MiddlewareAPI,
) -> Result<ExecutionResult, ExecutionError> {
let start = std::time::Instant::now();
let node_map: Vec<_> = graph
.nodes()
.iter()
.map(|node| (node.id(), node.name()))
.collect();
Self::invoke_hook::<OnGraphStart>(
hooks,
ctx,
&GraphEvent::GraphStart {
node_count,
node_map,
},
);
let result = self
.execute_from(graph, ctx, entry, 0, hooks, middleware)
.await;
let duration = start.elapsed();
match result {
Ok(nodes_executed) => {
Self::invoke_hook::<OnGraphComplete>(
hooks,
ctx,
&GraphEvent::GraphComplete {
nodes_executed,
duration,
},
);
Ok(ExecutionResult {
nodes_executed,
duration,
})
}
Err(err) => {
Self::invoke_hook::<OnGraphFailure>(
hooks,
ctx,
&GraphEvent::GraphFailure { error: err.clone() },
);
Err(err)
}
}
}
pub(crate) fn invoke_hook<S: Schedule>(
hooks: Option<&HooksAPI>,
ctx: &mut SystemContext<'_>,
event: &GraphEvent,
) {
if let Some(api) = hooks {
api.invoke(S::schedule_id(), ctx, event);
}
}
pub(crate) fn invoke_custom_schedules(
hooks: Option<&HooksAPI>,
ctx: &mut SystemContext<'_>,
schedules: &[ScheduleId],
event: &GraphEvent,
) {
if let Some(api) = hooks {
for schedule in schedules {
api.invoke(*schedule, ctx, event);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn executor_creation() {
let executor = GraphExecutor::new();
assert_eq!(executor.default_max_iterations, Some(1000));
assert_eq!(executor.max_recursion_depth, 64);
}
#[test]
fn executor_without_limit() {
let executor = GraphExecutor::without_iteration_limit();
assert_eq!(executor.default_max_iterations, None);
}
#[test]
fn executor_with_custom_limit() {
let executor = GraphExecutor::new().with_default_max_iterations(500);
assert_eq!(executor.default_max_iterations, Some(500));
}
#[test]
fn executor_with_custom_recursion_depth() {
let executor = GraphExecutor::new().with_max_recursion_depth(128);
assert_eq!(executor.max_recursion_depth, 128);
}
}