1use serde::{Deserialize, Serialize};
9
10use crate::abi::InstanceId;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
15#[non_exhaustive]
16pub enum QuotaReductionPolicy {
17 #[default]
19 Reject,
20 GrandfatherExisting,
23 ThrottleProportional,
27}
28
29#[non_exhaustive]
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub enum QuotaReductionError {
34 WouldViolateChildren {
37 current_usage: u64,
40 new_quota: u64,
42 },
43}
44
45impl core::fmt::Display for QuotaReductionError {
46 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47 match self {
48 Self::WouldViolateChildren {
49 current_usage,
50 new_quota,
51 } => write!(
52 f,
53 "quota reduction would violate children: current_usage={}, new_quota={}",
54 current_usage, new_quota
55 ),
56 }
57 }
58}
59
60impl std::error::Error for QuotaReductionError {}
61
62#[must_use = "policy result determines whether the reduction can proceed"]
83pub fn apply_quota_reduction(
84 policy: QuotaReductionPolicy,
85 new_quota: u64,
86 children: &[(InstanceId, u64)],
87) -> Result<Vec<(InstanceId, u64)>, QuotaReductionError> {
88 debug_assert!(
89 children.windows(2).all(|w| w[0].0 <= w[1].0),
90 "apply_quota_reduction: children must be sorted by InstanceId ascending"
91 );
92
93 let current_total: u64 = children
94 .iter()
95 .map(|(_, u)| *u)
96 .fold(0u64, u64::saturating_add);
97
98 match policy {
99 QuotaReductionPolicy::Reject => {
100 if current_total > new_quota {
101 Err(QuotaReductionError::WouldViolateChildren {
102 current_usage: current_total,
103 new_quota,
104 })
105 } else {
106 Ok(children.to_vec())
107 }
108 }
109 QuotaReductionPolicy::GrandfatherExisting => Ok(children.to_vec()),
110 QuotaReductionPolicy::ThrottleProportional => {
111 if children.is_empty() || current_total == 0 {
112 return Ok(children.to_vec());
113 }
114 let total = current_total as u128;
115 let target = new_quota as u128;
116 let mut out: Vec<(InstanceId, u64)> = Vec::with_capacity(children.len());
117 let mut allocated: u128 = 0;
118 for (id, current) in children {
119 let scaled = (*current as u128).saturating_mul(target) / total;
120 let scaled_u64 = scaled.min(u64::MAX as u128) as u64;
121 allocated = allocated.saturating_add(scaled);
122 out.push((*id, scaled_u64));
123 }
124 let mut remainder = target.saturating_sub(allocated);
125 for entry in out.iter_mut() {
126 if remainder == 0 {
127 break;
128 }
129 entry.1 = entry.1.saturating_add(1);
130 remainder -= 1;
131 }
132 Ok(out)
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn policy_default_is_reject() {
143 assert_eq!(
144 QuotaReductionPolicy::default(),
145 QuotaReductionPolicy::Reject
146 );
147 }
148
149 #[test]
150 fn policy_three_distinct_variants() {
151 let a = QuotaReductionPolicy::Reject;
152 let b = QuotaReductionPolicy::GrandfatherExisting;
153 let c = QuotaReductionPolicy::ThrottleProportional;
154 assert_ne!(a, b);
155 assert_ne!(b, c);
156 assert_ne!(a, c);
157 }
158
159 #[test]
160 fn error_display_includes_numbers() {
161 let e = QuotaReductionError::WouldViolateChildren {
162 current_usage: 100,
163 new_quota: 50,
164 };
165 let s = format!("{}", e);
166 assert!(s.contains("current_usage=100"));
167 assert!(s.contains("new_quota=50"));
168 }
169
170 #[test]
171 fn error_implements_std_error() {
172 fn assert_err<E: std::error::Error>() {}
173 assert_err::<QuotaReductionError>();
174 }
175
176 fn id(n: u64) -> InstanceId {
179 InstanceId::new(n).unwrap()
180 }
181
182 #[test]
183 fn reject_allows_under_quota() {
184 let cs = vec![(id(1), 25), (id(2), 25)];
185 let result = apply_quota_reduction(QuotaReductionPolicy::Reject, 100, &cs).unwrap();
186 assert_eq!(result, cs);
187 }
188
189 #[test]
190 fn reject_denies_over_quota() {
191 let cs = vec![(id(1), 100), (id(2), 100)];
192 let err = apply_quota_reduction(QuotaReductionPolicy::Reject, 100, &cs).unwrap_err();
193 assert_eq!(
194 err,
195 QuotaReductionError::WouldViolateChildren {
196 current_usage: 200,
197 new_quota: 100,
198 }
199 );
200 }
201
202 #[test]
203 fn reject_at_exact_limit_allows() {
204 let cs = vec![(id(1), 50), (id(2), 50)];
205 let result = apply_quota_reduction(QuotaReductionPolicy::Reject, 100, &cs).unwrap();
206 assert_eq!(result, cs);
207 }
208
209 #[test]
210 fn grandfather_returns_children_unchanged() {
211 let cs = vec![(id(1), 100), (id(2), 100)];
212 let result =
213 apply_quota_reduction(QuotaReductionPolicy::GrandfatherExisting, 100, &cs).unwrap();
214 assert_eq!(result, cs);
215 }
216
217 #[test]
218 fn throttle_basic() {
219 let cs = vec![(id(1), 100), (id(2), 200), (id(3), 300)];
221 let result =
222 apply_quota_reduction(QuotaReductionPolicy::ThrottleProportional, 300, &cs).unwrap();
223 assert_eq!(result, vec![(id(1), 50), (id(2), 100), (id(3), 150)]);
224 }
225
226 #[test]
227 fn throttle_remainder_distributes_ascending() {
228 let cs = vec![(id(1), 100), (id(2), 100), (id(3), 100)];
231 let result =
232 apply_quota_reduction(QuotaReductionPolicy::ThrottleProportional, 100, &cs).unwrap();
233 assert_eq!(result, vec![(id(1), 34), (id(2), 33), (id(3), 33)]);
234 let sum: u64 = result.iter().map(|(_, q)| *q).sum();
235 assert_eq!(sum, 100);
236 }
237
238 #[test]
239 fn throttle_zero_total_idempotent() {
240 let cs = vec![(id(1), 0), (id(2), 0)];
241 let result =
242 apply_quota_reduction(QuotaReductionPolicy::ThrottleProportional, 100, &cs).unwrap();
243 assert_eq!(result, cs);
244 }
245
246 #[test]
247 fn throttle_deterministic() {
248 let cs = vec![(id(1), 100), (id(2), 100), (id(3), 100)];
249 let r1 =
250 apply_quota_reduction(QuotaReductionPolicy::ThrottleProportional, 100, &cs).unwrap();
251 let r2 =
252 apply_quota_reduction(QuotaReductionPolicy::ThrottleProportional, 100, &cs).unwrap();
253 assert_eq!(r1, r2);
254 }
255
256 #[test]
257 fn apply_to_empty_children_returns_empty() {
258 let cs: Vec<(InstanceId, u64)> = vec![];
259 for policy in [
260 QuotaReductionPolicy::Reject,
261 QuotaReductionPolicy::GrandfatherExisting,
262 QuotaReductionPolicy::ThrottleProportional,
263 ] {
264 let result = apply_quota_reduction(policy, 100, &cs).unwrap();
265 assert!(
266 result.is_empty(),
267 "policy {:?} should pass empty through",
268 policy
269 );
270 }
271 }
272}