use std::{
ops::DerefMut,
sync::{Arc, Mutex},
};
use rand::{Rng, distr::Uniform, prelude::Distribution};
use super::sorted::{pick_max, pick_sorted};
use super::{ScoredSequence, Scorer, result_and, result_forced, result_last, result_or};
use crate as bevior_tree;
use crate::node::prelude::*;
pub mod prelude {
pub use super::{
RandomForcedSelector, RandomOrderedForcedSequence, RandomOrderedSequentialAnd,
RandomOrderedSequentialOr, pick_random_one, pick_random_sorted,
};
}
pub fn pick_random_sorted(scores: Vec<f32>, rng: &mut impl Rng) -> Vec<usize> {
let dist = Uniform::<f32>::new(0.0, 1.0).expect("Failed to init uniform distribution.");
let scores = scores
.into_iter()
.map(|score| dist.sample(rng).powf(1.0 / score))
.collect();
pick_sorted(scores)
}
pub fn pick_random_one(scores: Vec<f32>, rng: &mut impl Rng) -> Vec<usize> {
let dist = Uniform::<f32>::new(0.0, 1.0).expect("Failed to init uniform distribution.");
let scores = scores
.into_iter()
.map(|score| dist.sample(rng).powf(1.0 / score))
.collect();
pick_max(scores)
}
#[delegate_node(delegate)]
pub struct RandomOrderedSequentialAnd {
delegate: ScoredSequence,
}
impl RandomOrderedSequentialAnd {
pub fn new<R>(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>, rng: Arc<Mutex<R>>) -> Self
where
R: Rng + 'static + Send + Sync,
{
Self {
delegate: ScoredSequence::new(
nodes,
move |scores| pick_random_sorted(scores, (&mut rng.lock().unwrap()).deref_mut()),
result_and,
),
}
}
}
#[delegate_node(delegate)]
pub struct RandomOrderedSequentialOr {
delegate: ScoredSequence,
}
impl RandomOrderedSequentialOr {
pub fn new<R>(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>, rng: Arc<Mutex<R>>) -> Self
where
R: Rng + 'static + Send + Sync,
{
Self {
delegate: ScoredSequence::new(
nodes,
move |scores| pick_random_sorted(scores, (&mut rng.lock().unwrap()).deref_mut()),
result_or,
),
}
}
}
#[delegate_node(delegate)]
pub struct RandomOrderedForcedSequence {
delegate: ScoredSequence,
}
impl RandomOrderedForcedSequence {
pub fn new<R>(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>, rng: Arc<Mutex<R>>) -> Self
where
R: Rng + 'static + Send + Sync,
{
Self {
delegate: ScoredSequence::new(
nodes,
move |scores| pick_random_sorted(scores, (&mut rng.lock().unwrap()).deref_mut()),
result_last,
),
}
}
}
#[delegate_node(delegate)]
pub struct RandomForcedSelector {
delegate: ScoredSequence,
}
impl RandomForcedSelector {
pub fn new<R>(nodes: Vec<(Box<dyn Node>, Mutex<Box<dyn Scorer>>)>, rng: Arc<Mutex<R>>) -> Self
where
R: Rng + 'static + Send + Sync,
{
Self {
delegate: ScoredSequence::new(
nodes,
move |scores| pick_random_one(scores, (&mut rng.lock().unwrap()).deref_mut()),
result_forced,
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tester_util::prelude::*;
use rand::SeedableRng;
#[test]
fn test_random_ordered_sequential_and() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = RandomOrderedSequentialAnd::new(
vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Success), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Success), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
],
Arc::new(Mutex::new(rand::rngs::StdRng::seed_from_u64(224))),
);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 3,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(
found == &expected,
"RandomOrderedSequentialAnd should match result. found: {:?}",
found
);
}
#[test]
fn test_random_ordered_sequential_or() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = RandomOrderedSequentialOr::new(
vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Failure), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Success), |In(_)| 0.4),
],
Arc::new(Mutex::new(rand::rngs::StdRng::seed_from_u64(224))),
);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 3,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(
found == &expected,
"RandomOrderedSequentialOr should match result. found: {:?}",
found
);
}
#[test]
fn test_random_ordered_forced_sequence() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = RandomOrderedForcedSequence::new(
vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
],
Arc::new(Mutex::new(rand::rngs::StdRng::seed_from_u64(224))),
);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); app.update(); app.update(); app.update(); let expected = TestLog {
log: vec![
TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 1,
},
TestLogEntry {
task_id: 2,
updated_count: 0,
frame: 2,
},
TestLogEntry {
task_id: 3,
updated_count: 0,
frame: 3,
},
TestLogEntry {
task_id: 0,
updated_count: 0,
frame: 4,
},
],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(
found == &expected,
"RandomOrderedForcedSequence should match result. found: {:?}",
found
);
}
#[test]
fn test_random_forced_selector() {
let mut app = App::new();
app.add_plugins((BehaviorTreePlugin::default(), TesterPlugin));
let sequence = RandomForcedSelector::new(
vec![
pair_node_scorer_fn(TesterTask::<0>::new(1, NodeResult::Failure), |In(_)| 0.1),
pair_node_scorer_fn(TesterTask::<1>::new(1, NodeResult::Failure), |In(_)| 0.3),
pair_node_scorer_fn(TesterTask::<2>::new(1, NodeResult::Success), |In(_)| 0.2),
pair_node_scorer_fn(TesterTask::<3>::new(1, NodeResult::Failure), |In(_)| 0.4),
],
Arc::new(Mutex::new(rand::rngs::StdRng::seed_from_u64(224))),
);
let _entity = app
.world_mut()
.spawn(BehaviorTreeBundle::from_root(sequence))
.id();
app.update();
app.update(); app.update(); let expected = TestLog {
log: vec![TestLogEntry {
task_id: 1,
updated_count: 0,
frame: 1,
}],
};
let found = app.world().get_resource::<TestLog>().unwrap();
assert!(
found == &expected,
"RandomForcedSelector should match result. found: {:?}",
found
);
}
}