1use crate::constraint::ConstraintChannel;
9use crate::KernelResult;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ConstraintDefinition {
15 pub id: String,
17 pub name: String,
19 pub description: String,
21 pub domain_tag: String,
23 pub dimensions: Vec<String>,
25 pub rule: MarginRule,
27 #[serde(default)]
29 pub thresholds: ThresholdOverrides,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(tag = "type")]
35pub enum MarginRule {
36 #[serde(rename = "budget_ratio")]
39 BudgetRatio {
40 dimension_index: usize,
41 budget: f64,
42 },
43
44 #[serde(rename = "range_bound")]
47 RangeBound {
48 dimension_index: usize,
49 min: f64,
50 max: f64,
51 },
52
53 #[serde(rename = "pattern_match")]
56 PatternMatch {
57 dimension_index: usize,
58 approved_values: Vec<f64>,
60 #[serde(default = "default_tolerance")]
62 tolerance: f64,
63 #[serde(default = "default_fallback_margin")]
65 fallback_margin: f64,
66 },
67
68 #[serde(rename = "weighted_sum")]
71 WeightedSum {
72 weights: Vec<(usize, f64)>,
74 #[serde(default)]
76 offset: f64,
77 },
78
79 #[serde(rename = "constant")]
81 Constant { margin: f64 },
82}
83
84fn default_tolerance() -> f64 {
85 1e-10
86}
87
88fn default_fallback_margin() -> f64 {
89 0.0
90}
91
92#[derive(Debug, Clone, Default, Serialize, Deserialize)]
94pub struct ThresholdOverrides {
95 pub safe_threshold: Option<f64>,
96 pub caution_threshold: Option<f64>,
97 pub block_threshold: Option<f64>,
98}
99
100pub struct DeclarativeChannel {
102 definition: ConstraintDefinition,
103}
104
105impl DeclarativeChannel {
106 pub fn new(definition: ConstraintDefinition) -> Self {
108 Self { definition }
109 }
110
111 pub fn definition(&self) -> &ConstraintDefinition {
113 &self.definition
114 }
115
116 pub fn from_json(json: &str) -> KernelResult<Self> {
118 let definition: ConstraintDefinition =
119 serde_json::from_str(json).map_err(|e| crate::KernelError::DeclarativeError(e.to_string()))?;
120 Ok(Self::new(definition))
121 }
122}
123
124impl ConstraintChannel for DeclarativeChannel {
125 fn name(&self) -> &str {
126 &self.definition.name
127 }
128
129 fn evaluate(&self, state: &[f64]) -> KernelResult<f64> {
130 let margin = evaluate_rule(&self.definition.rule, state)?;
131 Ok(margin.clamp(0.0, 1.0))
132 }
133
134 fn dimension_names(&self) -> Vec<String> {
135 self.definition.dimensions.clone()
136 }
137}
138
139fn evaluate_rule(rule: &MarginRule, state: &[f64]) -> KernelResult<f64> {
141 match rule {
142 MarginRule::BudgetRatio {
143 dimension_index,
144 budget,
145 } => {
146 let value = get_state_value(state, *dimension_index)?;
147 if *budget <= 0.0 {
148 return Ok(0.0);
149 }
150 Ok(1.0 - (value / budget))
152 }
153
154 MarginRule::RangeBound {
155 dimension_index,
156 min,
157 max,
158 } => {
159 let value = get_state_value(state, *dimension_index)?;
160 if max <= min {
161 return Ok(0.0);
162 }
163 let range = max - min;
164 let midpoint = (min + max) / 2.0;
165 let distance_to_edge = (range / 2.0) - (value - midpoint).abs();
166 Ok(distance_to_edge / (range / 2.0))
167 }
168
169 MarginRule::PatternMatch {
170 dimension_index,
171 approved_values,
172 tolerance,
173 fallback_margin,
174 } => {
175 let value = get_state_value(state, *dimension_index)?;
176 for approved in approved_values {
177 if (value - approved).abs() <= *tolerance {
178 return Ok(1.0);
179 }
180 }
181 Ok(*fallback_margin)
182 }
183
184 MarginRule::WeightedSum { weights, offset } => {
185 let mut sum = *offset;
186 for (dim_index, weight) in weights {
187 let value = get_state_value(state, *dim_index)?;
188 sum += value * weight;
189 }
190 Ok(sum)
191 }
192
193 MarginRule::Constant { margin } => Ok(*margin),
194 }
195}
196
197fn get_state_value(state: &[f64], index: usize) -> KernelResult<f64> {
199 state
200 .get(index)
201 .copied()
202 .ok_or_else(|| crate::KernelError::DeclarativeError(
203 format!("Dimension index {} out of bounds (state has {} dimensions)", index, state.len()),
204 ))
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 fn budget_channel() -> DeclarativeChannel {
212 DeclarativeChannel::new(ConstraintDefinition {
213 id: "budget_test".into(),
214 name: "Budget Test".into(),
215 description: "Test budget ratio rule".into(),
216 domain_tag: "test".into(),
217 dimensions: vec!["spend".into()],
218 rule: MarginRule::BudgetRatio {
219 dimension_index: 0,
220 budget: 1000.0,
221 },
222 thresholds: ThresholdOverrides::default(),
223 })
224 }
225
226 #[test]
227 fn budget_ratio_zero_usage() {
228 let ch = budget_channel();
229 let margin = ch.evaluate(&[0.0]).unwrap();
230 assert!((margin - 1.0).abs() < f64::EPSILON);
231 }
232
233 #[test]
234 fn budget_ratio_half_usage() {
235 let ch = budget_channel();
236 let margin = ch.evaluate(&[500.0]).unwrap();
237 assert!((margin - 0.5).abs() < f64::EPSILON);
238 }
239
240 #[test]
241 fn budget_ratio_full_usage() {
242 let ch = budget_channel();
243 let margin = ch.evaluate(&[1000.0]).unwrap();
244 assert!((margin - 0.0).abs() < f64::EPSILON);
245 }
246
247 #[test]
248 fn budget_ratio_over_usage_clamped() {
249 let ch = budget_channel();
250 let margin = ch.evaluate(&[1500.0]).unwrap();
251 assert!((margin - 0.0).abs() < f64::EPSILON); }
253
254 #[test]
255 fn range_bound_at_center() {
256 let ch = DeclarativeChannel::new(ConstraintDefinition {
257 id: "range_test".into(),
258 name: "Range Test".into(),
259 description: "Test range bound rule".into(),
260 domain_tag: "test".into(),
261 dimensions: vec!["temperature".into()],
262 rule: MarginRule::RangeBound {
263 dimension_index: 0,
264 min: 0.0,
265 max: 100.0,
266 },
267 thresholds: ThresholdOverrides::default(),
268 });
269 let margin = ch.evaluate(&[50.0]).unwrap();
270 assert!((margin - 1.0).abs() < f64::EPSILON);
271 }
272
273 #[test]
274 fn range_bound_at_edge() {
275 let ch = DeclarativeChannel::new(ConstraintDefinition {
276 id: "range_test".into(),
277 name: "Range Test".into(),
278 description: "".into(),
279 domain_tag: "test".into(),
280 dimensions: vec!["x".into()],
281 rule: MarginRule::RangeBound {
282 dimension_index: 0,
283 min: 0.0,
284 max: 100.0,
285 },
286 thresholds: ThresholdOverrides::default(),
287 });
288 let margin = ch.evaluate(&[0.0]).unwrap();
289 assert!((margin - 0.0).abs() < f64::EPSILON);
290 }
291
292 #[test]
293 fn pattern_match_approved() {
294 let ch = DeclarativeChannel::new(ConstraintDefinition {
295 id: "pattern_test".into(),
296 name: "Pattern Test".into(),
297 description: "".into(),
298 domain_tag: "test".into(),
299 dimensions: vec!["code".into()],
300 rule: MarginRule::PatternMatch {
301 dimension_index: 0,
302 approved_values: vec![1.0, 2.0, 3.0],
303 tolerance: 1e-10,
304 fallback_margin: 0.0,
305 },
306 thresholds: ThresholdOverrides::default(),
307 });
308 assert!((ch.evaluate(&[2.0]).unwrap() - 1.0).abs() < f64::EPSILON);
309 assert!((ch.evaluate(&[5.0]).unwrap() - 0.0).abs() < f64::EPSILON);
310 }
311
312 #[test]
313 fn weighted_sum_basic() {
314 let ch = DeclarativeChannel::new(ConstraintDefinition {
315 id: "ws_test".into(),
316 name: "Weighted Sum Test".into(),
317 description: "".into(),
318 domain_tag: "test".into(),
319 dimensions: vec!["a".into(), "b".into()],
320 rule: MarginRule::WeightedSum {
321 weights: vec![(0, 0.3), (1, 0.7)],
322 offset: 0.0,
323 },
324 thresholds: ThresholdOverrides::default(),
325 });
326 let margin = ch.evaluate(&[1.0, 1.0]).unwrap();
327 assert!((margin - 1.0).abs() < f64::EPSILON);
328 }
329
330 #[test]
331 fn constant_margin() {
332 let ch = DeclarativeChannel::new(ConstraintDefinition {
333 id: "const_test".into(),
334 name: "Constant".into(),
335 description: "".into(),
336 domain_tag: "test".into(),
337 dimensions: vec![],
338 rule: MarginRule::Constant { margin: 0.42 },
339 thresholds: ThresholdOverrides::default(),
340 });
341 let margin = ch.evaluate(&[]).unwrap();
342 assert!((margin - 0.42).abs() < f64::EPSILON);
343 }
344
345 #[test]
346 fn from_json_roundtrip() {
347 let json = r#"{
348 "id": "json_test",
349 "name": "JSON Channel",
350 "description": "Loaded from JSON",
351 "domain_tag": "agentic",
352 "dimensions": ["budget_used"],
353 "rule": {
354 "type": "budget_ratio",
355 "dimension_index": 0,
356 "budget": 500.0
357 }
358 }"#;
359
360 let ch = DeclarativeChannel::from_json(json).unwrap();
361 assert_eq!(ch.name(), "JSON Channel");
362 assert_eq!(ch.definition().domain_tag, "agentic");
363
364 let margin = ch.evaluate(&[250.0]).unwrap();
365 assert!((margin - 0.5).abs() < f64::EPSILON);
366 }
367
368 #[test]
369 fn dimension_out_of_bounds() {
370 let ch = DeclarativeChannel::new(ConstraintDefinition {
371 id: "oob_test".into(),
372 name: "OOB".into(),
373 description: "".into(),
374 domain_tag: "test".into(),
375 dimensions: vec!["x".into()],
376 rule: MarginRule::BudgetRatio {
377 dimension_index: 5,
378 budget: 100.0,
379 },
380 thresholds: ThresholdOverrides::default(),
381 });
382 assert!(ch.evaluate(&[1.0]).is_err());
383 }
384
385 #[test]
386 fn serialization_roundtrip() {
387 let def = ConstraintDefinition {
388 id: "ser_test".into(),
389 name: "Serializable".into(),
390 description: "Test serialization".into(),
391 domain_tag: "test".into(),
392 dimensions: vec!["a".into(), "b".into()],
393 rule: MarginRule::WeightedSum {
394 weights: vec![(0, 0.5), (1, 0.5)],
395 offset: 0.0,
396 },
397 thresholds: ThresholdOverrides {
398 safe_threshold: Some(0.7),
399 caution_threshold: None,
400 block_threshold: Some(0.05),
401 },
402 };
403
404 let json = serde_json::to_string(&def).unwrap();
405 let deserialized: ConstraintDefinition = serde_json::from_str(&json).unwrap();
406 assert_eq!(deserialized.id, "ser_test");
407 assert_eq!(deserialized.thresholds.safe_threshold, Some(0.7));
408 }
409}