Skip to main content

qae_kernel/
declarative.rs

1// SPDX-License-Identifier: BUSL-1.1
2//! Declarative constraint channels — JSON/TOML-defined constraints evaluated at runtime.
3//!
4//! A `DeclarativeChannel` wraps a `ConstraintDefinition` (which can be loaded from
5//! JSON or TOML) and evaluates it against a state vector using a generic interpreter.
6//! This enables constraint management without recompilation.
7
8use crate::constraint::ConstraintChannel;
9use crate::KernelResult;
10use serde::{Deserialize, Serialize};
11
12/// A constraint defined declaratively (JSON/TOML), not compiled Rust.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ConstraintDefinition {
15    /// Unique identifier for this constraint.
16    pub id: String,
17    /// Human-readable name.
18    pub name: String,
19    /// Description shown in UI.
20    pub description: String,
21    /// Domain tag (e.g., "finance", "agentic", "custom").
22    pub domain_tag: String,
23    /// Names of state dimensions this constraint operates on.
24    pub dimensions: Vec<String>,
25    /// The margin computation rule.
26    pub rule: MarginRule,
27    /// Optional per-channel threshold overrides.
28    #[serde(default)]
29    pub thresholds: ThresholdOverrides,
30}
31
32/// How to compute the margin from a state vector.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(tag = "type")]
35pub enum MarginRule {
36    /// Margin = value / budget. Produces 1.0 when value == 0, 0.0 when value >= budget.
37    /// `dimension_index` selects which state dimension to read.
38    #[serde(rename = "budget_ratio")]
39    BudgetRatio {
40        dimension_index: usize,
41        budget: f64,
42    },
43
44    /// Margin = how far value is from the [min, max] boundary, normalized to [0, 1].
45    /// Inside range → margin based on distance to nearest edge.
46    #[serde(rename = "range_bound")]
47    RangeBound {
48        dimension_index: usize,
49        min: f64,
50        max: f64,
51    },
52
53    /// Margin = 1.0 if value matches pattern, else fallback.
54    /// Useful for categorical checks (e.g., "is this in an approved list?").
55    #[serde(rename = "pattern_match")]
56    PatternMatch {
57        dimension_index: usize,
58        /// Values that produce margin = 1.0
59        approved_values: Vec<f64>,
60        /// Tolerance for floating-point comparison
61        #[serde(default = "default_tolerance")]
62        tolerance: f64,
63        /// Margin when value is not approved
64        #[serde(default = "default_fallback_margin")]
65        fallback_margin: f64,
66    },
67
68    /// Margin computed from a simple arithmetic expression over state dimensions.
69    /// Supports: weighted sum clamped to [0, 1].
70    #[serde(rename = "weighted_sum")]
71    WeightedSum {
72        /// (dimension_index, weight) pairs
73        weights: Vec<(usize, f64)>,
74        /// Offset added after weighted sum
75        #[serde(default)]
76        offset: f64,
77    },
78
79    /// Constant margin — useful for testing or placeholder channels.
80    #[serde(rename = "constant")]
81    Constant { margin: f64 },
82}
83
84fn default_tolerance() -> f64 {
85    1e-10
86}
87
88fn default_fallback_margin() -> f64 {
89    0.0
90}
91
92/// Optional per-channel threshold overrides.
93#[derive(Debug, Clone, Default, Serialize, Deserialize)]
94pub struct ThresholdOverrides {
95    pub safe_threshold: Option<f64>,
96    pub caution_threshold: Option<f64>,
97    pub block_threshold: Option<f64>,
98}
99
100/// A constraint channel backed by a declarative definition.
101pub struct DeclarativeChannel {
102    definition: ConstraintDefinition,
103}
104
105impl DeclarativeChannel {
106    /// Create a new declarative channel from a definition.
107    pub fn new(definition: ConstraintDefinition) -> Self {
108        Self { definition }
109    }
110
111    /// Get the underlying definition.
112    pub fn definition(&self) -> &ConstraintDefinition {
113        &self.definition
114    }
115
116    /// Parse a constraint definition from JSON.
117    pub fn from_json(json: &str) -> KernelResult<Self> {
118        let definition: ConstraintDefinition =
119            serde_json::from_str(json).map_err(|e| crate::KernelError::DeclarativeError(e.to_string()))?;
120        Ok(Self::new(definition))
121    }
122}
123
124impl ConstraintChannel for DeclarativeChannel {
125    fn name(&self) -> &str {
126        &self.definition.name
127    }
128
129    fn evaluate(&self, state: &[f64]) -> KernelResult<f64> {
130        let margin = evaluate_rule(&self.definition.rule, state)?;
131        Ok(margin.clamp(0.0, 1.0))
132    }
133
134    fn dimension_names(&self) -> Vec<String> {
135        self.definition.dimensions.clone()
136    }
137}
138
139/// Evaluate a margin rule against a state vector.
140fn evaluate_rule(rule: &MarginRule, state: &[f64]) -> KernelResult<f64> {
141    match rule {
142        MarginRule::BudgetRatio {
143            dimension_index,
144            budget,
145        } => {
146            let value = get_state_value(state, *dimension_index)?;
147            if *budget <= 0.0 {
148                return Ok(0.0);
149            }
150            // margin = 1 - (value / budget), clamped
151            Ok(1.0 - (value / budget))
152        }
153
154        MarginRule::RangeBound {
155            dimension_index,
156            min,
157            max,
158        } => {
159            let value = get_state_value(state, *dimension_index)?;
160            if max <= min {
161                return Ok(0.0);
162            }
163            let range = max - min;
164            let midpoint = (min + max) / 2.0;
165            let distance_to_edge = (range / 2.0) - (value - midpoint).abs();
166            Ok(distance_to_edge / (range / 2.0))
167        }
168
169        MarginRule::PatternMatch {
170            dimension_index,
171            approved_values,
172            tolerance,
173            fallback_margin,
174        } => {
175            let value = get_state_value(state, *dimension_index)?;
176            for approved in approved_values {
177                if (value - approved).abs() <= *tolerance {
178                    return Ok(1.0);
179                }
180            }
181            Ok(*fallback_margin)
182        }
183
184        MarginRule::WeightedSum { weights, offset } => {
185            let mut sum = *offset;
186            for (dim_index, weight) in weights {
187                let value = get_state_value(state, *dim_index)?;
188                sum += value * weight;
189            }
190            Ok(sum)
191        }
192
193        MarginRule::Constant { margin } => Ok(*margin),
194    }
195}
196
197/// Safely get a state value by dimension index.
198fn get_state_value(state: &[f64], index: usize) -> KernelResult<f64> {
199    state
200        .get(index)
201        .copied()
202        .ok_or_else(|| crate::KernelError::DeclarativeError(
203            format!("Dimension index {} out of bounds (state has {} dimensions)", index, state.len()),
204        ))
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    fn budget_channel() -> DeclarativeChannel {
212        DeclarativeChannel::new(ConstraintDefinition {
213            id: "budget_test".into(),
214            name: "Budget Test".into(),
215            description: "Test budget ratio rule".into(),
216            domain_tag: "test".into(),
217            dimensions: vec!["spend".into()],
218            rule: MarginRule::BudgetRatio {
219                dimension_index: 0,
220                budget: 1000.0,
221            },
222            thresholds: ThresholdOverrides::default(),
223        })
224    }
225
226    #[test]
227    fn budget_ratio_zero_usage() {
228        let ch = budget_channel();
229        let margin = ch.evaluate(&[0.0]).unwrap();
230        assert!((margin - 1.0).abs() < f64::EPSILON);
231    }
232
233    #[test]
234    fn budget_ratio_half_usage() {
235        let ch = budget_channel();
236        let margin = ch.evaluate(&[500.0]).unwrap();
237        assert!((margin - 0.5).abs() < f64::EPSILON);
238    }
239
240    #[test]
241    fn budget_ratio_full_usage() {
242        let ch = budget_channel();
243        let margin = ch.evaluate(&[1000.0]).unwrap();
244        assert!((margin - 0.0).abs() < f64::EPSILON);
245    }
246
247    #[test]
248    fn budget_ratio_over_usage_clamped() {
249        let ch = budget_channel();
250        let margin = ch.evaluate(&[1500.0]).unwrap();
251        assert!((margin - 0.0).abs() < f64::EPSILON); // clamped to 0
252    }
253
254    #[test]
255    fn range_bound_at_center() {
256        let ch = DeclarativeChannel::new(ConstraintDefinition {
257            id: "range_test".into(),
258            name: "Range Test".into(),
259            description: "Test range bound rule".into(),
260            domain_tag: "test".into(),
261            dimensions: vec!["temperature".into()],
262            rule: MarginRule::RangeBound {
263                dimension_index: 0,
264                min: 0.0,
265                max: 100.0,
266            },
267            thresholds: ThresholdOverrides::default(),
268        });
269        let margin = ch.evaluate(&[50.0]).unwrap();
270        assert!((margin - 1.0).abs() < f64::EPSILON);
271    }
272
273    #[test]
274    fn range_bound_at_edge() {
275        let ch = DeclarativeChannel::new(ConstraintDefinition {
276            id: "range_test".into(),
277            name: "Range Test".into(),
278            description: "".into(),
279            domain_tag: "test".into(),
280            dimensions: vec!["x".into()],
281            rule: MarginRule::RangeBound {
282                dimension_index: 0,
283                min: 0.0,
284                max: 100.0,
285            },
286            thresholds: ThresholdOverrides::default(),
287        });
288        let margin = ch.evaluate(&[0.0]).unwrap();
289        assert!((margin - 0.0).abs() < f64::EPSILON);
290    }
291
292    #[test]
293    fn pattern_match_approved() {
294        let ch = DeclarativeChannel::new(ConstraintDefinition {
295            id: "pattern_test".into(),
296            name: "Pattern Test".into(),
297            description: "".into(),
298            domain_tag: "test".into(),
299            dimensions: vec!["code".into()],
300            rule: MarginRule::PatternMatch {
301                dimension_index: 0,
302                approved_values: vec![1.0, 2.0, 3.0],
303                tolerance: 1e-10,
304                fallback_margin: 0.0,
305            },
306            thresholds: ThresholdOverrides::default(),
307        });
308        assert!((ch.evaluate(&[2.0]).unwrap() - 1.0).abs() < f64::EPSILON);
309        assert!((ch.evaluate(&[5.0]).unwrap() - 0.0).abs() < f64::EPSILON);
310    }
311
312    #[test]
313    fn weighted_sum_basic() {
314        let ch = DeclarativeChannel::new(ConstraintDefinition {
315            id: "ws_test".into(),
316            name: "Weighted Sum Test".into(),
317            description: "".into(),
318            domain_tag: "test".into(),
319            dimensions: vec!["a".into(), "b".into()],
320            rule: MarginRule::WeightedSum {
321                weights: vec![(0, 0.3), (1, 0.7)],
322                offset: 0.0,
323            },
324            thresholds: ThresholdOverrides::default(),
325        });
326        let margin = ch.evaluate(&[1.0, 1.0]).unwrap();
327        assert!((margin - 1.0).abs() < f64::EPSILON);
328    }
329
330    #[test]
331    fn constant_margin() {
332        let ch = DeclarativeChannel::new(ConstraintDefinition {
333            id: "const_test".into(),
334            name: "Constant".into(),
335            description: "".into(),
336            domain_tag: "test".into(),
337            dimensions: vec![],
338            rule: MarginRule::Constant { margin: 0.42 },
339            thresholds: ThresholdOverrides::default(),
340        });
341        let margin = ch.evaluate(&[]).unwrap();
342        assert!((margin - 0.42).abs() < f64::EPSILON);
343    }
344
345    #[test]
346    fn from_json_roundtrip() {
347        let json = r#"{
348            "id": "json_test",
349            "name": "JSON Channel",
350            "description": "Loaded from JSON",
351            "domain_tag": "agentic",
352            "dimensions": ["budget_used"],
353            "rule": {
354                "type": "budget_ratio",
355                "dimension_index": 0,
356                "budget": 500.0
357            }
358        }"#;
359
360        let ch = DeclarativeChannel::from_json(json).unwrap();
361        assert_eq!(ch.name(), "JSON Channel");
362        assert_eq!(ch.definition().domain_tag, "agentic");
363
364        let margin = ch.evaluate(&[250.0]).unwrap();
365        assert!((margin - 0.5).abs() < f64::EPSILON);
366    }
367
368    #[test]
369    fn dimension_out_of_bounds() {
370        let ch = DeclarativeChannel::new(ConstraintDefinition {
371            id: "oob_test".into(),
372            name: "OOB".into(),
373            description: "".into(),
374            domain_tag: "test".into(),
375            dimensions: vec!["x".into()],
376            rule: MarginRule::BudgetRatio {
377                dimension_index: 5,
378                budget: 100.0,
379            },
380            thresholds: ThresholdOverrides::default(),
381        });
382        assert!(ch.evaluate(&[1.0]).is_err());
383    }
384
385    #[test]
386    fn serialization_roundtrip() {
387        let def = ConstraintDefinition {
388            id: "ser_test".into(),
389            name: "Serializable".into(),
390            description: "Test serialization".into(),
391            domain_tag: "test".into(),
392            dimensions: vec!["a".into(), "b".into()],
393            rule: MarginRule::WeightedSum {
394                weights: vec![(0, 0.5), (1, 0.5)],
395                offset: 0.0,
396            },
397            thresholds: ThresholdOverrides {
398                safe_threshold: Some(0.7),
399                caution_threshold: None,
400                block_threshold: Some(0.05),
401            },
402        };
403
404        let json = serde_json::to_string(&def).unwrap();
405        let deserialized: ConstraintDefinition = serde_json::from_str(&json).unwrap();
406        assert_eq!(deserialized.id, "ser_test");
407        assert_eq!(deserialized.thresholds.safe_threshold, Some(0.7));
408    }
409}