Skip to main content

liminal/pressure/
enforce.rs

1use std::{collections::BTreeMap, time::SystemTime};
2
3use crate::pressure::{
4    ChannelPolicyConfig, ChannelPressureSnapshot, ConsumerPressureMetrics, PolicyAction,
5    PressureMonitor,
6};
7
8/// Scaling signal emitted for external orchestrators when pressure requires more consumers.
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct ScaleSignal {
11    /// Channel whose pressure triggered the scale signal.
12    pub channel_id: String,
13    /// Number of consumers currently tracked for the channel.
14    pub current_consumer_count: usize,
15}
16
17/// Typed pressure policy event emitted by the enforcer.
18#[derive(Clone, Debug, PartialEq)]
19pub enum PolicyEvent {
20    /// A policy action became active for a channel pressure update.
21    Action {
22        /// Channel whose policy action triggered.
23        channel_id: String,
24        /// Action selected by the channel policy configuration.
25        action: PolicyAction,
26        /// Wall-clock time when the enforcer emitted this event.
27        triggered_at: SystemTime,
28    },
29    /// A scale-consumer policy emitted a scaling signal for external orchestration.
30    ScaleConsumer {
31        /// Signal payload for the external orchestrator.
32        signal: ScaleSignal,
33        /// Wall-clock time when the enforcer emitted this event.
34        triggered_at: SystemTime,
35    },
36}
37
38/// Result returned from a monitor update after policy enforcement runs.
39#[derive(Clone, Debug, PartialEq)]
40pub struct EnforcementOutcome {
41    /// Snapshot of the pressure state after the monitor update.
42    pub snapshot: ChannelPressureSnapshot,
43    /// Triggered policy actions for the caller to apply synchronously.
44    pub actions: Vec<PolicyAction>,
45    /// Typed events consumable by routing, dispatch, or observability subsystems.
46    pub events: Vec<PolicyEvent>,
47}
48
49/// Synchronously updates pressure metrics and enforces channel policies.
50#[derive(Debug, Default)]
51pub struct PressureEnforcer {
52    monitor: PressureMonitor,
53    policies: BTreeMap<String, ChannelPolicyConfig>,
54    last_events: Vec<PolicyEvent>,
55    last_snapshot: Option<ChannelPressureSnapshot>,
56}
57
58impl PressureEnforcer {
59    /// Creates an enforcer with an empty monitor and no configured channel policies.
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            monitor: PressureMonitor::new(),
64            policies: BTreeMap::new(),
65            last_events: Vec::new(),
66            last_snapshot: None,
67        }
68    }
69
70    /// Creates an enforcer around an existing monitor.
71    #[must_use]
72    pub const fn with_monitor(monitor: PressureMonitor) -> Self {
73        Self {
74            monitor,
75            policies: BTreeMap::new(),
76            last_events: Vec::new(),
77            last_snapshot: None,
78        }
79    }
80
81    /// Returns the monitor used by the enforcer.
82    #[must_use]
83    pub const fn monitor(&self) -> &PressureMonitor {
84        &self.monitor
85    }
86
87    /// Returns the latest typed policy events emitted by an automatic update.
88    #[must_use]
89    pub fn last_events(&self) -> &[PolicyEvent] {
90        &self.last_events
91    }
92
93    /// Returns the latest pressure snapshot observed by an automatic update.
94    #[must_use]
95    pub const fn last_snapshot(&self) -> Option<&ChannelPressureSnapshot> {
96        self.last_snapshot.as_ref()
97    }
98
99    /// Registers or replaces the pressure policy configuration for a channel.
100    pub fn set_channel_policy(
101        &mut self,
102        channel_id: impl Into<String>,
103        config: ChannelPolicyConfig,
104    ) {
105        self.policies.insert(channel_id.into(), config);
106    }
107
108    /// Returns the pressure policy configuration for a channel, if one is registered.
109    #[must_use]
110    pub fn channel_policy(&self, channel_id: &str) -> Option<&ChannelPolicyConfig> {
111        self.policies.get(channel_id)
112    }
113
114    /// Records consumer metrics and immediately evaluates the channel policy.
115    pub fn record_consumer_metrics(
116        &mut self,
117        channel_id: impl Into<String>,
118        consumer_id: impl Into<String>,
119        metrics: ConsumerPressureMetrics,
120    ) -> Vec<PolicyAction> {
121        self.record_consumer_metrics_outcome(channel_id, consumer_id, metrics)
122            .actions
123    }
124
125    /// Records consumer metrics and returns the full enforcement outcome.
126    ///
127    /// The primary update path returns the action vector directly; this helper
128    /// exposes the same automatic enforcement result with its snapshot and
129    /// typed events for observers that need more context.
130    pub fn record_consumer_metrics_outcome(
131        &mut self,
132        channel_id: impl Into<String>,
133        consumer_id: impl Into<String>,
134        metrics: ConsumerPressureMetrics,
135    ) -> EnforcementOutcome {
136        let channel_id = channel_id.into();
137        let snapshot =
138            self.monitor
139                .record_consumer_metrics(channel_id.clone(), consumer_id, metrics);
140        self.enforce_snapshot(&channel_id, snapshot)
141    }
142
143    /// Records an accept decision and immediately evaluates the channel policy.
144    pub fn record_accept(
145        &mut self,
146        channel_id: impl Into<String>,
147        consumer_id: impl Into<String>,
148    ) -> Vec<PolicyAction> {
149        self.record_accept_outcome(channel_id, consumer_id).actions
150    }
151
152    /// Records an accept decision and returns the full enforcement outcome.
153    ///
154    /// Enforcement still runs synchronously as part of this monitor update.
155    pub fn record_accept_outcome(
156        &mut self,
157        channel_id: impl Into<String>,
158        consumer_id: impl Into<String>,
159    ) -> EnforcementOutcome {
160        let channel_id = channel_id.into();
161        let snapshot = self.monitor.record_accept(channel_id.clone(), consumer_id);
162        self.enforce_snapshot(&channel_id, snapshot)
163    }
164
165    /// Records a defer decision and immediately evaluates the channel policy.
166    pub fn record_defer(
167        &mut self,
168        channel_id: impl Into<String>,
169        consumer_id: impl Into<String>,
170    ) -> Vec<PolicyAction> {
171        self.record_defer_outcome(channel_id, consumer_id).actions
172    }
173
174    /// Records a defer decision and returns the full enforcement outcome.
175    ///
176    /// Enforcement still runs synchronously as part of this monitor update.
177    pub fn record_defer_outcome(
178        &mut self,
179        channel_id: impl Into<String>,
180        consumer_id: impl Into<String>,
181    ) -> EnforcementOutcome {
182        let channel_id = channel_id.into();
183        let snapshot = self.monitor.record_defer(channel_id.clone(), consumer_id);
184        self.enforce_snapshot(&channel_id, snapshot)
185    }
186
187    /// Records a reject decision and immediately evaluates the channel policy.
188    pub fn record_reject(
189        &mut self,
190        channel_id: impl Into<String>,
191        consumer_id: impl Into<String>,
192    ) -> Vec<PolicyAction> {
193        self.record_reject_outcome(channel_id, consumer_id).actions
194    }
195
196    /// Records a reject decision and returns the full enforcement outcome.
197    ///
198    /// Enforcement still runs synchronously as part of this monitor update.
199    pub fn record_reject_outcome(
200        &mut self,
201        channel_id: impl Into<String>,
202        consumer_id: impl Into<String>,
203    ) -> EnforcementOutcome {
204        let channel_id = channel_id.into();
205        let snapshot = self.monitor.record_reject(channel_id.clone(), consumer_id);
206        self.enforce_snapshot(&channel_id, snapshot)
207    }
208
209    fn enforce_snapshot(
210        &mut self,
211        channel_id: &str,
212        snapshot: ChannelPressureSnapshot,
213    ) -> EnforcementOutcome {
214        let actions = self
215            .policies
216            .get(channel_id)
217            .map_or_else(Vec::new, |config| {
218                config.actions_for_pressure(snapshot.pressure_score)
219            });
220        let triggered_at = SystemTime::now();
221        let events = Self::events_for_actions(
222            channel_id,
223            snapshot.consumer_count(),
224            &actions,
225            triggered_at,
226        );
227        self.last_events.clone_from(&events);
228        self.last_snapshot = Some(snapshot.clone());
229        EnforcementOutcome {
230            snapshot,
231            actions,
232            events,
233        }
234    }
235
236    fn events_for_actions(
237        channel_id: &str,
238        current_consumer_count: usize,
239        actions: &[PolicyAction],
240        triggered_at: SystemTime,
241    ) -> Vec<PolicyEvent> {
242        let mut events = Vec::with_capacity(actions.len());
243        for action in actions {
244            events.push(PolicyEvent::Action {
245                channel_id: channel_id.to_owned(),
246                action: action.clone(),
247                triggered_at,
248            });
249            if matches!(action, PolicyAction::ScaleConsumer) {
250                events.push(PolicyEvent::ScaleConsumer {
251                    signal: ScaleSignal {
252                        channel_id: channel_id.to_owned(),
253                        current_consumer_count,
254                    },
255                    triggered_at,
256                });
257            }
258        }
259        events
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::{PolicyEvent, PressureEnforcer, ScaleSignal};
266    use crate::pressure::{
267        ChannelPolicyConfig, ConsumerPressureMetrics, PolicyAction, PressureMonitor, PressurePolicy,
268    };
269
270    fn slow_policy(threshold: f64) -> PressurePolicy {
271        PressurePolicy {
272            threshold,
273            action: PolicyAction::SlowProducer {
274                reduction_factor: 0.5,
275            },
276        }
277    }
278
279    #[test]
280    fn policy_enforcement_emits_slow_producer_when_threshold_is_crossed() {
281        let mut enforcer = PressureEnforcer::new();
282        enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
283
284        let actions = enforcer.record_consumer_metrics(
285            "orders",
286            "consumer-a",
287            ConsumerPressureMetrics::new(7, 10, 0),
288        );
289
290        assert_eq!(
291            actions,
292            vec![PolicyAction::SlowProducer {
293                reduction_factor: 0.5
294            }]
295        );
296        assert!(matches!(
297            enforcer.last_events(),
298            [PolicyEvent::Action {
299                channel_id,
300                action: PolicyAction::SlowProducer { reduction_factor },
301                triggered_at: _
302            }] if channel_id == "orders" && (*reduction_factor - 0.5).abs() < f64::EPSILON
303        ));
304    }
305
306    #[test]
307    fn enforcement_runs_as_part_of_each_monitor_update() {
308        let mut enforcer = PressureEnforcer::new();
309        enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
310
311        let below = enforcer.record_consumer_metrics(
312            "orders",
313            "consumer-a",
314            ConsumerPressureMetrics::new(6, 10, 0),
315        );
316        let above = enforcer.record_consumer_metrics(
317            "orders",
318            "consumer-a",
319            ConsumerPressureMetrics::new(8, 10, 0),
320        );
321
322        assert!(below.is_empty());
323        assert_eq!(above.len(), 1);
324    }
325
326    #[test]
327    fn enforcement_returns_no_actions_below_all_thresholds() {
328        let mut enforcer = PressureEnforcer::new();
329        enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
330
331        let actions = enforcer.record_consumer_metrics(
332            "orders",
333            "consumer-a",
334            ConsumerPressureMetrics::new(3, 10, 0),
335        );
336
337        assert!(actions.is_empty());
338        assert!(enforcer.last_events().is_empty());
339    }
340
341    #[test]
342    fn scale_consumer_policy_emits_scale_signal_with_channel_and_consumer_count() {
343        let mut enforcer = PressureEnforcer::new();
344        enforcer.set_channel_policy(
345            "orders",
346            ChannelPolicyConfig::new(vec![PressurePolicy {
347                threshold: 0.7,
348                action: PolicyAction::ScaleConsumer,
349            }]),
350        );
351        enforcer.record_consumer_metrics(
352            "orders",
353            "consumer-a",
354            ConsumerPressureMetrics::new(6, 10, 0),
355        );
356
357        let actions = enforcer.record_consumer_metrics(
358            "orders",
359            "consumer-b",
360            ConsumerPressureMetrics::new(8, 10, 0),
361        );
362
363        assert_eq!(actions, vec![PolicyAction::ScaleConsumer]);
364        assert!(enforcer.last_events().iter().any(|event| matches!(
365            event,
366            PolicyEvent::ScaleConsumer {
367                signal: ScaleSignal {
368                    channel_id,
369                    current_consumer_count: 2,
370                },
371                triggered_at: _
372            } if channel_id == "orders"
373        )));
374    }
375
376    #[test]
377    fn custom_monitor_scores_are_enforced_without_manual_evaluate_call() {
378        let monitor = PressureMonitor::with_scoring(|_| 1.0);
379        let mut enforcer = PressureEnforcer::with_monitor(monitor);
380        enforcer.set_channel_policy("orders", ChannelPolicyConfig::new(vec![slow_policy(0.7)]));
381
382        let actions = enforcer.record_accept("orders", "consumer-a");
383
384        assert_eq!(actions.len(), 1);
385    }
386}