use crate::{
self as behaviortree, Control,
behavior::{Behavior, BehaviorData, BehaviorError, BehaviorResult, BehaviorState},
input_port,
port::PortList,
port_list,
tree::BehaviorTreeElementList,
};
use alloc::boxed::Box;
use alloc::collections::btree_set::BTreeSet;
use tinyscript::SharedRuntime;
#[derive(Control, Debug, Default)]
#[behavior(groot2)]
pub struct Parallel {
success_count: i32,
failure_count: i32,
completed_list: BTreeSet<usize>,
}
const SUCCESS_COUNT: &str = "success_count";
const FAILURE_COUNT: &str = "failure_count";
#[async_trait::async_trait]
impl Behavior for Parallel {
fn on_halt(&mut self) -> Result<(), BehaviorError> {
self.completed_list.clear();
self.success_count = 0;
self.failure_count = 0;
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_possible_wrap)]
fn on_start(
&mut self,
behavior: &mut BehaviorData,
children: &mut BehaviorTreeElementList,
_runtime: &SharedRuntime,
) -> Result<(), BehaviorError> {
let success_threshold = behavior.get(SUCCESS_COUNT).unwrap_or(-1);
let failure_threshold = behavior.get(FAILURE_COUNT).unwrap_or(-1);
let children_count = children.len();
if (children_count as i32) < success_threshold {
return Err(BehaviorError::Composition {
txt: "Number of children is less than the threshold. Can never succeed.".into(),
});
}
if (children_count as i32) < failure_threshold {
return Err(BehaviorError::Composition {
txt: "Number of children is less than the threshold. Can never fail.".into(),
});
}
behavior.set_state(BehaviorState::Running);
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_possible_wrap)]
#[allow(clippy::set_contains_or_insert)]
async fn tick(
&mut self,
behavior: &mut BehaviorData,
children: &mut BehaviorTreeElementList,
runtime: &SharedRuntime,
) -> BehaviorResult {
let success_threshold = behavior.get(SUCCESS_COUNT).unwrap_or(-1);
let failure_threshold = behavior.get(FAILURE_COUNT).unwrap_or(-1);
let children_count = children.len();
let mut skipped_count = 0;
for i in 0..children_count {
if !self.completed_list.contains(&i) {
let child = &mut children[i];
match child.tick(runtime).await? {
BehaviorState::Skipped => skipped_count += 1,
BehaviorState::Success => {
self.completed_list.insert(i);
self.success_count += 1;
}
BehaviorState::Failure => {
self.completed_list.insert(i);
self.failure_count += 1;
}
BehaviorState::Running => {}
BehaviorState::Idle => {
return Err(BehaviorError::State {
behavior: "Parallel".into(),
state: BehaviorState::Idle,
});
}
}
}
let sum = self.failure_count + self.success_count + skipped_count;
if sum >= children_count as i32 {
let state = if skipped_count == children_count as i32 {
BehaviorState::Skipped
} else if failure_threshold <= 0 && success_threshold <= 0 {
BehaviorState::Success
} else if failure_threshold <= 0 {
if self.success_count >= success_threshold {
BehaviorState::Success
} else {
BehaviorState::Failure
}
} else if (self.failure_count > failure_threshold) || (self.success_count < success_threshold) {
BehaviorState::Failure
} else {
BehaviorState::Success
};
self.completed_list.clear();
self.success_count = 0;
self.failure_count = 0;
children.halt(runtime)?;
return Ok(state);
}
}
Ok(BehaviorState::Running)
}
fn provided_ports() -> PortList {
port_list![
input_port!(i32, SUCCESS_COUNT),
input_port!(i32, FAILURE_COUNT)
]
}
}