1use std::fmt::{Display, Formatter};
2use std::ops::{Bound, RangeBounds};
3
4use thiserror::Error;
5
6use crate::Formula;
7use crate::metrics::{Top, Bottom, Meet, Join};
8use crate::trace::{Range, Trace};
9use super::BinaryOperatorError;
10
11struct ForwardIter<'a, T, F> {
12 rest: Range<'a, T>,
13 state: Option<(f64, T)>,
14 combine: F,
15}
16
17impl<'a, T, F> ForwardIter<'a, T, F>
18where
19 F: Fn(&T, &T) -> T,
20{
21 fn new(mut range: Range<'a, T>, init: T, combine: F) -> Self {
22 let state = match range.next_back() {
26 Some((time, value)) => Some((time, combine(&init, value))),
27 None => Some((0.0, init)),
28 };
29
30 Self {
31 rest: range,
32 state,
33 combine,
34 }
35 }
36}
37
38impl<'a, T, F> Iterator for ForwardIter<'a, T, F>
39where
40 F: Fn(&T, &T) -> T,
41{
42 type Item = (f64, T);
43
44 fn next(&mut self) -> Option<Self::Item> {
45 let state = self.state.take()?;
46 self.state = self
47 .rest
48 .next_back()
49 .map(|(time, value)| (time, (self.combine)(&state.1, value)));
50
51 Some(state)
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
56enum Endpoint {
57 Open(f64),
58 Closed(f64),
59}
60
61impl Endpoint {
62 fn value(&self) -> f64 {
63 match self {
64 Self::Open(value) => *value,
65 Self::Closed(value) => *value,
66 }
67 }
68
69 fn map<F>(self, f: F) -> Endpoint
70 where
71 F: Fn(f64) -> f64,
72 {
73 match self {
74 Self::Open(value) => Self::Open(f(value)),
75 Self::Closed(value) => Self::Closed(f(value)),
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, PartialOrd)]
81pub struct Interval {
82 start: Endpoint,
83 end: Endpoint,
84}
85
86impl Interval {
87 fn shift(&self, amount: f64) -> Interval {
88 Self {
89 start: self.start.map(|start| start + amount),
90 end: self.end.map(|end| end + amount),
91 }
92 }
93}
94
95impl Display for Interval {
96 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97 let opening = match &self.start {
98 Endpoint::Open(_) => '(',
99 Endpoint::Closed(_) => '[',
100 };
101
102 let closing = match &self.end {
103 Endpoint::Open(_) => ')',
104 Endpoint::Closed(_) => ']',
105 };
106
107 write!(f, "{}{},{}{}", opening, self.start.value(), self.end.value(), closing)
108 }
109}
110
111impl<T> From<std::ops::Range<T>> for Interval
112where
113 T: Into<f64>,
114{
115 fn from(std::ops::Range { start, end }: std::ops::Range<T>) -> Self {
116 Self {
117 start: Endpoint::Closed(start.into()),
118 end: Endpoint::Open(end.into()),
119 }
120 }
121}
122
123impl<T> From<std::ops::RangeInclusive<T>> for Interval
124where
125 T: Into<f64>,
126{
127 fn from(range: std::ops::RangeInclusive<T>) -> Self {
128 let (start, end) = range.into_inner();
129
130 Self {
131 start: Endpoint::Closed(start.into()),
132 end: Endpoint::Closed(end.into()),
133 }
134 }
135}
136
137impl RangeBounds<f64> for Interval {
138 fn contains<U>(&self, item: &U) -> bool
139 where
140 U: PartialOrd<f64> + ?Sized,
141 {
142 let within_lower = match &self.start {
143 Endpoint::Open(lower) => item.gt(lower),
144 Endpoint::Closed(lower) => item.ge(lower),
145 };
146
147 let within_upper = match &self.end {
148 Endpoint::Open(upper) => item.lt(upper),
149 Endpoint::Closed(upper) => item.le(upper),
150 };
151
152 within_lower && within_upper
153 }
154
155 fn start_bound(&self) -> Bound<&f64> {
156 match &self.start {
157 Endpoint::Open(start) => Bound::Excluded(start),
158 Endpoint::Closed(start) => Bound::Included(start),
159 }
160 }
161
162 fn end_bound(&self) -> Bound<&f64> {
163 match &self.end {
164 Endpoint::Open(start) => Bound::Excluded(start),
165 Endpoint::Closed(start) => Bound::Included(start),
166 }
167 }
168}
169
170#[derive(Debug, Clone)]
171struct UnaryOperator<F> {
172 subformula: F,
173 bounds: Option<Interval>
174}
175
176#[derive(Debug, Clone, Error)]
185pub enum ForwardOperatorError<F> {
186 #[error("Bounded formula error: {0}")]
187 FormulaError(F),
188
189 #[error("Subtrace evaluation for interval {0} is empty")]
190 EmptySubtraceEvaluation(Interval),
191
192 #[error("Empty interval")]
193 EmptyInterval,
194}
195
196impl<F> UnaryOperator<F> {
197 fn new(bounds: Option<Interval>, subformula: F) -> Self {
198 Self { bounds, subformula }
199 }
200
201 fn evaluate<State, I, C, Metric>(&self, trace: &Trace<State>, init: I, combine: C) -> Result<Trace<F::Metric>, ForwardOperatorError<F::Error>>
202 where
203 F: Formula<State, Metric = Metric>,
204 I: Fn() -> Metric,
205 C: Fn(&Metric, &Metric) -> Metric,
206 {
207 if trace.is_empty() {
208 return Ok(Trace::from_iter([(0.0, init())]));
209 }
210
211 let inner = self.subformula
212 .evaluate(trace)
213 .map_err(ForwardOperatorError::FormulaError)?;
214
215 match &self.bounds {
216 None => {
217 let first = init();
218 let range = inner.range(..);
219 let result = ForwardIter::new(range, first, combine).collect();
220
221 Ok(result)
222 },
223 Some(interval) => {
224 let evaluate_time = |time: f64| -> Result<(f64, Metric), ForwardOperatorError<F::Error>> {
225 let shifted = interval.shift(time);
226 let range = inner.range((shifted.start_bound(), shifted.end_bound()));
227 let iter = ForwardIter::new(range, init(), &combine);
228
229 iter.last()
230 .ok_or(ForwardOperatorError::EmptySubtraceEvaluation(shifted))
231 .map(|(_, value)| (time, value))
232 };
233
234 inner.times().map(evaluate_time).collect()
235 },
236 }
237 }
238}
239
240#[derive(Clone)]
266pub struct Always<F>(UnaryOperator<F>);
267
268impl<F> Always<F> {
269 pub fn unbounded(formula: F) -> Self {
283 Self(UnaryOperator::new(None, formula))
284 }
285
286 pub fn bounded<I>(interval: I, formula: F) -> Self
302 where
303 I: Into<Interval>,
304 {
305 Self(UnaryOperator::new(Some(interval.into()), formula))
306 }
307}
308
309impl<State, F, M> Formula<State> for Always<F>
310where
311 F: Formula<State, Metric = M>,
312 M: Top + Meet,
313{
314 type Metric = M;
315 type Error = ForwardOperatorError<F::Error>;
316
317 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
318 self.0.evaluate(trace, M::top, M::min)
319 }
320}
321
322#[derive(Clone)]
348pub struct Eventually<F>(UnaryOperator<F>);
349
350impl<F> Eventually<F> {
351 pub fn unbounded(formula: F) -> Self {
365 Self(UnaryOperator::new(None, formula))
366 }
367
368 pub fn bounded<I>(interval: I, formula: F) -> Self
384 where
385 I: Into<Interval>,
386 {
387 Self(UnaryOperator::new(Some(interval.into()), formula))
388 }
389}
390
391impl<State, F, M> Formula<State> for Eventually<F>
392where
393 F: Formula<State, Metric = M>,
394 M: Bottom + Join,
395{
396 type Metric = M;
397 type Error = ForwardOperatorError<F::Error>;
398
399 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
400 self.0.evaluate(trace, M::bottom, M::max)
401 }
402}
403
404#[derive(Clone, Debug)]
431pub struct Next<F> {
432 subformula: F,
433}
434
435impl<F> Next<F> {
436 pub fn new(subformula: F) -> Self {
450 Self { subformula }
451 }
452}
453
454fn next_op<T, F, U>(trace: Trace<T>, f: F) -> Trace<U>
455where
456 F: Fn(&T, T) -> U,
457 U: Bottom,
458{
459 let mut iter = trace.into_iter().rev();
460 let mut trace = Trace::default();
461
462 if let Some((time, mut metric)) = iter.next() {
463 trace.insert(time, U::bottom());
464
465 for (prev_time, prev_metric) in iter {
466 trace.insert(prev_time, f(&prev_metric, metric));
467 metric = prev_metric;
468 }
469 }
470
471 trace
472}
473
474impl<State, F, Metric> Formula<State> for Next<F>
475where
476 F: Formula<State, Metric = Metric>,
477 Metric: Bottom,
478{
479 type Metric = F::Metric;
480 type Error = F::Error;
481
482 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
483 self.subformula
484 .evaluate(trace)
485 .map(|inner_trace| next_op(inner_trace, |_, metric| metric))
486 }
487}
488
489pub struct Until<Left, Right> {
512 left: Left,
513 right: Right,
514}
515
516impl<Left, Right> Until<Left, Right> {
517 pub fn new(left: Left, right: Right) -> Self {
534 Self { left, right }
535 }
536}
537
538fn until_eval_time<M>(left: &Trace<M>, time: f64, right: M, prev: &M) -> M
539where
540 M: Top + Meet + Join,
541{
542 let left_metric = left
543 .range(..=time)
544 .fold(M::top(), |l, (_, r)| l.min(r)); let combined_metric = left_metric.min(&right); combined_metric.max(prev) }
549
550fn until_op<M, I>(left: Trace<M>, right: I, mut prev_time: f64, mut prev_metric: M) -> Trace<M>
551where
552 I: Iterator<Item = (f64, M)>,
553 M: Top + Bottom + Meet + Join,
554{
555 let mut trace = Trace::default();
556 let bottom = M::bottom();
557
558 prev_metric = until_eval_time(&left, prev_time, prev_metric, &bottom);
559
560 for (time, right_metric) in right {
561 let next_metric = until_eval_time(&left, time, right_metric, &prev_metric);
562
563 trace.insert(prev_time, prev_metric);
564 prev_time = time;
565 prev_metric = next_metric;
566 }
567
568 trace.insert(prev_time, prev_metric);
569 trace
570}
571
572impl<Left, Right, State, Metric> Formula<State> for Until<Left, Right>
573where
574 Left: Formula<State, Metric = Metric>,
575 Right: Formula<State, Metric = Metric>,
576 Metric: Clone + Top + Bottom + Meet + Join,
577{
578 type Metric = Metric;
579 type Error = BinaryOperatorError<Left::Error, Right::Error>;
580
581 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
582 let left_trace = self
583 .left
584 .evaluate(trace)
585 .map_err(BinaryOperatorError::LeftError)?;
586
587 let right_trace = self
588 .right
589 .evaluate(trace)
590 .map_err(BinaryOperatorError::RightError)?;
591
592 let mut iter = right_trace.into_iter().rev();
593
594 let evaluated_trace = if let Some((prev_time, prev_metric)) = iter.next() {
595 until_op(left_trace, iter, prev_time, prev_metric)
596 } else {
597 Trace::default()
598 };
599
600 Ok(evaluated_trace)
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use crate::Formula;
607 use crate::operators::BinaryOperatorError;
608 use crate::operators::test::*;
609 use crate::trace::Trace;
610 use super::{Always, Eventually, Next, Until, ForwardOperatorError};
611
612 #[test]
613 fn always() -> Result<(), ForwardOperatorError<ConstError>> {
614 let input = Trace::from_iter([
615 (0, 4.0),
616 (1, 2.0),
617 (2, 3.0),
618 (3, 1.0),
619 (4, 3.0),
620 ]);
621
622 let formula = Always::unbounded(Const);
623 let robustness = formula.evaluate(&input)?;
624 let expected = Trace::from_iter([
625 (0, 1.0),
626 (1, 1.0),
627 (2, 1.0),
628 (3, 1.0),
629 (4, 3.0),
630 ]);
631
632 assert_eq!(robustness, expected);
633 Ok(())
634 }
635
636 #[test]
637 fn bounded_always() -> Result<(), ForwardOperatorError<ConstError>> {
638 let input = Trace::from_iter([
639 (0, 4.0),
640 (1, 2.0),
641 (2, 3.0),
642 (3, 1.0),
643 (4, 3.0),
644 ]);
645
646 let formula = Always::bounded(0.0..=2.0, Const);
647 let robustness = formula.evaluate(&input)?;
648 let expected = Trace::from_iter([
649 (0, 2.0),
650 (1, 1.0),
651 (2, 1.0),
652 (3, 1.0),
653 (4, 3.0),
654 ]);
655
656 assert_eq!(robustness, expected);
657 Ok(())
658 }
659
660 #[test]
661 fn eventually() -> Result<(), ForwardOperatorError<ConstError>> {
662 let input = Trace::from_iter([
663 (0, 4.0),
664 (1, 2.0),
665 (2, 3.0),
666 (3, 1.0),
667 (4, 3.0),
668 ]);
669
670 let formula = Eventually::unbounded(Const);
671 let robustness = formula.evaluate(&input)?;
672 let expected = Trace::from_iter([
673 (0, 4.0),
674 (1, 3.0),
675 (2, 3.0),
676 (3, 3.0),
677 (4, 3.0),
678 ]);
679
680 assert_eq!(robustness, expected);
681 Ok(())
682 }
683
684 #[test]
685 fn bounded_eventually() -> Result<(), ForwardOperatorError<ConstError>> {
686 let input = Trace::from_iter([
687 (0, 4.0),
688 (1, 2.0),
689 (2, 1.0),
690 (3, 5.0),
691 (4, 3.0),
692 ]);
693
694 let formula = Eventually::bounded(0.0..=2.0, Const);
695 let robustness = formula.evaluate(&input)?;
696 let expected = Trace::from_iter([
697 (0, 4.0),
698 (1, 5.0),
699 (2, 5.0),
700 (3, 5.0),
701 (4, 3.0),
702 ]);
703
704 assert_eq!(robustness, expected);
705 Ok(())
706 }
707
708 #[test]
709 fn bounds() -> Result<(), ForwardOperatorError<ConstError>> {
710 let input = Trace::from_iter([
711 (0, 4.0),
712 (1, 2.0),
713 (2, 1.0),
714 (3, 5.0),
715 (4, 3.0),
716 ]);
717
718 let expected = Trace::from_iter([
719 (0, 4.0),
720 (1, 5.0),
721 (2, 5.0),
722 (3, 5.0),
723 (4, 3.0),
724 ]);
725
726 let f1 = Eventually::bounded(0f64..=2f64, Const);
727 let f2 = Eventually::bounded(0f64..3f64, Const);
728
729 assert_eq!(f1.evaluate(&input)?, expected);
730 assert_eq!(f2.evaluate(&input)?, expected);
731 Ok(())
732 }
733
734 #[test]
735 fn next() -> Result<(), ConstError> {
736 let input = Trace::from_iter([
737 (0, 1.0),
738 (1, 2.0),
739 (2, 3.0),
740 (3, 4.0),
741 ]);
742
743 let formula = Next::new(Const);
744 let robustness = formula.evaluate(&input)?;
745 let expected = Trace::from_iter([
746 (0, 2.0),
747 (1, 3.0),
748 (2, 4.0),
749 (3, f64::NEG_INFINITY),
750 ]);
751
752 assert_eq!(robustness, expected);
753 Ok(())
754 }
755
756 #[test]
757 fn until() -> Result<(), BinaryOperatorError<ConstError, ConstError>> {
758 let input = Trace::from_iter([
759 (0.0, (3.0, -2.1)),
760 (1.0, (1.5, 3.7)),
761 (2.0, (1.4, 1.2)),
762 (3.0, (1.1, 2.2)),
763 ]);
764
765 let formula = Until::new(ConstLeft, ConstRight);
766 let robustness = formula.evaluate(&input)?;
767
768 assert_eq!(robustness[3.0], 1.1);
769 assert_eq!(robustness[2.0], 1.2);
770 assert_eq!(robustness[1.0], 1.5);
771 assert_eq!(robustness[0.0], 1.5);
772
773 Ok(())
774 }
775}