1use bincode::de::Decoder;
2use bincode::enc::Encoder;
3use bincode::error::{DecodeError, EncodeError};
4use bincode::{Decode, Encode};
5use cu29::prelude::*;
6use std::marker::PhantomData;
7
8#[derive(Reflect)]
9#[reflect(no_field_bounds, from_reflect = false, type_path = false)]
10pub struct CuRateLimit<T>
11where
12 T: for<'a> CuMsgPayload + 'static,
13{
14 #[reflect(ignore)]
15 _marker: PhantomData<fn() -> T>,
16 interval: CuDuration,
17 last_tov: Option<CuTime>,
18}
19
20impl<T> TypePath for CuRateLimit<T>
21where
22 T: CuMsgPayload + 'static,
23{
24 fn type_path() -> &'static str {
25 "cu_ratelimit::CuRateLimit"
26 }
27
28 fn short_type_path() -> &'static str {
29 "CuRateLimit"
30 }
31
32 fn type_ident() -> Option<&'static str> {
33 Some("CuRateLimit")
34 }
35
36 fn crate_name() -> Option<&'static str> {
37 Some("cu_ratelimit")
38 }
39
40 fn module_path() -> Option<&'static str> {
41 Some("")
42 }
43}
44
45impl<T> Freezable for CuRateLimit<T>
46where
47 T: CuMsgPayload,
48{
49 fn freeze<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
50 Encode::encode(&self.last_tov, encoder)
51 }
52
53 fn thaw<D: Decoder>(&mut self, decoder: &mut D) -> Result<(), DecodeError> {
54 self.last_tov = Decode::decode(decoder)?;
55 Ok(())
56 }
57}
58
59impl<T> CuTask for CuRateLimit<T>
60where
61 T: CuMsgPayload,
62{
63 type Resources<'r> = ();
64 type Input<'m> = input_msg!(T);
65 type Output<'m> = output_msg!(T);
66
67 fn new(config: Option<&ComponentConfig>, _resources: Self::Resources<'_>) -> CuResult<Self> {
68 let hz = match config {
69 Some(cfg) => cfg
70 .get::<f64>("rate")?
71 .ok_or("Missing required 'rate' config for CuRateLimiter")?,
72 None => return Err("Missing required 'rate' config for CuRateLimiter".into()),
73 };
74 let interval_ns = (1e9 / hz) as u64;
75 Ok(Self {
76 _marker: PhantomData,
77 interval: CuDuration::from(interval_ns),
78 last_tov: None,
79 })
80 }
81
82 fn process<'m>(
83 &mut self,
84 _clock: &RobotClock,
85 input: &Self::Input<'m>,
86 output: &mut Self::Output<'m>,
87 ) -> CuResult<()> {
88 let tov = match input.tov {
89 Tov::Time(ts) => ts,
90 _ => return Err("Expected single timestamp TOV".into()),
91 };
92
93 let allow = match self.last_tov {
94 None => true,
95 Some(last) => (tov - last) >= self.interval,
96 };
97
98 if allow {
99 self.last_tov = Some(tov);
100 if let Some(payload) = input.payload() {
101 output.set_payload(payload.clone());
102 } else {
103 output.clear_payload();
104 }
105 } else {
106 output.clear_payload();
107 }
108
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn create_test_ratelimiter(rate: f64) -> CuRateLimit<i32> {
118 let mut cfg = ComponentConfig::new();
119 cfg.set("rate", rate);
120 CuRateLimit::new(Some(&cfg), ()).unwrap()
121 }
122
123 #[test]
124 fn test_rate_limiting() {
125 let (clock, _) = RobotClock::mock();
126 let mut limiter = create_test_ratelimiter(10.0); let mut input = CuMsg::<i32>::new(Some(42));
128 let mut output = CuMsg::<i32>::new(None);
129
130 input.tov = Tov::Time(CuTime::from(0));
132 limiter.process(&clock, &input, &mut output).unwrap();
133 assert_eq!(output.payload(), Some(&42));
134
135 input.tov = Tov::Time(CuTime::from(50_000_000)); limiter.process(&clock, &input, &mut output).unwrap();
138 assert_eq!(output.payload(), None);
139
140 input.tov = Tov::Time(CuTime::from(100_000_000)); limiter.process(&clock, &input, &mut output).unwrap();
143 assert_eq!(output.payload(), Some(&42));
144 }
145
146 #[test]
147 fn test_payload_propagation() {
148 let (clock, _) = RobotClock::mock();
149 let mut limiter = create_test_ratelimiter(10.0);
150 let mut input = CuMsg::<i32>::new(None);
151 let mut output = CuMsg::<i32>::new(None);
152
153 input.set_payload(123);
155 input.tov = Tov::Time(CuTime::from(0));
156 limiter.process(&clock, &input, &mut output).unwrap();
157 assert_eq!(output.payload(), Some(&123));
158
159 input.clear_payload();
161 input.tov = Tov::Time(CuTime::from(100_000_000));
162 limiter.process(&clock, &input, &mut output).unwrap();
163 assert_eq!(output.payload(), None);
164 }
165}