1use crate::error::{LogicError, LogicResult};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum LogicalOperator {
13 And,
14 Or,
15 Not,
16 Implies,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum ComposedConstraint {
22 Single(Constraint),
24 And(Box<ComposedConstraint>, Box<ComposedConstraint>),
26 Or(Box<ComposedConstraint>, Box<ComposedConstraint>),
28 Not(Box<ComposedConstraint>),
30 Implies(Box<ComposedConstraint>, Box<ComposedConstraint>),
32}
33
34impl ComposedConstraint {
35 pub fn single(constraint: Constraint) -> Self {
37 Self::Single(constraint)
38 }
39
40 pub fn and(self, other: ComposedConstraint) -> Self {
42 Self::And(Box::new(self), Box::new(other))
43 }
44
45 pub fn or(self, other: ComposedConstraint) -> Self {
47 Self::Or(Box::new(self), Box::new(other))
48 }
49
50 pub fn negate(self) -> Self {
52 Self::Not(Box::new(self))
53 }
54
55 pub fn implies(self, other: ComposedConstraint) -> Self {
57 Self::Implies(Box::new(self), Box::new(other))
58 }
59
60 pub fn check(&self, value: f32) -> bool {
62 match self {
63 Self::Single(c) => c.check(value),
64 Self::And(a, b) => a.check(value) && b.check(value),
65 Self::Or(a, b) => a.check(value) || b.check(value),
66 Self::Not(c) => !c.check(value),
67 Self::Implies(a, b) => !a.check(value) || b.check(value),
68 }
69 }
70
71 pub fn check_all(&self, values: &[f32]) -> bool {
73 match self {
74 Self::Single(c) => {
75 if let Some(dim) = c.dimension() {
76 values.get(dim).is_some_and(|&v| c.check(v))
77 } else {
78 values.iter().all(|&v| c.check(v))
79 }
80 }
81 Self::And(a, b) => a.check_all(values) && b.check_all(values),
82 Self::Or(a, b) => a.check_all(values) || b.check_all(values),
83 Self::Not(c) => !c.check_all(values),
84 Self::Implies(a, b) => !a.check_all(values) || b.check_all(values),
85 }
86 }
87
88 pub fn violation(&self, value: f32) -> f32 {
90 match self {
91 Self::Single(c) => c.violation(value),
92 Self::And(a, b) => a.violation(value) + b.violation(value),
93 Self::Or(a, b) => a.violation(value).min(b.violation(value)),
94 Self::Not(c) => {
95 if c.check(value) {
97 1.0
98 } else {
99 0.0
100 }
101 }
102 Self::Implies(a, b) => {
103 if a.check(value) && !b.check(value) {
105 b.violation(value)
106 } else {
107 0.0
108 }
109 }
110 }
111 }
112
113 pub fn project(&self, value: f32) -> f32 {
115 match self {
116 Self::Single(c) => c.project(value),
117 Self::And(a, b) => {
118 let v1 = a.project(value);
120 b.project(v1)
121 }
122 Self::Or(a, b) => {
123 let proj_a = a.project(value);
125 let proj_b = b.project(value);
126 let dist_a = (value - proj_a).abs();
127 let dist_b = (value - proj_b).abs();
128 if dist_a <= dist_b {
129 proj_a
130 } else {
131 proj_b
132 }
133 }
134 Self::Not(_) => {
135 value
138 }
139 Self::Implies(a, b) => {
140 if a.check(value) {
142 b.project(value)
143 } else {
144 value
145 }
146 }
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub enum BoundType {
154 LessThan(f32),
155 LessEq(f32),
156 GreaterThan(f32),
157 GreaterEq(f32),
158 Equal(f32, f32), InRange(f32, f32),
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Constraint {
165 name: String,
166 dimension: Option<usize>,
167 bound: BoundType,
168 weight: f32,
169}
170
171impl Constraint {
172 pub fn check(&self, value: f32) -> bool {
174 match &self.bound {
175 BoundType::LessThan(b) => value < *b,
176 BoundType::LessEq(b) => value <= *b,
177 BoundType::GreaterThan(b) => value > *b,
178 BoundType::GreaterEq(b) => value >= *b,
179 BoundType::Equal(target, tol) => (value - target).abs() <= *tol,
180 BoundType::InRange(lo, hi) => value >= *lo && value <= *hi,
181 }
182 }
183
184 pub fn violation(&self, value: f32) -> f32 {
186 match &self.bound {
187 BoundType::LessThan(b) | BoundType::LessEq(b) => (value - b).max(0.0),
188 BoundType::GreaterThan(b) | BoundType::GreaterEq(b) => (b - value).max(0.0),
189 BoundType::Equal(target, _) => (value - target).abs(),
190 BoundType::InRange(lo, hi) => {
191 if value < *lo {
192 lo - value
193 } else if value > *hi {
194 value - hi
195 } else {
196 0.0
197 }
198 }
199 }
200 }
201
202 pub fn project(&self, value: f32) -> f32 {
204 match &self.bound {
205 BoundType::LessThan(b) => value.min(*b - f32::EPSILON),
206 BoundType::LessEq(b) => value.min(*b),
207 BoundType::GreaterThan(b) => value.max(*b + f32::EPSILON),
208 BoundType::GreaterEq(b) => value.max(*b),
209 BoundType::Equal(target, _) => *target,
210 BoundType::InRange(lo, hi) => value.clamp(*lo, *hi),
211 }
212 }
213
214 pub fn name(&self) -> &str {
216 &self.name
217 }
218
219 pub fn dimension(&self) -> Option<usize> {
221 self.dimension
222 }
223
224 pub fn weight(&self) -> f32 {
226 self.weight
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub enum RateType {
233 MaxRate(f32),
235 RateRange { min_rate: f32, max_rate: f32 },
237 MonotonicIncreasing,
239 MonotonicDecreasing,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct TemporalConstraint {
248 name: String,
249 dimension: Option<usize>,
250 rate_type: RateType,
251 dt: f32,
252 weight: f32,
253}
254
255impl TemporalConstraint {
256 pub fn name(&self) -> &str {
258 &self.name
259 }
260
261 pub fn dimension(&self) -> Option<usize> {
263 self.dimension
264 }
265
266 pub fn dt(&self) -> f32 {
268 self.dt
269 }
270
271 pub fn weight(&self) -> f32 {
273 self.weight
274 }
275
276 pub fn check(&self, prev_value: f32, current_value: f32) -> bool {
278 let rate = (current_value - prev_value) / self.dt;
279 match &self.rate_type {
280 RateType::MaxRate(max) => rate.abs() <= *max,
281 RateType::RateRange { min_rate, max_rate } => rate >= *min_rate && rate <= *max_rate,
282 RateType::MonotonicIncreasing => rate >= 0.0,
283 RateType::MonotonicDecreasing => rate <= 0.0,
284 }
285 }
286
287 pub fn violation(&self, prev_value: f32, current_value: f32) -> f32 {
289 let rate = (current_value - prev_value) / self.dt;
290 match &self.rate_type {
291 RateType::MaxRate(max) => (rate.abs() - max).max(0.0),
292 RateType::RateRange { min_rate, max_rate } => {
293 if rate < *min_rate {
294 min_rate - rate
295 } else if rate > *max_rate {
296 rate - max_rate
297 } else {
298 0.0
299 }
300 }
301 RateType::MonotonicIncreasing => (-rate).max(0.0),
302 RateType::MonotonicDecreasing => rate.max(0.0),
303 }
304 }
305
306 pub fn project(&self, prev_value: f32, current_value: f32) -> f32 {
308 let rate = (current_value - prev_value) / self.dt;
309 match &self.rate_type {
310 RateType::MaxRate(max) => {
311 if rate.abs() <= *max {
312 current_value
313 } else {
314 prev_value + rate.signum() * max * self.dt
315 }
316 }
317 RateType::RateRange { min_rate, max_rate } => {
318 let clamped_rate = rate.clamp(*min_rate, *max_rate);
319 prev_value + clamped_rate * self.dt
320 }
321 RateType::MonotonicIncreasing => {
322 if rate >= 0.0 {
323 current_value
324 } else {
325 prev_value }
327 }
328 RateType::MonotonicDecreasing => {
329 if rate <= 0.0 {
330 current_value
331 } else {
332 prev_value }
334 }
335 }
336 }
337}
338
339#[derive(Default)]
341pub struct TemporalConstraintBuilder {
342 name: Option<String>,
343 dimension: Option<usize>,
344 rate_type: Option<RateType>,
345 dt: Option<f32>,
346 weight: f32,
347}
348
349impl TemporalConstraintBuilder {
350 pub fn new() -> Self {
352 Self {
353 weight: 1.0,
354 ..Default::default()
355 }
356 }
357
358 pub fn name(mut self, name: &str) -> Self {
360 self.name = Some(name.to_string());
361 self
362 }
363
364 pub fn dimension(mut self, dim: usize) -> Self {
366 self.dimension = Some(dim);
367 self
368 }
369
370 pub fn max_rate(mut self, max_rate: f32) -> Self {
372 self.rate_type = Some(RateType::MaxRate(max_rate));
373 self
374 }
375
376 pub fn rate_range(mut self, min_rate: f32, max_rate: f32) -> Self {
378 self.rate_type = Some(RateType::RateRange { min_rate, max_rate });
379 self
380 }
381
382 pub fn monotonic_increasing(mut self) -> Self {
384 self.rate_type = Some(RateType::MonotonicIncreasing);
385 self
386 }
387
388 pub fn monotonic_decreasing(mut self) -> Self {
390 self.rate_type = Some(RateType::MonotonicDecreasing);
391 self
392 }
393
394 pub fn dt(mut self, dt: f32) -> Self {
396 self.dt = Some(dt);
397 self
398 }
399
400 pub fn weight(mut self, w: f32) -> Self {
402 self.weight = w;
403 self
404 }
405
406 pub fn build(self) -> LogicResult<TemporalConstraint> {
408 let name = self
409 .name
410 .ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
411 let rate_type = self
412 .rate_type
413 .ok_or_else(|| LogicError::InvalidConstraint("rate_type is required".into()))?;
414 let dt = self
415 .dt
416 .ok_or_else(|| LogicError::InvalidConstraint("dt (time step) is required".into()))?;
417
418 if dt <= 0.0 {
419 return Err(LogicError::InvalidConstraint("dt must be positive".into()));
420 }
421
422 Ok(TemporalConstraint {
423 name,
424 dimension: self.dimension,
425 rate_type,
426 dt,
427 weight: self.weight,
428 })
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct TemporalChecker {
435 constraints: Vec<TemporalConstraint>,
436 prev_values: Vec<f32>,
437 initialized: bool,
438}
439
440impl TemporalChecker {
441 pub fn new(constraints: Vec<TemporalConstraint>) -> Self {
443 Self {
444 constraints,
445 prev_values: Vec::new(),
446 initialized: false,
447 }
448 }
449
450 pub fn reset(&mut self) {
452 self.prev_values.clear();
453 self.initialized = false;
454 }
455
456 pub fn check(&mut self, values: &[f32]) -> Vec<(String, bool)> {
458 if !self.initialized {
459 self.prev_values = values.to_vec();
460 self.initialized = true;
461 return self
462 .constraints
463 .iter()
464 .map(|c| (c.name.clone(), true))
465 .collect();
466 }
467
468 let results: Vec<(String, bool)> = self
469 .constraints
470 .iter()
471 .map(|c| {
472 let result = if let Some(dim) = c.dimension() {
473 if dim < values.len() && dim < self.prev_values.len() {
474 c.check(self.prev_values[dim], values[dim])
475 } else {
476 true }
478 } else {
479 values
481 .iter()
482 .zip(self.prev_values.iter())
483 .all(|(&curr, &prev)| c.check(prev, curr))
484 };
485 (c.name.clone(), result)
486 })
487 .collect();
488
489 self.prev_values = values.to_vec();
490 results
491 }
492
493 pub fn total_violation(&mut self, values: &[f32]) -> f32 {
495 if !self.initialized {
496 self.prev_values = values.to_vec();
497 self.initialized = true;
498 return 0.0;
499 }
500
501 let violation: f32 = self
502 .constraints
503 .iter()
504 .map(|c| {
505 let v = if let Some(dim) = c.dimension() {
506 if dim < values.len() && dim < self.prev_values.len() {
507 c.violation(self.prev_values[dim], values[dim])
508 } else {
509 0.0
510 }
511 } else {
512 values
513 .iter()
514 .zip(self.prev_values.iter())
515 .map(|(&curr, &prev)| c.violation(prev, curr))
516 .sum()
517 };
518 v * c.weight()
519 })
520 .sum();
521
522 self.prev_values = values.to_vec();
523 violation
524 }
525
526 pub fn project(&mut self, values: &[f32]) -> Vec<f32> {
528 if !self.initialized {
529 self.prev_values = values.to_vec();
530 self.initialized = true;
531 return values.to_vec();
532 }
533
534 let mut projected = values.to_vec();
535
536 for c in &self.constraints {
537 if let Some(dim) = c.dimension() {
538 if dim < projected.len() && dim < self.prev_values.len() {
539 projected[dim] = c.project(self.prev_values[dim], projected[dim]);
540 }
541 } else {
542 for i in 0..projected.len().min(self.prev_values.len()) {
543 projected[i] = c.project(self.prev_values[i], projected[i]);
544 }
545 }
546 }
547
548 self.prev_values = projected.clone();
549 projected
550 }
551
552 pub fn all_satisfied(&mut self, values: &[f32]) -> bool {
554 self.check(values).iter().all(|(_, sat)| *sat)
555 }
556}
557
558pub struct ConstraintBuilder {
560 name: Option<String>,
561 dimension: Option<usize>,
562 bound: Option<BoundType>,
563 weight: f32,
564}
565
566impl Default for ConstraintBuilder {
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572impl ConstraintBuilder {
573 pub fn new() -> Self {
575 Self {
576 name: None,
577 dimension: None,
578 bound: None,
579 weight: 1.0,
580 }
581 }
582
583 pub fn name(mut self, name: &str) -> Self {
585 self.name = Some(name.to_string());
586 self
587 }
588
589 pub fn dimension(mut self, dim: usize) -> Self {
591 self.dimension = Some(dim);
592 self
593 }
594
595 pub fn less_than(mut self, value: f32) -> Self {
597 self.bound = Some(BoundType::LessThan(value));
598 self
599 }
600
601 pub fn less_eq(mut self, value: f32) -> Self {
603 self.bound = Some(BoundType::LessEq(value));
604 self
605 }
606
607 pub fn greater_than(mut self, value: f32) -> Self {
609 self.bound = Some(BoundType::GreaterThan(value));
610 self
611 }
612
613 pub fn greater_eq(mut self, value: f32) -> Self {
615 self.bound = Some(BoundType::GreaterEq(value));
616 self
617 }
618
619 pub fn equal(mut self, value: f32, tolerance: f32) -> Self {
621 self.bound = Some(BoundType::Equal(value, tolerance));
622 self
623 }
624
625 pub fn in_range(mut self, lo: f32, hi: f32) -> Self {
627 self.bound = Some(BoundType::InRange(lo, hi));
628 self
629 }
630
631 pub fn weight(mut self, w: f32) -> Self {
633 self.weight = w;
634 self
635 }
636
637 pub fn build(self) -> LogicResult<Constraint> {
639 let name = self
640 .name
641 .ok_or_else(|| LogicError::InvalidConstraint("name is required".into()))?;
642 let bound = self
643 .bound
644 .ok_or_else(|| LogicError::InvalidConstraint("bound is required".into()))?;
645
646 Ok(Constraint {
647 name,
648 dimension: self.dimension,
649 bound,
650 weight: self.weight,
651 })
652 }
653}