use std::{cmp::Reverse, sync::Mutex};
use ordered_float::OrderedFloat;
use super::{ScoredSequence, Scorer, result_and, result_forced, result_last, result_or};
use crate as bevior_tree;
use crate::node::prelude::*;
pub mod prelude {
pub use super::{
ScoreOrderedForcedSequence, ScoreOrderedSequentialAnd, ScoreOrderedSequentialOr,
ScoredForcedSelector, pick_max, pick_sorted,
};
}
pub fn pick_sorted(scores: Vec<f32>) -> Vec<usize> {
let mut enumerated: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
enumerated.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));
enumerated.into_iter().map(|(index, _)| index).collect()
}
pub fn pick_max(scores: Vec<f32>) -> Vec<usize> {
scores
.into_iter()
.enumerate()
.max_by_key(|(_, score)| OrderedFloat(*score))
.map(|(index, _)| index)
.into_iter()
.collect()
}
#[delegate_node(delegate)]
pub struct ScoreOrderedSequentialAnd {
delegate: ScoredSequence,
}
impl ScoreOrderedSequentialAnd {
pub fn new(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>) -> Self {
Self {
delegate: ScoredSequence::new(nodes, pick_sorted, result_and),
}
}
}
#[delegate_node(delegate)]
pub struct ScoreOrderedSequentialOr {
delegate: ScoredSequence,
}
impl ScoreOrderedSequentialOr {
pub fn new(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>) -> Self {
Self {
delegate: ScoredSequence::new(nodes, pick_sorted, result_or),
}
}
}
#[delegate_node(delegate)]
pub struct ScoreOrderedForcedSequence {
delegate: ScoredSequence,
}
impl ScoreOrderedForcedSequence {
pub fn new(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>) -> Self {
Self {
delegate: ScoredSequence::new(nodes, pick_sorted, result_last),
}
}
}
#[delegate_node(delegate)]
pub struct ScoredForcedSelector {
delegate: ScoredSequence,
}
impl ScoredForcedSelector {
pub fn new(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>) -> Self {
Self {
delegate: ScoredSequence::new(nodes, pick_max, result_forced),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tester_util::prelude::*;
#[test]
fn test_score_ordered_sequential_and() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = ScoreOrderedSequentialAnd::new(vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Success), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Success), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Failure), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Success), |In(_)| 0.4),
]);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 3,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(found == &expected, "Result mismatch. found: {:?}", found);
}
#[test]
fn test_score_ordered_sequential_or() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = ScoreOrderedSequentialOr::new(vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
]);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 3,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(found == &expected, "Result mismatch. found: {:?}", found);
}
#[test]
fn test_score_ordered_forced_sequence() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = ScoreOrderedForcedSequence::new(vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
]);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 3,
},
TestLogEntry {
task_id: 0,
updated_count: 0,
frame: 4,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(found == &expected, "Result mismatch. found: {:?}", found);
}
#[test]
fn test_score_ordered_forced_selector() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = ScoredForcedSelector::new(vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
]);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); let expected = TestLog {
log: vec![TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 1,
}],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(found == &expected, "Result mismatch. found: {:?}", found);
}
}