cornerstone 0.1.2

A behavior tree library
Documentation
use std::collections::HashSet;

use crate::{Context, NodeStatus, TreeNode};

pub trait ControlNode {
    fn add_child(&mut self, node: Box<dyn TreeNode>);
}

#[derive(Default)]
pub struct ControlNodeHandle {
    child_nodes: Vec<Box<dyn TreeNode>>,
}

impl ControlNode for ControlNodeHandle {
    fn add_child(&mut self, node: Box<dyn TreeNode>) {
        self.child_nodes.push(node);
    }
}

#[derive(Default)]
pub struct SequenceNode {
    current_child_idx: usize,
    handle: ControlNodeHandle,
}

impl ControlNode for SequenceNode {
    fn add_child(&mut self, node: Box<dyn TreeNode>) {
        self.handle.add_child(node);
    }
}

impl TreeNode for SequenceNode {
    fn tick(&mut self, ctx: &mut Context) -> NodeStatus {
        let from = self.current_child_idx;

        for node in self.handle.child_nodes.iter_mut().skip(from) {
            match node.tick(ctx) {
                NodeStatus::Failure => {
                    return NodeStatus::Failure;
                }
                NodeStatus::Running => {
                    return NodeStatus::Running;
                }
                NodeStatus::Success => (),
            }
        }

        NodeStatus::Success
    }
}

pub struct ParallelNode {
    success_threshold: Option<usize>,
    failure_threshold: Option<usize>,
    success_count: usize,
    failure_count: usize,
    completed_list: HashSet<usize>,
    handle: ControlNodeHandle,
}

impl ParallelNode {
    pub fn new(success_threshold: Option<usize>, failure_threshold: Option<usize>) -> Self {
        Self {
            success_threshold,
            failure_threshold,
            success_count: 0,
            failure_count: 0,
            completed_list: HashSet::new(),
            handle: ControlNodeHandle::default(),
        }
    }
}

impl ControlNode for ParallelNode {
    fn add_child(&mut self, node: Box<dyn TreeNode>) {
        self.handle.add_child(node);
    }
}

impl TreeNode for ParallelNode {
    fn tick(&mut self, ctx: &mut Context) -> NodeStatus {
        let children_count = self.handle.child_nodes.len();

        if children_count == 0 {
            return NodeStatus::Failure;
        }

        for i in 0..children_count {
            if self.completed_list.contains(&i) {
                continue;
            }

            let node = &mut self.handle.child_nodes[i];

            match node.tick(ctx) {
                NodeStatus::Failure => {
                    self.failure_count += 1;
                }
                NodeStatus::Success => {
                    self.success_count += 1;
                }
                NodeStatus::Running => continue,
            }

            self.completed_list.insert(i);

            if self.success_count >= self.success_threshold.unwrap_or(children_count) {
                return NodeStatus::Success;
            }

            if self.failure_count >= self.failure_threshold.unwrap_or(children_count) {
                return NodeStatus::Failure;
            }
        }

        NodeStatus::Running
    }
}