kizzasi_logic/constraint/
basic.rs

1//! Constraint definitions and builders
2//!
3//! Provides constraint types for signal bounds and composition operators
4//! for building complex constraint expressions, including temporal constraints
5//! for rate-of-change limits.
6
7use crate::error::{LogicError, LogicResult};
8use serde::{Deserialize, Serialize};
9
10/// Logical operators for combining constraints
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum LogicalOperator {
13    And,
14    Or,
15    Not,
16    Implies,
17}
18
19/// Composed constraint from multiple constraints with logical operators
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum ComposedConstraint {
22    /// Single constraint
23    Single(Constraint),
24    /// AND of two constraints (both must be satisfied)
25    And(Box<ComposedConstraint>, Box<ComposedConstraint>),
26    /// OR of two constraints (at least one must be satisfied)
27    Or(Box<ComposedConstraint>, Box<ComposedConstraint>),
28    /// NOT constraint (must not be satisfied)
29    Not(Box<ComposedConstraint>),
30    /// Implies: if A then B (equivalent to !A OR B)
31    Implies(Box<ComposedConstraint>, Box<ComposedConstraint>),
32}
33
34impl ComposedConstraint {
35    /// Create a single constraint
36    pub fn single(constraint: Constraint) -> Self {
37        Self::Single(constraint)
38    }
39
40    /// Combine with AND operator
41    pub fn and(self, other: ComposedConstraint) -> Self {
42        Self::And(Box::new(self), Box::new(other))
43    }
44
45    /// Combine with OR operator
46    pub fn or(self, other: ComposedConstraint) -> Self {
47        Self::Or(Box::new(self), Box::new(other))
48    }
49
50    /// Negate this constraint
51    pub fn negate(self) -> Self {
52        Self::Not(Box::new(self))
53    }
54
55    /// Create implication: if self then other
56    pub fn implies(self, other: ComposedConstraint) -> Self {
57        Self::Implies(Box::new(self), Box::new(other))
58    }
59
60    /// Check if a value satisfies this composed constraint
61    pub fn check(&self, value: f32) -> bool {
62        match self {
63            Self::Single(c) => c.check(value),
64            Self::And(a, b) => a.check(value) && b.check(value),
65            Self::Or(a, b) => a.check(value) || b.check(value),
66            Self::Not(c) => !c.check(value),
67            Self::Implies(a, b) => !a.check(value) || b.check(value),
68        }
69    }
70
71    /// Check all dimensions of a multi-dimensional value
72    pub fn check_all(&self, values: &[f32]) -> bool {
73        match self {
74            Self::Single(c) => {
75                if let Some(dim) = c.dimension() {
76                    values.get(dim).is_some_and(|&v| c.check(v))
77                } else {
78                    values.iter().all(|&v| c.check(v))
79                }
80            }
81            Self::And(a, b) => a.check_all(values) && b.check_all(values),
82            Self::Or(a, b) => a.check_all(values) || b.check_all(values),
83            Self::Not(c) => !c.check_all(values),
84            Self::Implies(a, b) => !a.check_all(values) || b.check_all(values),
85        }
86    }
87
88    /// Compute total violation (used for loss computation)
89    pub fn violation(&self, value: f32) -> f32 {
90        match self {
91            Self::Single(c) => c.violation(value),
92            Self::And(a, b) => a.violation(value) + b.violation(value),
93            Self::Or(a, b) => a.violation(value).min(b.violation(value)),
94            Self::Not(c) => {
95                // If the constraint is satisfied, it's a violation (inverted)
96                if c.check(value) {
97                    1.0
98                } else {
99                    0.0
100                }
101            }
102            Self::Implies(a, b) => {
103                // If A is true and B is false, it's a violation
104                if a.check(value) && !b.check(value) {
105                    b.violation(value)
106                } else {
107                    0.0
108                }
109            }
110        }
111    }
112
113    /// Project a value to satisfy this constraint (best effort)
114    pub fn project(&self, value: f32) -> f32 {
115        match self {
116            Self::Single(c) => c.project(value),
117            Self::And(a, b) => {
118                // Apply both projections sequentially
119                let v1 = a.project(value);
120                b.project(v1)
121            }
122            Self::Or(a, b) => {
123                // Choose the projection with smaller change
124                let proj_a = a.project(value);
125                let proj_b = b.project(value);
126                let dist_a = (value - proj_a).abs();
127                let dist_b = (value - proj_b).abs();
128                if dist_a <= dist_b {
129                    proj_a
130                } else {
131                    proj_b
132                }
133            }
134            Self::Not(_) => {
135                // Cannot project onto "not satisfied" region in general
136                // Return value unchanged
137                value
138            }
139            Self::Implies(a, b) => {
140                // If A is satisfied, must also satisfy B
141                if a.check(value) {
142                    b.project(value)
143                } else {
144                    value
145                }
146            }
147        }
148    }
149}
150
151/// Type of bound constraint
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub enum BoundType {
154    LessThan(f32),
155    LessEq(f32),
156    GreaterThan(f32),
157    GreaterEq(f32),
158    Equal(f32, f32), // value, tolerance
159    InRange(f32, f32),
160}
161
162/// A single constraint on signal values
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Constraint {
165    name: String,
166    dimension: Option<usize>,
167    bound: BoundType,
168    weight: f32,
169}
170
171impl Constraint {
172    /// Check if a value satisfies this constraint
173    pub fn check(&self, value: f32) -> bool {
174        match &self.bound {
175            BoundType::LessThan(b) => value < *b,
176            BoundType::LessEq(b) => value <= *b,
177            BoundType::GreaterThan(b) => value > *b,
178            BoundType::GreaterEq(b) => value >= *b,
179            BoundType::Equal(target, tol) => (value - target).abs() <= *tol,
180            BoundType::InRange(lo, hi) => value >= *lo && value <= *hi,
181        }
182    }
183
184    /// Compute the violation amount (0 if satisfied)
185    pub fn violation(&self, value: f32) -> f32 {
186        match &self.bound {
187            BoundType::LessThan(b) | BoundType::LessEq(b) => (value - b).max(0.0),
188            BoundType::GreaterThan(b) | BoundType::GreaterEq(b) => (b - value).max(0.0),
189            BoundType::Equal(target, _) => (value - target).abs(),
190            BoundType::InRange(lo, hi) => {
191                if value < *lo {
192                    lo - value
193                } else if value > *hi {
194                    value - hi
195                } else {
196                    0.0
197                }
198            }
199        }
200    }
201
202    /// Project a value onto the valid region
203    pub fn project(&self, value: f32) -> f32 {
204        match &self.bound {
205            BoundType::LessThan(b) => value.min(*b - f32::EPSILON),
206            BoundType::LessEq(b) => value.min(*b),
207            BoundType::GreaterThan(b) => value.max(*b + f32::EPSILON),
208            BoundType::GreaterEq(b) => value.max(*b),
209            BoundType::Equal(target, _) => *target,
210            BoundType::InRange(lo, hi) => value.clamp(*lo, *hi),
211        }
212    }
213
214    /// Get the constraint name
215    pub fn name(&self) -> &str {
216        &self.name
217    }
218
219    /// Get the target dimension (if specific)
220    pub fn dimension(&self) -> Option<usize> {
221        self.dimension
222    }
223
224    /// Get the constraint weight for loss computation
225    pub fn weight(&self) -> f32 {
226        self.weight
227    }
228}
229
230/// Rate-of-change constraint type
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum RateType {
233    /// Maximum rate of change: |dx/dt| <= max_rate
234    MaxRate(f32),
235    /// Rate must be in range: min_rate <= dx/dt <= max_rate
236    RateRange { min_rate: f32, max_rate: f32 },
237    /// Rate must be non-negative (monotonic increasing): dx/dt >= 0
238    MonotonicIncreasing,
239    /// Rate must be non-positive (monotonic decreasing): dx/dt <= 0
240    MonotonicDecreasing,
241}
242
243/// Temporal constraint for rate-of-change limits
244///
245/// Tracks previous values to compute derivatives and enforce rate constraints.
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TemporalConstraint {
248    name: String,
249    dimension: Option<usize>,
250    rate_type: RateType,
251    dt: f32,
252    weight: f32,
253}
254
255impl TemporalConstraint {
256    /// Get constraint name
257    pub fn name(&self) -> &str {
258        &self.name
259    }
260
261    /// Get the target dimension (if specific)
262    pub fn dimension(&self) -> Option<usize> {
263        self.dimension
264    }
265
266    /// Get the time step
267    pub fn dt(&self) -> f32 {
268        self.dt
269    }
270
271    /// Get the constraint weight
272    pub fn weight(&self) -> f32 {
273        self.weight
274    }
275
276    /// Check if a value transition satisfies this constraint
277    pub fn check(&self, prev_value: f32, current_value: f32) -> bool {
278        let rate = (current_value - prev_value) / self.dt;
279        match &self.rate_type {
280            RateType::MaxRate(max) => rate.abs() <= *max,
281            RateType::RateRange { min_rate, max_rate } => rate >= *min_rate && rate <= *max_rate,
282            RateType::MonotonicIncreasing => rate >= 0.0,
283            RateType::MonotonicDecreasing => rate <= 0.0,
284        }
285    }
286
287    /// Compute violation amount for rate constraint
288    pub fn violation(&self, prev_value: f32, current_value: f32) -> f32 {
289        let rate = (current_value - prev_value) / self.dt;
290        match &self.rate_type {
291            RateType::MaxRate(max) => (rate.abs() - max).max(0.0),
292            RateType::RateRange { min_rate, max_rate } => {
293                if rate < *min_rate {
294                    min_rate - rate
295                } else if rate > *max_rate {
296                    rate - max_rate
297                } else {
298                    0.0
299                }
300            }
301            RateType::MonotonicIncreasing => (-rate).max(0.0),
302            RateType::MonotonicDecreasing => rate.max(0.0),
303        }
304    }
305
306    /// Project a value to satisfy rate constraint given previous value
307    pub fn project(&self, prev_value: f32, current_value: f32) -> f32 {
308        let rate = (current_value - prev_value) / self.dt;
309        match &self.rate_type {
310            RateType::MaxRate(max) => {
311                if rate.abs() <= *max {
312                    current_value
313                } else {
314                    prev_value + rate.signum() * max * self.dt
315                }
316            }
317            RateType::RateRange { min_rate, max_rate } => {
318                let clamped_rate = rate.clamp(*min_rate, *max_rate);
319                prev_value + clamped_rate * self.dt
320            }
321            RateType::MonotonicIncreasing => {
322                if rate >= 0.0 {
323                    current_value
324                } else {
325                    prev_value // Stay at previous value
326                }
327            }
328            RateType::MonotonicDecreasing => {
329                if rate <= 0.0 {
330                    current_value
331                } else {
332                    prev_value // Stay at previous value
333                }
334            }
335        }
336    }
337}
338
339/// Builder for temporal constraints
340#[derive(Default)]
341pub struct TemporalConstraintBuilder {
342    name: Option<String>,
343    dimension: Option<usize>,
344    rate_type: Option<RateType>,
345    dt: Option<f32>,
346    weight: f32,
347}
348
349impl TemporalConstraintBuilder {
350    /// Create a new temporal constraint builder
351    pub fn new() -> Self {
352        Self {
353            weight: 1.0,
354            ..Default::default()
355        }
356    }
357
358    /// Set the constraint name
359    pub fn name(mut self, name: &str) -> Self {
360        self.name = Some(name.to_string());
361        self
362    }
363
364    /// Set the target dimension
365    pub fn dimension(mut self, dim: usize) -> Self {
366        self.dimension = Some(dim);
367        self
368    }
369
370    /// Set maximum rate of change constraint: |dx/dt| <= max_rate
371    pub fn max_rate(mut self, max_rate: f32) -> Self {
372        self.rate_type = Some(RateType::MaxRate(max_rate));
373        self
374    }
375
376    /// Set rate range constraint: min_rate <= dx/dt <= max_rate
377    pub fn rate_range(mut self, min_rate: f32, max_rate: f32) -> Self {
378        self.rate_type = Some(RateType::RateRange { min_rate, max_rate });
379        self
380    }
381
382    /// Set monotonic increasing constraint: dx/dt >= 0
383    pub fn monotonic_increasing(mut self) -> Self {
384        self.rate_type = Some(RateType::MonotonicIncreasing);
385        self
386    }
387
388    /// Set monotonic decreasing constraint: dx/dt <= 0
389    pub fn monotonic_decreasing(mut self) -> Self {
390        self.rate_type = Some(RateType::MonotonicDecreasing);
391        self
392    }
393
394    /// Set the time step (dt)
395    pub fn dt(mut self, dt: f32) -> Self {
396        self.dt = Some(dt);
397        self
398    }
399
400    /// Set the constraint weight
401    pub fn weight(mut self, w: f32) -> Self {
402        self.weight = w;
403        self
404    }
405
406    /// Build the temporal constraint
407    pub fn build(self) -> LogicResult<TemporalConstraint> {
408        let name = self
409            .name
410            .ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
411        let rate_type = self
412            .rate_type
413            .ok_or_else(|| LogicError::InvalidConstraint("rate_type is required".into()))?;
414        let dt = self
415            .dt
416            .ok_or_else(|| LogicError::InvalidConstraint("dt (time step) is required".into()))?;
417
418        if dt <= 0.0 {
419            return Err(LogicError::InvalidConstraint("dt must be positive".into()));
420        }
421
422        Ok(TemporalConstraint {
423            name,
424            dimension: self.dimension,
425            rate_type,
426            dt,
427            weight: self.weight,
428        })
429    }
430}
431
432/// Temporal constraint checker that maintains state
433#[derive(Debug, Clone)]
434pub struct TemporalChecker {
435    constraints: Vec<TemporalConstraint>,
436    prev_values: Vec<f32>,
437    initialized: bool,
438}
439
440impl TemporalChecker {
441    /// Create a new temporal checker with given constraints
442    pub fn new(constraints: Vec<TemporalConstraint>) -> Self {
443        Self {
444            constraints,
445            prev_values: Vec::new(),
446            initialized: false,
447        }
448    }
449
450    /// Reset the checker state
451    pub fn reset(&mut self) {
452        self.prev_values.clear();
453        self.initialized = false;
454    }
455
456    /// Check all temporal constraints for new values
457    pub fn check(&mut self, values: &[f32]) -> Vec<(String, bool)> {
458        if !self.initialized {
459            self.prev_values = values.to_vec();
460            self.initialized = true;
461            return self
462                .constraints
463                .iter()
464                .map(|c| (c.name.clone(), true))
465                .collect();
466        }
467
468        let results: Vec<(String, bool)> = self
469            .constraints
470            .iter()
471            .map(|c| {
472                let result = if let Some(dim) = c.dimension() {
473                    if dim < values.len() && dim < self.prev_values.len() {
474                        c.check(self.prev_values[dim], values[dim])
475                    } else {
476                        true // Dimension out of bounds, consider satisfied
477                    }
478                } else {
479                    // Check all dimensions
480                    values
481                        .iter()
482                        .zip(self.prev_values.iter())
483                        .all(|(&curr, &prev)| c.check(prev, curr))
484                };
485                (c.name.clone(), result)
486            })
487            .collect();
488
489        self.prev_values = values.to_vec();
490        results
491    }
492
493    /// Check all constraints and return total violation
494    pub fn total_violation(&mut self, values: &[f32]) -> f32 {
495        if !self.initialized {
496            self.prev_values = values.to_vec();
497            self.initialized = true;
498            return 0.0;
499        }
500
501        let violation: f32 = self
502            .constraints
503            .iter()
504            .map(|c| {
505                let v = if let Some(dim) = c.dimension() {
506                    if dim < values.len() && dim < self.prev_values.len() {
507                        c.violation(self.prev_values[dim], values[dim])
508                    } else {
509                        0.0
510                    }
511                } else {
512                    values
513                        .iter()
514                        .zip(self.prev_values.iter())
515                        .map(|(&curr, &prev)| c.violation(prev, curr))
516                        .sum()
517                };
518                v * c.weight()
519            })
520            .sum();
521
522        self.prev_values = values.to_vec();
523        violation
524    }
525
526    /// Project values to satisfy rate constraints
527    pub fn project(&mut self, values: &[f32]) -> Vec<f32> {
528        if !self.initialized {
529            self.prev_values = values.to_vec();
530            self.initialized = true;
531            return values.to_vec();
532        }
533
534        let mut projected = values.to_vec();
535
536        for c in &self.constraints {
537            if let Some(dim) = c.dimension() {
538                if dim < projected.len() && dim < self.prev_values.len() {
539                    projected[dim] = c.project(self.prev_values[dim], projected[dim]);
540                }
541            } else {
542                for i in 0..projected.len().min(self.prev_values.len()) {
543                    projected[i] = c.project(self.prev_values[i], projected[i]);
544                }
545            }
546        }
547
548        self.prev_values = projected.clone();
549        projected
550    }
551
552    /// Check if all constraints are satisfied
553    pub fn all_satisfied(&mut self, values: &[f32]) -> bool {
554        self.check(values).iter().all(|(_, sat)| *sat)
555    }
556}
557
558/// Builder for constructing constraints
559pub struct ConstraintBuilder {
560    name: Option<String>,
561    dimension: Option<usize>,
562    bound: Option<BoundType>,
563    weight: f32,
564}
565
566impl Default for ConstraintBuilder {
567    fn default() -> Self {
568        Self::new()
569    }
570}
571
572impl ConstraintBuilder {
573    /// Create a new constraint builder
574    pub fn new() -> Self {
575        Self {
576            name: None,
577            dimension: None,
578            bound: None,
579            weight: 1.0,
580        }
581    }
582
583    /// Set the constraint name
584    pub fn name(mut self, name: &str) -> Self {
585        self.name = Some(name.to_string());
586        self
587    }
588
589    /// Set the target dimension
590    pub fn dimension(mut self, dim: usize) -> Self {
591        self.dimension = Some(dim);
592        self
593    }
594
595    /// Set less-than bound
596    pub fn less_than(mut self, value: f32) -> Self {
597        self.bound = Some(BoundType::LessThan(value));
598        self
599    }
600
601    /// Set less-than-or-equal bound
602    pub fn less_eq(mut self, value: f32) -> Self {
603        self.bound = Some(BoundType::LessEq(value));
604        self
605    }
606
607    /// Set greater-than bound
608    pub fn greater_than(mut self, value: f32) -> Self {
609        self.bound = Some(BoundType::GreaterThan(value));
610        self
611    }
612
613    /// Set greater-than-or-equal bound
614    pub fn greater_eq(mut self, value: f32) -> Self {
615        self.bound = Some(BoundType::GreaterEq(value));
616        self
617    }
618
619    /// Set equality constraint with tolerance
620    pub fn equal(mut self, value: f32, tolerance: f32) -> Self {
621        self.bound = Some(BoundType::Equal(value, tolerance));
622        self
623    }
624
625    /// Set range constraint
626    pub fn in_range(mut self, lo: f32, hi: f32) -> Self {
627        self.bound = Some(BoundType::InRange(lo, hi));
628        self
629    }
630
631    /// Set the constraint weight
632    pub fn weight(mut self, w: f32) -> Self {
633        self.weight = w;
634        self
635    }
636
637    /// Build the constraint
638    pub fn build(self) -> LogicResult<Constraint> {
639        let name = self
640            .name
641            .ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
642        let bound = self
643            .bound
644            .ok_or_else(|| LogicError::InvalidConstraint("bound is required".into()))?;
645
646        Ok(Constraint {
647            name,
648            dimension: self.dimension,
649            bound,
650            weight: self.weight,
651        })
652    }
653}