1use super::{MessageId, ProtocolError};
2
3#[derive(Clone, Debug, PartialEq, Eq)]
5pub struct AcceptPayload {
6 pub referenced_message_id: MessageId,
8}
9
10#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct DeferPayload {
13 pub referenced_message_id: MessageId,
15 pub reason: Option<String>,
17}
18
19#[derive(Clone, Debug, PartialEq, Eq)]
21pub struct RejectPayload {
22 pub referenced_message_id: MessageId,
24 pub reason: Option<String>,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
30pub enum PressureState {
31 Normal,
33 Deferred,
35 Rejecting,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq)]
41pub struct StreamPressure {
42 outstanding_count: u32,
44 max_in_flight: u32,
46 state: PressureState,
48}
49
50impl StreamPressure {
51 pub fn new(max_in_flight: u32) -> Result<Self, ProtocolError> {
57 if max_in_flight == 0 {
58 return Err(ProtocolError::codec(
59 "max_in_flight must be greater than zero",
60 ));
61 }
62
63 Ok(Self {
64 outstanding_count: 0,
65 max_in_flight,
66 state: PressureState::Normal,
67 })
68 }
69
70 #[must_use]
72 pub const fn outstanding_count(&self) -> u32 {
73 self.outstanding_count
74 }
75
76 #[must_use]
78 pub const fn max_in_flight(&self) -> u32 {
79 self.max_in_flight
80 }
81
82 #[must_use]
84 pub const fn state(&self) -> PressureState {
85 self.state
86 }
87
88 pub fn record_delivery(
98 &mut self,
99 buffer_capacity: u32,
100 ) -> Result<PressureState, ProtocolError> {
101 let next_outstanding = self
102 .outstanding_count
103 .checked_add(1)
104 .ok_or_else(|| ProtocolError::codec("outstanding message count overflowed"))?;
105 let next_state = Self::state_for(next_outstanding, self.max_in_flight, buffer_capacity)?;
106
107 self.outstanding_count = next_outstanding;
108 self.state = next_state;
109 Ok(next_state)
110 }
111
112 pub fn record_accept(&mut self, buffer_capacity: u32) -> Result<PressureState, ProtocolError> {
122 let next_outstanding = self
123 .outstanding_count
124 .checked_sub(1)
125 .ok_or_else(|| ProtocolError::codec("cannot accept with zero outstanding messages"))?;
126 let next_state = Self::state_for(next_outstanding, self.max_in_flight, buffer_capacity)?;
127
128 self.outstanding_count = next_outstanding;
129 self.state = next_state;
130 Ok(next_state)
131 }
132
133 fn state_for(
134 outstanding_count: u32,
135 max_in_flight: u32,
136 buffer_capacity: u32,
137 ) -> Result<PressureState, ProtocolError> {
138 let reject_threshold = max_in_flight
139 .checked_add(buffer_capacity)
140 .ok_or_else(|| ProtocolError::codec("pressure buffer threshold overflowed"))?;
141
142 Ok(if outstanding_count < max_in_flight {
143 PressureState::Normal
144 } else if outstanding_count > reject_threshold {
145 PressureState::Rejecting
146 } else {
147 PressureState::Deferred
148 })
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use std::fmt::Debug;
155
156 use super::{AcceptPayload, DeferPayload, PressureState, RejectPayload, StreamPressure};
157 use crate::protocol::{Frame, MessageId, ProtocolError};
158
159 #[test]
160 fn pressure_state_has_exact_required_variants() {
161 fn state_name(state: PressureState) -> &'static str {
162 match state {
163 PressureState::Normal => "normal",
164 PressureState::Deferred => "deferred",
165 PressureState::Rejecting => "rejecting",
166 }
167 }
168
169 let variants = [
170 PressureState::Normal,
171 PressureState::Deferred,
172 PressureState::Rejecting,
173 ];
174
175 assert_eq!(variants.len(), 3);
176 assert_eq!(state_name(PressureState::Normal), "normal");
177 assert_eq!(state_name(PressureState::Deferred), "deferred");
178 assert_eq!(state_name(PressureState::Rejecting), "rejecting");
179 }
180
181 #[test]
182 fn public_backpressure_types_implement_debug() {
183 fn assert_debug<T: Debug>() {}
184
185 assert_debug::<AcceptPayload>();
186 assert_debug::<DeferPayload>();
187 assert_debug::<RejectPayload>();
188 assert_debug::<PressureState>();
189 assert_debug::<StreamPressure>();
190 }
191
192 #[test]
193 fn payload_structs_carry_referenced_message_ids_and_reasons() {
194 let accept = AcceptPayload {
195 referenced_message_id: MessageId::from("message-1"),
196 };
197 let defer = DeferPayload {
198 referenced_message_id: MessageId::from("message-2"),
199 reason: Some("buffered".to_owned()),
200 };
201 let reject = RejectPayload {
202 referenced_message_id: MessageId::from("message-3"),
203 reason: None,
204 };
205
206 assert_eq!(accept.referenced_message_id.as_str(), "message-1");
207 assert_eq!(defer.reason.as_deref(), Some("buffered"));
208 assert_eq!(reject.reason, None);
209 }
210
211 #[test]
212 fn stream_pressure_rejects_zero_capacity() {
213 assert!(matches!(
214 StreamPressure::new(0),
215 Err(ProtocolError::CodecError { .. })
216 ));
217 }
218
219 #[test]
220 fn stream_pressure_transitions_to_deferred_at_max_in_flight() -> Result<(), ProtocolError> {
221 let mut pressure = StreamPressure::new(10)?;
222
223 for _ in 0..9 {
224 assert_eq!(pressure.record_delivery(0)?, PressureState::Normal);
225 }
226
227 assert_eq!(pressure.record_delivery(0)?, PressureState::Deferred);
228 assert_eq!(pressure.outstanding_count(), 10);
229 assert_eq!(pressure.max_in_flight(), 10);
230 assert_eq!(pressure.state(), PressureState::Deferred);
231 Ok(())
232 }
233
234 #[test]
235 fn stream_pressure_transitions_to_rejecting_beyond_buffer() -> Result<(), ProtocolError> {
236 let mut pressure = StreamPressure::new(2)?;
237
238 assert_eq!(pressure.record_delivery(1)?, PressureState::Normal);
239 assert_eq!(pressure.record_delivery(1)?, PressureState::Deferred);
240 assert_eq!(pressure.record_delivery(1)?, PressureState::Deferred);
241 assert_eq!(pressure.record_delivery(1)?, PressureState::Rejecting);
242 Ok(())
243 }
244
245 #[test]
246 fn accept_decrements_outstanding_and_returns_to_normal() -> Result<(), ProtocolError> {
247 let mut pressure = StreamPressure::new(10)?;
248
249 for _ in 0..10 {
250 pressure.record_delivery(0)?;
251 }
252
253 assert_eq!(pressure.record_accept(0)?, PressureState::Normal);
254 assert_eq!(pressure.outstanding_count(), 9);
255 assert_eq!(pressure.state(), PressureState::Normal);
256 Ok(())
257 }
258
259 #[test]
260 fn accept_preserves_rejecting_when_still_beyond_buffer() -> Result<(), ProtocolError> {
261 let mut pressure = StreamPressure::new(2)?;
262
263 for _ in 0..5 {
264 pressure.record_delivery(1)?;
265 }
266 assert_eq!(pressure.state(), PressureState::Rejecting);
267
268 assert_eq!(pressure.record_accept(1)?, PressureState::Rejecting);
269 assert_eq!(pressure.outstanding_count(), 4);
270 assert_eq!(pressure.state(), PressureState::Rejecting);
271 Ok(())
272 }
273
274 #[test]
275 fn accept_with_zero_outstanding_returns_protocol_error() -> Result<(), ProtocolError> {
276 let mut pressure = StreamPressure::new(10)?;
277
278 assert!(matches!(
279 pressure.record_accept(0),
280 Err(ProtocolError::CodecError { .. })
281 ));
282 Ok(())
283 }
284
285 #[test]
286 fn subscribe_capacity_can_create_stream_pressure() -> Result<(), ProtocolError> {
287 let subscribe = Frame::Subscribe {
288 flags: 0,
289 stream_id: 1,
290 channel: "orders".to_owned(),
291 accepted_schemas: Vec::new(),
292 max_in_flight: 100,
293 };
294 let Frame::Subscribe { max_in_flight, .. } = subscribe else {
295 return Err(ProtocolError::codec("test frame was not Subscribe"));
296 };
297 let pressure = StreamPressure::new(max_in_flight)?;
298
299 assert_eq!(pressure.max_in_flight(), 100);
300 assert_eq!(pressure.outstanding_count(), 0);
301 assert_eq!(pressure.state(), PressureState::Normal);
302 Ok(())
303 }
304}