1use std::{collections::BTreeMap, time::SystemTime};
2
3use crate::pressure::{
4 ChannelPolicyConfig, ChannelPressureSnapshot, ConsumerPressureMetrics, PolicyAction,
5 PressureMonitor,
6};
7
8#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct ScaleSignal {
11 pub channel_id: String,
13 pub current_consumer_count: usize,
15}
16
17#[derive(Clone, Debug, PartialEq)]
19pub enum PolicyEvent {
20 Action {
22 channel_id: String,
24 action: PolicyAction,
26 triggered_at: SystemTime,
28 },
29 ScaleConsumer {
31 signal: ScaleSignal,
33 triggered_at: SystemTime,
35 },
36}
37
38#[derive(Clone, Debug, PartialEq)]
40pub struct EnforcementOutcome {
41 pub snapshot: ChannelPressureSnapshot,
43 pub actions: Vec<PolicyAction>,
45 pub events: Vec<PolicyEvent>,
47}
48
49#[derive(Debug, Default)]
51pub struct PressureEnforcer {
52 monitor: PressureMonitor,
53 policies: BTreeMap<String, ChannelPolicyConfig>,
54 last_events: Vec<PolicyEvent>,
55 last_snapshot: Option<ChannelPressureSnapshot>,
56}
57
58impl PressureEnforcer {
59 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 monitor: PressureMonitor::new(),
64 policies: BTreeMap::new(),
65 last_events: Vec::new(),
66 last_snapshot: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn with_monitor(monitor: PressureMonitor) -> Self {
73 Self {
74 monitor,
75 policies: BTreeMap::new(),
76 last_events: Vec::new(),
77 last_snapshot: None,
78 }
79 }
80
81 #[must_use]
83 pub const fn monitor(&self) -> &PressureMonitor {
84 &self.monitor
85 }
86
87 #[must_use]
89 pub fn last_events(&self) -> &[PolicyEvent] {
90 &self.last_events
91 }
92
93 #[must_use]
95 pub const fn last_snapshot(&self) -> Option<&ChannelPressureSnapshot> {
96 self.last_snapshot.as_ref()
97 }
98
99 pub fn set_channel_policy(
101 &mut self,
102 channel_id: impl Into<String>,
103 config: ChannelPolicyConfig,
104 ) {
105 self.policies.insert(channel_id.into(), config);
106 }
107
108 #[must_use]
110 pub fn channel_policy(&self, channel_id: &str) -> Option<&ChannelPolicyConfig> {
111 self.policies.get(channel_id)
112 }
113
114 pub fn record_consumer_metrics(
116 &mut self,
117 channel_id: impl Into<String>,
118 consumer_id: impl Into<String>,
119 metrics: ConsumerPressureMetrics,
120 ) -> Vec<PolicyAction> {
121 self.record_consumer_metrics_outcome(channel_id, consumer_id, metrics)
122 .actions
123 }
124
125 pub fn record_consumer_metrics_outcome(
131 &mut self,
132 channel_id: impl Into<String>,
133 consumer_id: impl Into<String>,
134 metrics: ConsumerPressureMetrics,
135 ) -> EnforcementOutcome {
136 let channel_id = channel_id.into();
137 let snapshot =
138 self.monitor
139 .record_consumer_metrics(channel_id.clone(), consumer_id, metrics);
140 self.enforce_snapshot(&channel_id, snapshot)
141 }
142
143 pub fn record_accept(
145 &mut self,
146 channel_id: impl Into<String>,
147 consumer_id: impl Into<String>,
148 ) -> Vec<PolicyAction> {
149 self.record_accept_outcome(channel_id, consumer_id).actions
150 }
151
152 pub fn record_accept_outcome(
156 &mut self,
157 channel_id: impl Into<String>,
158 consumer_id: impl Into<String>,
159 ) -> EnforcementOutcome {
160 let channel_id = channel_id.into();
161 let snapshot = self.monitor.record_accept(channel_id.clone(), consumer_id);
162 self.enforce_snapshot(&channel_id, snapshot)
163 }
164
165 pub fn record_defer(
167 &mut self,
168 channel_id: impl Into<String>,
169 consumer_id: impl Into<String>,
170 ) -> Vec<PolicyAction> {
171 self.record_defer_outcome(channel_id, consumer_id).actions
172 }
173
174 pub fn record_defer_outcome(
178 &mut self,
179 channel_id: impl Into<String>,
180 consumer_id: impl Into<String>,
181 ) -> EnforcementOutcome {
182 let channel_id = channel_id.into();
183 let snapshot = self.monitor.record_defer(channel_id.clone(), consumer_id);
184 self.enforce_snapshot(&channel_id, snapshot)
185 }
186
187 pub fn record_reject(
189 &mut self,
190 channel_id: impl Into<String>,
191 consumer_id: impl Into<String>,
192 ) -> Vec<PolicyAction> {
193 self.record_reject_outcome(channel_id, consumer_id).actions
194 }
195
196 pub fn record_reject_outcome(
200 &mut self,
201 channel_id: impl Into<String>,
202 consumer_id: impl Into<String>,
203 ) -> EnforcementOutcome {
204 let channel_id = channel_id.into();
205 let snapshot = self.monitor.record_reject(channel_id.clone(), consumer_id);
206 self.enforce_snapshot(&channel_id, snapshot)
207 }
208
209 fn enforce_snapshot(
210 &mut self,
211 channel_id: &str,
212 snapshot: ChannelPressureSnapshot,
213 ) -> EnforcementOutcome {
214 let actions = self
215 .policies
216 .get(channel_id)
217 .map_or_else(Vec::new, |config| {
218 config.actions_for_pressure(snapshot.pressure_score)
219 });
220 let triggered_at = SystemTime::now();
221 let events = Self::events_for_actions(
222 channel_id,
223 snapshot.consumer_count(),
224 &actions,
225 triggered_at,
226 );
227 self.last_events.clone_from(&events);
228 self.last_snapshot = Some(snapshot.clone());
229 EnforcementOutcome {
230 snapshot,
231 actions,
232 events,
233 }
234 }
235
236 fn events_for_actions(
237 channel_id: &str,
238 current_consumer_count: usize,
239 actions: &[PolicyAction],
240 triggered_at: SystemTime,
241 ) -> Vec<PolicyEvent> {
242 let mut events = Vec::with_capacity(actions.len());
243 for action in actions {
244 events.push(PolicyEvent::Action {
245 channel_id: channel_id.to_owned(),
246 action: action.clone(),
247 triggered_at,
248 });
249 if matches!(action, PolicyAction::ScaleConsumer) {
250 events.push(PolicyEvent::ScaleConsumer {
251 signal: ScaleSignal {
252 channel_id: channel_id.to_owned(),
253 current_consumer_count,
254 },
255 triggered_at,
256 });
257 }
258 }
259 events
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::{PolicyEvent, PressureEnforcer, ScaleSignal};
266 use crate::pressure::{
267 ChannelPolicyConfig, ConsumerPressureMetrics, PolicyAction, PressureMonitor, PressurePolicy,
268 };
269
270 fn slow_policy(threshold: f64) -> PressurePolicy {
271 PressurePolicy {
272 threshold,
273 action: PolicyAction::SlowProducer {
274 reduction_factor: 0.5,
275 },
276 }
277 }
278
279 #[test]
280 fn policy_enforcement_emits_slow_producer_when_threshold_is_crossed() {
281 let mut enforcer = PressureEnforcer::new();
282 enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
283
284 let actions = enforcer.record_consumer_metrics(
285 "orders",
286 "consumer-a",
287 ConsumerPressureMetrics::new(7, 10, 0),
288 );
289
290 assert_eq!(
291 actions,
292 vec![PolicyAction::SlowProducer {
293 reduction_factor: 0.5
294 }]
295 );
296 assert!(matches!(
297 enforcer.last_events(),
298 [PolicyEvent::Action {
299 channel_id,
300 action: PolicyAction::SlowProducer { reduction_factor },
301 triggered_at: _
302 }] if channel_id == "orders" && (*reduction_factor - 0.5).abs() < f64::EPSILON
303 ));
304 }
305
306 #[test]
307 fn enforcement_runs_as_part_of_each_monitor_update() {
308 let mut enforcer = PressureEnforcer::new();
309 enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
310
311 let below = enforcer.record_consumer_metrics(
312 "orders",
313 "consumer-a",
314 ConsumerPressureMetrics::new(6, 10, 0),
315 );
316 let above = enforcer.record_consumer_metrics(
317 "orders",
318 "consumer-a",
319 ConsumerPressureMetrics::new(8, 10, 0),
320 );
321
322 assert!(below.is_empty());
323 assert_eq!(above.len(), 1);
324 }
325
326 #[test]
327 fn enforcement_returns_no_actions_below_all_thresholds() {
328 let mut enforcer = PressureEnforcer::new();
329 enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
330
331 let actions = enforcer.record_consumer_metrics(
332 "orders",
333 "consumer-a",
334 ConsumerPressureMetrics::new(3, 10, 0),
335 );
336
337 assert!(actions.is_empty());
338 assert!(enforcer.last_events().is_empty());
339 }
340
341 #[test]
342 fn scale_consumer_policy_emits_scale_signal_with_channel_and_consumer_count() {
343 let mut enforcer = PressureEnforcer::new();
344 enforcer.set_channel_policy(
345 "orders",
346 ChannelPolicyConfig::new(vec![PressurePolicy {
347 threshold: 0.7,
348 action: PolicyAction::ScaleConsumer,
349 }]),
350 );
351 enforcer.record_consumer_metrics(
352 "orders",
353 "consumer-a",
354 ConsumerPressureMetrics::new(6, 10, 0),
355 );
356
357 let actions = enforcer.record_consumer_metrics(
358 "orders",
359 "consumer-b",
360 ConsumerPressureMetrics::new(8, 10, 0),
361 );
362
363 assert_eq!(actions, vec![PolicyAction::ScaleConsumer]);
364 assert!(enforcer.last_events().iter().any(|event| matches!(
365 event,
366 PolicyEvent::ScaleConsumer {
367 signal: ScaleSignal {
368 channel_id,
369 current_consumer_count: 2,
370 },
371 triggered_at: _
372 } if channel_id == "orders"
373 )));
374 }
375
376 #[test]
377 fn custom_monitor_scores_are_enforced_without_manual_evaluate_call() {
378 let monitor = PressureMonitor::with_scoring(|_| 1.0);
379 let mut enforcer = PressureEnforcer::with_monitor(monitor);
380 enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
381
382 let actions = enforcer.record_accept("orders", "consumer-a");
383
384 assert_eq!(actions.len(), 1);
385 }
386}