Skip to main content

beetry_core/tree/
ticker.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4    time::Duration,
5};
6
7use futures::{Stream, future::poll_fn};
8use thiserror::Error as ThisError;
9use tokio::time::MissedTickBehavior;
10
11use crate::{Node, TickStatus};
12
13/// Ticks a behavior tree using an external tick source.
14///
15/// `Ticker` is built from any `Stream<Item = TickSignal>`, so callers can
16/// define their own ticking mechanism. A tick source can be periodic with
17/// [`PeriodicTick`], event-driven from an external stream, or a hybrid of
18/// time-based and signal-based wakeups.
19pub struct Ticker<S> {
20    stream: Pin<Box<S>>,
21}
22
23pub type TickSignal = ();
24
25/// Errors that can occur while ticking a tree from a tick source.
26#[derive(Debug, ThisError)]
27pub enum Error {
28    /// The tick source ended before the tree reached a terminal status.
29    #[error("tick source was exhausted")]
30    SourceExhausted,
31}
32
33impl<S> Ticker<S>
34where
35    S: Stream<Item = TickSignal>,
36{
37    pub fn new(stream: S) -> Self {
38        Self {
39            stream: Box::pin(stream),
40        }
41    }
42
43    pub async fn tick_till_terminal(&mut self, tree: &mut impl Node) -> Result<TickStatus, Error> {
44        while poll_fn(|cx| self.stream.as_mut().poll_next(cx))
45            .await
46            .is_some()
47        {
48            match tree.tick() {
49                TickStatus::Running => {}
50                s @ (TickStatus::Success | TickStatus::Failure) => return Ok(s),
51            }
52        }
53        Err(Error::SourceExhausted)
54    }
55}
56
57/// Built-in periodic tick source. Useful as the default in most applications.
58pub struct PeriodicTick {
59    interval: tokio::time::Interval,
60}
61
62impl PeriodicTick {
63    #[must_use]
64    pub fn new(period: Duration) -> Self {
65        Self {
66            interval: tokio::time::interval(period),
67        }
68    }
69
70    pub fn with_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) {
71        self.interval.set_missed_tick_behavior(behavior);
72    }
73}
74
75impl Stream for PeriodicTick {
76    type Item = TickSignal;
77
78    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79        self.get_mut().interval.poll_tick(cx).map(|_| Some(()))
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use std::time::Duration;
86
87    use futures::stream;
88
89    use super::*;
90    use crate::{MockNode, TickStatus};
91
92    #[tokio::test]
93    async fn periodic_tick_yields_terminal_status() {
94        let mut node = MockNode::new();
95        node.expect_tick().once().return_const(TickStatus::Success);
96
97        let mut ticker = Ticker::new(PeriodicTick::new(Duration::from_millis(1)));
98        let status = ticker
99            .tick_till_terminal(&mut node)
100            .await
101            .expect("periodic stream should tick at least once");
102
103        assert_eq!(status, TickStatus::Success);
104    }
105
106    #[tokio::test]
107    async fn terminal_status_precedes_source_exhaustion() {
108        let mut node = MockNode::new();
109        let mut statuses = [TickStatus::Running, TickStatus::Success].into_iter();
110        node.expect_tick()
111            .times(2)
112            .returning(move || statuses.next().expect("status sequence configured"));
113
114        let mut ticker = Ticker::new(stream::iter([(), (), ()]));
115        let status = ticker
116            .tick_till_terminal(&mut node)
117            .await
118            .expect("tree reaches terminal status before source exhaustion");
119
120        assert_eq!(status, TickStatus::Success);
121    }
122
123    #[tokio::test]
124    async fn tick_source_exhausted() {
125        let mut node = MockNode::new();
126        node.expect_tick()
127            .times(2)
128            .return_const(TickStatus::Running);
129
130        let mut ticker = Ticker::new(stream::iter([(), ()]));
131        let result = ticker.tick_till_terminal(&mut node).await;
132        assert!(matches!(result, Err(Error::SourceExhausted)));
133    }
134}