use crate::{error::HtnErr, HtnStateTrait};
use super::*;
use bevy::prelude::*;
#[derive(Debug, Reflect, Clone, Default)]
pub struct HtnSchema {
pub version: String,
}
#[derive(Debug, Reflect, Clone)]
pub struct HTN<T: HtnStateTrait> {
pub tasks: Vec<Task<T>>,
pub schema: HtnSchema,
}
impl<T: HtnStateTrait> HTN<T> {
pub fn builder() -> HTNBuilder<T> {
HTNBuilder {
tasks: Vec::new(),
schema: HtnSchema::default(),
}
}
pub fn version(&self) -> &str {
&self.schema.version
}
pub fn get_task_by_name(&self, name: &str) -> Option<&Task<T>> {
self.tasks.iter().find(|task| match task {
Task::Primitive(primitive) => primitive.name == name,
Task::Compound(compound) => compound.name == name,
})
}
pub fn root_task(&self) -> &Task<T> {
self.tasks.first().expect("No root task found")
}
pub fn verify_all(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
self.verify_conditions(state, atr)?;
self.verify_effects(state, atr)?;
self.verify_operators(state, atr)?;
Ok(())
}
pub fn verify_without_operators(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
self.verify_conditions(state, atr)?;
self.verify_effects(state, atr)?;
Ok(())
}
pub fn verify_operators(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
for task in self.tasks.iter() {
match task {
Task::Primitive(primitive) => primitive.verify_operator(state, atr)?,
Task::Compound(_) => continue,
}
}
Ok(())
}
pub fn verify_effects(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
for task in self.tasks.iter() {
debug!("Verifying effects for task: {}", task.name());
task.verify_effects(state, atr)?;
}
Ok(())
}
pub fn verify_conditions(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
for task in self.tasks.iter() {
debug!("Verifying conditions for task: {}", task.name());
task.verify_conditions(state, atr)?;
}
Ok(())
}
}
pub struct HTNBuilder<T: HtnStateTrait> {
tasks: Vec<Task<T>>,
schema: HtnSchema,
}
impl<T: HtnStateTrait> HTNBuilder<T> {
pub fn primitive_task(mut self, task: PrimitiveTask<T>) -> Self {
self.tasks.push(Task::Primitive(task));
self
}
pub fn compound_task(mut self, task: CompoundTask<T>) -> Self {
self.tasks.push(Task::Compound(task));
self
}
pub fn schema(mut self, meta: HtnSchema) -> Self {
self.schema = meta;
self
}
pub fn verify_operators(self, state: &T, atr: &AppTypeRegistry) -> Result<Self, HtnErr> {
for task in self.tasks.iter() {
match task {
Task::Primitive(primitive) => primitive.verify_operator(state, atr)?,
Task::Compound(_) => continue,
}
}
Ok(self)
}
pub fn build(self) -> HTN<T> {
HTN {
tasks: self.tasks,
schema: self.schema,
}
}
}
#[derive(Clone, Debug, Reflect)]
pub enum Task<T: HtnStateTrait> {
Primitive(PrimitiveTask<T>),
Compound(CompoundTask<T>),
}
impl<T: HtnStateTrait> Task<T> {
pub fn name(&self) -> &str {
match self {
Task::Primitive(primitive) => &primitive.name,
Task::Compound(compound) => &compound.name,
}
}
pub fn verify_effects(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
match self {
Task::Primitive(primitive) => primitive.verify_effects(state, atr),
Task::Compound(_compound) => Ok(()),
}
}
pub fn verify_conditions(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
match self {
Task::Primitive(primitive) => primitive.verify_conditions(state, atr),
Task::Compound(compound) => compound.verify_conditions(state, atr),
}
}
}