cu_ratelimit/
lib.rs

1use bincode::de::Decoder;
2use bincode::enc::Encoder;
3use bincode::error::{DecodeError, EncodeError};
4use bincode::{Decode, Encode};
5use cu29::prelude::*;
6use std::marker::PhantomData;
7
8pub struct CuRateLimit<T>
9where
10    T: for<'a> CuMsgPayload + 'static,
11{
12    _marker: PhantomData<T>,
13    interval: CuDuration,
14    last_tov: Option<CuTime>,
15}
16
17impl<T> Freezable for CuRateLimit<T>
18where
19    T: CuMsgPayload,
20{
21    fn freeze<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
22        Encode::encode(&self.last_tov, encoder)
23    }
24
25    fn thaw<D: Decoder>(&mut self, decoder: &mut D) -> Result<(), DecodeError> {
26        self.last_tov = Decode::decode(decoder)?;
27        Ok(())
28    }
29}
30
31impl<T> CuTask for CuRateLimit<T>
32where
33    T: CuMsgPayload,
34{
35    type Resources<'r> = ();
36    type Input<'m> = input_msg!(T);
37    type Output<'m> = output_msg!(T);
38
39    fn new(config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self> {
40        let hz = config
41            .and_then(|cfg| cfg.get::<f64>("rate"))
42            .ok_or("Missing required 'rate' config for CuRateLimiter")?;
43        let interval_ns = (1e9 / hz) as u64;
44        Ok(Self {
45            _marker: PhantomData,
46            interval: CuDuration::from(interval_ns),
47            last_tov: None,
48        })
49    }
50
51    fn process<'m>(
52        &mut self,
53        _clock: &RobotClock,
54        input: &Self::Input<'m>,
55        output: &mut Self::Output<'m>,
56    ) -> CuResult<()> {
57        let tov = match input.tov {
58            Tov::Time(ts) => ts,
59            _ => return Err("Expected single timestamp TOV".into()),
60        };
61
62        let allow = match self.last_tov {
63            None => true,
64            Some(last) => (tov - last) >= self.interval,
65        };
66
67        if allow {
68            self.last_tov = Some(tov);
69            if let Some(payload) = input.payload() {
70                output.set_payload(payload.clone());
71            } else {
72                output.clear_payload();
73            }
74        } else {
75            output.clear_payload();
76        }
77
78        Ok(())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    fn create_test_ratelimiter(rate: f64) -> CuRateLimit<i32> {
87        let mut cfg = ComponentConfig::new();
88        cfg.set("rate", rate);
89        CuRateLimit::new(Some(&cfg), ()).unwrap()
90    }
91
92    #[test]
93    fn test_rate_limiting() {
94        let (clock, _) = RobotClock::mock();
95        let mut limiter = create_test_ratelimiter(10.0); // 10 Hz = 100ms interval
96        let mut input = CuMsg::<i32>::new(Some(42));
97        let mut output = CuMsg::<i32>::new(None);
98
99        // First message should pass
100        input.tov = Tov::Time(CuTime::from(0));
101        limiter.process(&clock, &input, &mut output).unwrap();
102        assert_eq!(output.payload(), Some(&42));
103
104        // Message within the interval should be blocked
105        input.tov = Tov::Time(CuTime::from(50_000_000)); // 50ms
106        limiter.process(&clock, &input, &mut output).unwrap();
107        assert_eq!(output.payload(), None);
108
109        // Message after the interval should pass
110        input.tov = Tov::Time(CuTime::from(100_000_000)); // 100ms
111        limiter.process(&clock, &input, &mut output).unwrap();
112        assert_eq!(output.payload(), Some(&42));
113    }
114
115    #[test]
116    fn test_payload_propagation() {
117        let (clock, _) = RobotClock::mock();
118        let mut limiter = create_test_ratelimiter(10.0);
119        let mut input = CuMsg::<i32>::new(None);
120        let mut output = CuMsg::<i32>::new(None);
121
122        // Test payload propagation
123        input.set_payload(123);
124        input.tov = Tov::Time(CuTime::from(0));
125        limiter.process(&clock, &input, &mut output).unwrap();
126        assert_eq!(output.payload(), Some(&123));
127
128        // Test empty payload propagation
129        input.clear_payload();
130        input.tov = Tov::Time(CuTime::from(100_000_000));
131        limiter.process(&clock, &input, &mut output).unwrap();
132        assert_eq!(output.payload(), None);
133    }
134}