Skip to main content

cu_ratelimit/
lib.rs

1use bincode::de::Decoder;
2use bincode::enc::Encoder;
3use bincode::error::{DecodeError, EncodeError};
4use bincode::{Decode, Encode};
5use cu29::prelude::*;
6use std::marker::PhantomData;
7
8#[derive(Reflect)]
9#[reflect(no_field_bounds, from_reflect = false, type_path = false)]
10pub struct CuRateLimit<T>
11where
12    T: for<'a> CuMsgPayload + 'static,
13{
14    #[reflect(ignore)]
15    _marker: PhantomData<fn() -> T>,
16    interval: CuDuration,
17    last_tov: Option<CuTime>,
18}
19
20impl<T> TypePath for CuRateLimit<T>
21where
22    T: CuMsgPayload + 'static,
23{
24    fn type_path() -> &'static str {
25        "cu_ratelimit::CuRateLimit"
26    }
27
28    fn short_type_path() -> &'static str {
29        "CuRateLimit"
30    }
31
32    fn type_ident() -> Option<&'static str> {
33        Some("CuRateLimit")
34    }
35
36    fn crate_name() -> Option<&'static str> {
37        Some("cu_ratelimit")
38    }
39
40    fn module_path() -> Option<&'static str> {
41        Some("")
42    }
43}
44
45impl<T> Freezable for CuRateLimit<T>
46where
47    T: CuMsgPayload,
48{
49    fn freeze<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
50        Encode::encode(&self.last_tov, encoder)
51    }
52
53    fn thaw<D: Decoder>(&mut self, decoder: &mut D) -> Result<(), DecodeError> {
54        self.last_tov = Decode::decode(decoder)?;
55        Ok(())
56    }
57}
58
59impl<T> CuTask for CuRateLimit<T>
60where
61    T: CuMsgPayload,
62{
63    type Resources<'r> = ();
64    type Input<'m> = input_msg!(T);
65    type Output<'m> = output_msg!(T);
66
67    fn new(config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self> {
68        let hz = match config {
69            Some(cfg) => cfg
70                .get::<f64>("rate")?
71                .ok_or("Missing required 'rate' config for CuRateLimiter")?,
72            None => return Err("Missing required 'rate' config for CuRateLimiter".into()),
73        };
74        let interval_ns = (1e9 / hz) as u64;
75        Ok(Self {
76            _marker: PhantomData,
77            interval: CuDuration::from(interval_ns),
78            last_tov: None,
79        })
80    }
81
82    fn process<'m>(
83        &mut self,
84        _clock: &RobotClock,
85        input: &Self::Input<'m>,
86        output: &mut Self::Output<'m>,
87    ) -> CuResult<()> {
88        let tov = match input.tov {
89            Tov::Time(ts) => ts,
90            _ => return Err("Expected single timestamp TOV".into()),
91        };
92
93        let allow = match self.last_tov {
94            None => true,
95            Some(last) => (tov - last) >= self.interval,
96        };
97
98        if allow {
99            self.last_tov = Some(tov);
100            if let Some(payload) = input.payload() {
101                output.set_payload(payload.clone());
102            } else {
103                output.clear_payload();
104            }
105        } else {
106            output.clear_payload();
107        }
108
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    fn create_test_ratelimiter(rate: f64) -> CuRateLimit<i32> {
118        let mut cfg = ComponentConfig::new();
119        cfg.set("rate", rate);
120        CuRateLimit::new(Some(&cfg), ()).unwrap()
121    }
122
123    #[test]
124    fn test_rate_limiting() {
125        let (clock, _) = RobotClock::mock();
126        let mut limiter = create_test_ratelimiter(10.0); // 10 Hz = 100ms interval
127        let mut input = CuMsg::<i32>::new(Some(42));
128        let mut output = CuMsg::<i32>::new(None);
129
130        // First message should pass
131        input.tov = Tov::Time(CuTime::from(0));
132        limiter.process(&clock, &input, &mut output).unwrap();
133        assert_eq!(output.payload(), Some(&42));
134
135        // Message within the interval should be blocked
136        input.tov = Tov::Time(CuTime::from(50_000_000)); // 50ms
137        limiter.process(&clock, &input, &mut output).unwrap();
138        assert_eq!(output.payload(), None);
139
140        // Message after the interval should pass
141        input.tov = Tov::Time(CuTime::from(100_000_000)); // 100ms
142        limiter.process(&clock, &input, &mut output).unwrap();
143        assert_eq!(output.payload(), Some(&42));
144    }
145
146    #[test]
147    fn test_payload_propagation() {
148        let (clock, _) = RobotClock::mock();
149        let mut limiter = create_test_ratelimiter(10.0);
150        let mut input = CuMsg::<i32>::new(None);
151        let mut output = CuMsg::<i32>::new(None);
152
153        // Test payload propagation
154        input.set_payload(123);
155        input.tov = Tov::Time(CuTime::from(0));
156        limiter.process(&clock, &input, &mut output).unwrap();
157        assert_eq!(output.payload(), Some(&123));
158
159        // Test empty payload propagation
160        input.clear_payload();
161        input.tov = Tov::Time(CuTime::from(100_000_000));
162        limiter.process(&clock, &input, &mut output).unwrap();
163        assert_eq!(output.payload(), None);
164    }
165}