Skip to main content

bevy_htn/htn/
task_compound.rs

1use crate::{error::HtnErr, HtnStateTrait};
2
3use super::*;
4use bevy::prelude::*;
5use std::marker::PhantomData;
6
7#[derive(Clone, Debug, Reflect)]
8pub struct Method<T: Reflect> {
9    pub name: Option<String>,
10    pub preconditions: Vec<HtnCondition>,
11    pub subtasks: Vec<String>, // Just the task names now
12    _phantom: PhantomData<T>,
13}
14
15#[derive(Clone, Debug, Reflect)]
16pub struct CompoundTask<T: HtnStateTrait> {
17    pub name: String,
18    pub methods: Vec<Method<T>>,
19    _phantom: PhantomData<T>,
20}
21
22impl<T: HtnStateTrait> CompoundTask<T> {
23    /// Finds the first method with passing preconditions, skipping the first `skip` methods.
24    pub fn find_method(
25        &self,
26        state: &T,
27        skip: usize,
28        atr: &AppTypeRegistry,
29    ) -> Option<(&Method<T>, usize)> {
30        self.methods
31            .iter()
32            .enumerate()
33            .skip(skip)
34            .find(|(_, method)| {
35                method
36                    .preconditions
37                    .iter()
38                    .all(|cond| cond.evaluate(state, atr))
39            })
40            .map(|(i, method)| (method, i))
41    }
42    pub fn verify_conditions(&self, state: &T, atr: &AppTypeRegistry) -> Result<(), HtnErr> {
43        for method in self.methods.iter() {
44            for cond in method.preconditions.iter() {
45                cond.verify_types(state, atr)?;
46            }
47        }
48        Ok(())
49    }
50}
51
52pub struct CompoundTaskBuilder<T: HtnStateTrait> {
53    name: String,
54    methods: Vec<Method<T>>,
55    _phantom: PhantomData<T>,
56}
57
58impl<T: HtnStateTrait> CompoundTaskBuilder<T> {
59    pub fn new(name: impl Into<String>) -> Self {
60        CompoundTaskBuilder {
61            name: name.into(),
62            methods: Vec::new(),
63            _phantom: PhantomData,
64        }
65    }
66
67    pub fn method(mut self, method: Method<T>) -> Self {
68        self.methods.push(method);
69        self
70    }
71
72    pub fn build(self) -> CompoundTask<T> {
73        CompoundTask {
74            name: self.name,
75            methods: self.methods,
76            _phantom: PhantomData,
77        }
78    }
79}
80
81// Add this for building methods
82pub struct MethodBuilder<T: Reflect> {
83    preconditions: Vec<HtnCondition>,
84    subtasks: Vec<String>, // Just task names, not the actual tasks
85    name: Option<String>,
86    _phantom: PhantomData<T>,
87}
88
89impl<T: Reflect> MethodBuilder<T> {
90    #[allow(clippy::new_without_default)]
91    pub fn new() -> Self {
92        MethodBuilder {
93            preconditions: Vec::new(),
94            subtasks: Vec::new(),
95            name: None,
96            _phantom: PhantomData,
97        }
98    }
99
100    pub fn name(mut self, name: String) -> Self {
101        self.name = Some(name);
102        self
103    }
104
105    pub fn precondition(mut self, cond: HtnCondition) -> Self {
106        self.preconditions.push(cond);
107        self
108    }
109
110    pub fn subtask(mut self, task_name: impl Into<String>) -> Self {
111        self.subtasks.push(task_name.into());
112        self
113    }
114
115    pub fn build(self) -> Method<T> {
116        Method {
117            preconditions: self.preconditions,
118            subtasks: self.subtasks,
119            name: self.name,
120            _phantom: PhantomData,
121        }
122    }
123}