cornerstone 0.3.1

A behavior tree library
Documentation
use std::collections::{HashMap, HashSet};

use crate::{NodeStatus, TreeNode, TreeNodeWrapper};

use super::DataProxy;

pub trait CompositeNodeImpl: Send {
    fn tick_status(
        &mut self,
        data_proxy: &mut DataProxy,
        child_nodes: &mut Vec<TreeNodeWrapper>,
    ) -> NodeStatus;
    fn node_info(&self) -> String {
        format!("{}", std::any::type_name::<Self>())
    }
    fn reset_state(&mut self);
}

pub struct CompositeWrapper {
    pub data_proxy: DataProxy,
    node_wrapper: Box<dyn CompositeNodeImpl>,
    pub child_nodes: Vec<TreeNodeWrapper>,
}

impl CompositeWrapper {
    pub fn new(data_proxy: DataProxy, node_wrapper: Box<dyn CompositeNodeImpl>) -> Self {
        Self {
            data_proxy,
            node_wrapper,
            child_nodes: vec![],
        }
    }

    pub fn add_child(&mut self, node: TreeNodeWrapper) {
        self.child_nodes.push(node);
    }

    pub fn reset_children(&mut self) {
        for child_node in &mut self.child_nodes {
            if child_node.status() == NodeStatus::Running {
                child_node.halt();
            }
            child_node.reset_status();
        }
    }
}

impl TreeNode for CompositeWrapper {
    fn tick(&mut self) -> NodeStatus {
        if self.data_proxy.status() == NodeStatus::Idle {
            self.data_proxy.set_status(NodeStatus::Running);
        }

        let tick_status = self
            .node_wrapper
            .tick_status(&mut self.data_proxy, &mut self.child_nodes);

        if tick_status.is_completed() {
            self.halt();
        }

        self.data_proxy.set_status(tick_status);

        tick_status
    }

    fn halt(&mut self) {
        tracing::debug!("halt self: {}", std::any::type_name::<Self>());
        self.node_wrapper.reset_state();
        self.reset_children();
    }
}

#[derive(Default)]
pub struct Sequence {
    current_child_idx: usize,
}

impl CompositeNodeImpl for Sequence {
    fn tick_status(
        &mut self,
        data_proxy: &mut DataProxy,
        child_nodes: &mut Vec<TreeNodeWrapper>,
    ) -> NodeStatus {
        let from = self.current_child_idx;

        for node in child_nodes.iter_mut().skip(from) {
            match node.tick() {
                NodeStatus::Failure => {
                    return NodeStatus::Failure;
                }
                NodeStatus::Running => {
                    return NodeStatus::Running;
                }
                NodeStatus::Success => {
                    self.current_child_idx += 1;
                }
                NodeStatus::Idle => return NodeStatus::Failure,
            }
        }

        NodeStatus::Success
    }

    fn node_info(&self) -> String {
        format!("Sequence: current_child_idx= {}", self.current_child_idx)
    }

    fn reset_state(&mut self) {
        self.current_child_idx = 0;
    }
}

#[derive(Default)]
pub struct Parallel {
    success_threshold: Option<usize>,
    failure_threshold: Option<usize>,
    success_count: usize,
    failure_count: usize,
    completed_list: HashSet<usize>,
}

pub const PARALLEL_SUCCESS_COUNT: &str = "success_count";
pub const PARALLEL_FAILURE_COUNT: &str = "failure_count";

impl CompositeNodeImpl for Parallel {
    fn tick_status(
        &mut self,
        data_proxy: &mut DataProxy,
        child_nodes: &mut Vec<TreeNodeWrapper>,
    ) -> NodeStatus {
        let children_count = child_nodes.len();

        let success_threshold = data_proxy
            .get_input(PARALLEL_SUCCESS_COUNT)
            .unwrap_or(self.success_threshold.unwrap_or(children_count));

        let failure_threshold = data_proxy
            .get_input(PARALLEL_FAILURE_COUNT)
            .unwrap_or(self.failure_threshold.unwrap_or(children_count));

        if children_count == 0 {
            return NodeStatus::Failure;
        }

        for i in 0..children_count {
            if self.completed_list.contains(&i) {
                continue;
            }

            let node = &mut child_nodes[i];

            match node.tick() {
                NodeStatus::Idle => return NodeStatus::Failure,
                NodeStatus::Failure => {
                    self.failure_count += 1;
                }
                NodeStatus::Success => {
                    self.success_count += 1;
                }
                NodeStatus::Running => continue,
            }

            self.completed_list.insert(i);

            if self.success_count >= success_threshold {
                return NodeStatus::Success;
            }

            if self.failure_count >= failure_threshold {
                return NodeStatus::Failure;
            }
        }

        NodeStatus::Running
    }

    fn reset_state(&mut self) {
        std::mem::swap(self, &mut Self::default());
    }
}

#[derive(Default)]
pub struct Selector {
    current_child_idx: usize,
}

impl CompositeNodeImpl for Selector {
    fn tick_status(
        &mut self,
        data_proxy: &mut DataProxy,
        child_nodes: &mut Vec<TreeNodeWrapper>,
    ) -> NodeStatus {
        for node in child_nodes.iter_mut().skip(self.current_child_idx) {
            match node.tick() {
                NodeStatus::Idle => return NodeStatus::Failure,
                NodeStatus::Success => {
                    self.reset_state();
                    return NodeStatus::Success;
                }
                NodeStatus::Running => return NodeStatus::Running,
                NodeStatus::Failure => {
                    self.current_child_idx += 1;
                }
            }
        }

        NodeStatus::Failure
    }

    fn reset_state(&mut self) {
        std::mem::swap(self, &mut Self::default());
    }
}