1use std::ops::Neg;
2
3use thiserror::Error;
4
5use crate::Formula;
6use crate::metrics::{Meet, Join};
7use crate::trace::Trace;
8
9#[derive(Debug, Clone, PartialEq)]
31pub struct Not<F> {
32 subformula: F,
33}
34
35impl<F> Not<F> {
36 pub fn new(subformula: F) -> Self {
37 Self { subformula }
38 }
39}
40
41impl<T, F, M> Formula<T> for Not<F>
42where
43 F: Formula<T, Metric = M>,
44 M: Neg<Output = M>,
45{
46 type Metric = M;
47 type Error = F::Error;
48
49 fn evaluate(&self, trace: &Trace<T>) -> Result<Trace<Self::Metric>, Self::Error> {
50 self.subformula
51 .evaluate(trace)
52 .map(|result| {
53 result.into_iter().map_states(|state| state.neg()).collect()
54 })
55 }
56}
57
58#[derive(Debug, Clone, Error)]
66pub enum BinaryEvaluationError {
67 #[error("Metric traces have mismatched lengths: [{0}] [{1}]")]
68 MismatchedLengths(usize, usize),
69
70 #[error("Mismatched times between traces: [{0}] [{1}]")]
71 MismatchedTimes(f64, f64),
72}
73
74#[derive(Debug, Clone, Error)]
83pub enum BinaryOperatorError<L, R> {
84 #[error("Left suformula error: {0}")]
86 LeftError(L),
87
88 #[error("Right subformula error: {0}")]
90 RightError(R),
91
92 #[error("Error evaluation binary operator: {0}")]
93 EvaluationError(#[from] BinaryEvaluationError)
94}
95
96
97type BinOpResult<T, E1, E2> = Result<Trace<T>, BinaryOperatorError<E1, E2>>;
98
99#[derive(Debug, Clone)]
100struct Binop<Left, Right> {
101 left: Left,
102 right: Right,
103}
104
105impl<Left, Right> Binop<Left, Right> {
106 fn evaluate_left<State, Metric>(&self, trace: &Trace<State>) -> BinOpResult<Metric, Left::Error, Right::Error>
107 where
108 Left: Formula<State, Metric = Metric>,
109 Right: Formula<State>,
110 {
111 self.left.evaluate(trace).map_err(BinaryOperatorError::LeftError)
112 }
113
114 fn evaluate_right<State, Metric>(&self, trace: &Trace<State>) -> BinOpResult<Metric, Left::Error, Right::Error>
115 where
116 Left: Formula<State>,
117 Right: Formula<State, Metric = Metric>,
118 {
119 self.right.evaluate(trace).map_err(BinaryOperatorError::RightError)
120 }
121}
122
123fn binop<I1, I2, F, M>(left: I1, right: I2, f: F) -> Result<Trace<M>, BinaryEvaluationError>
124where
125 I1: ExactSizeIterator<Item = (f64, M)>,
126 I2: ExactSizeIterator<Item = (f64, M)>,
127 F: Fn(&M, &M) -> M,
128{
129 if left.len() != right.len() {
130 return Err(BinaryEvaluationError::MismatchedLengths(left.len(), right.len()));
131 }
132
133 left.zip(right)
134 .map(|((lt, ls), (rt, rs))| {
135 if lt == rt {
136 Ok((lt, f(&ls, &rs)))
137 } else {
138 Err(BinaryEvaluationError::MismatchedTimes(lt, rt))
139 }
140 })
141 .collect()
142}
143
144#[derive(Debug, Clone)]
171pub struct Or<Left, Right>(Binop<Left, Right>);
172
173impl<Left, Right> Or<Left, Right> {
174 pub fn new(left: Left, right: Right) -> Self {
175 Self(Binop { left, right })
176 }
177}
178
179impl<Left, Right, State, Metric> Formula<State> for Or<Left, Right>
180where
181 Left: Formula<State, Metric = Metric>,
182 Right: Formula<State, Metric = Metric>,
183 Metric: Join,
184{
185 type Metric = Metric;
186 type Error = BinaryOperatorError<Left::Error, Right::Error>;
187
188 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
189 let left = self.0.evaluate_left(trace)?;
190 let right = self.0.evaluate_right(trace)?;
191 let result = binop(left.into_iter(), right.into_iter(), Metric::max)?;
192
193 Ok(result)
194 }
195}
196
197#[derive(Debug, Clone)]
224pub struct And<Left, Right>(Binop<Left, Right>);
225
226impl<Left, Right> And<Left, Right> {
227 pub fn new(left: Left, right: Right) -> Self {
228 Self(Binop { left, right })
229 }
230}
231
232impl<Left, Right, State, Metric> Formula<State> for And<Left, Right>
233where
234 Left: Formula<State, Metric = Metric>,
235 Right: Formula<State, Metric = Metric>,
236 Metric: Meet,
237{
238 type Metric = Metric;
239 type Error = BinaryOperatorError<Left::Error, Right::Error>;
240
241 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
242 let left = self.0.evaluate_left(trace)?;
243 let right = self.0.evaluate_right(trace)?;
244 let result = binop(left.into_iter(), right.into_iter(), Metric::min)?;
245
246 Ok(result)
247 }
248}
249
250#[derive(Clone)]
277pub struct Implies<Ante, Cons>(Binop<Ante, Cons>);
278
279impl<Ante, Cons> Implies<Ante, Cons> {
280 pub fn new(ante: Ante, cons: Cons) -> Self {
281 Self(Binop { left: ante, right: cons })
282 }
283}
284
285impl<Ante, Cons, State, Metric> Formula<State> for Implies<Ante, Cons>
286where
287 Ante: Formula<State, Metric = Metric>,
288 Cons: Formula<State, Metric = Metric>,
289 Metric: Neg<Output = Metric> + Join,
290{
291 type Metric = Metric;
292 type Error = BinaryOperatorError<Ante::Error, Cons::Error>;
293
294 fn evaluate(&self, trace: &Trace<State>) -> Result<Trace<Self::Metric>, Self::Error> {
295 let ante = self.0
296 .evaluate_left(trace)?
297 .into_iter()
298 .map_states(|state| -state);
299
300 let cons = self.0.evaluate_right(trace)?;
301 let result = binop(ante, cons.into_iter(), |neg_a, c| Metric::max(neg_a, c))?;
302
303 Ok(result)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use crate::Formula;
310 use crate::operators::test::*;
311 use crate::operators::BinaryOperatorError;
312 use crate::trace::Trace;
313 use super::{And, Implies, Not, Or};
314
315 #[test]
316 fn not() -> Result<(), ConstError> {
317 let input = Trace::from_iter([
318 (0, 0.0),
319 (1, 1.0),
320 (2, 2.0),
321 (3, 3.0),
322 ]);
323
324 let formula = Not::new(Const);
325 let robustness = formula.evaluate(&input)?;
326 let expected = Trace::from_iter([
327 (0, 0.0),
328 (1, -1.0),
329 (2, -2.0),
330 (3, -3.0),
331 ]);
332
333 assert_eq!(robustness, expected);
334 Ok(())
335 }
336
337 #[test]
338 fn or() -> Result<(), BinaryOperatorError<ConstError, ConstError>> {
339 let input = Trace::from_iter([
340 (0, (0.0, 1.0)),
341 (1, (1.0, 0.0)),
342 (2, (2.0, 4.0)),
343 (3, (3.0, 6.0)),
344 ]);
345
346 let formula = Or::new(ConstLeft, ConstRight);
347 let robustness = formula.evaluate(&input)?;
348 let expected = Trace::from_iter([
349 (0, 1.0),
350 (1, 1.0),
351 (2, 4.0),
352 (3, 6.0),
353 ]);
354
355 assert_eq!(robustness, expected);
356 Ok(())
357 }
358
359 #[test]
360 fn and() -> Result<(), BinaryOperatorError<ConstError, ConstError>> {
361 let input = Trace::from_iter([
362 (0, (0.0, 1.0)),
363 (1, (1.0, 0.0)),
364 (2, (2.0, 4.0)),
365 (3, (3.0, 6.0)),
366 ]);
367
368 let formula = And::new(ConstLeft, ConstRight);
369 let robustness = formula.evaluate(&input)?;
370 let expected = Trace::from_iter([(0, 0.0), (1, 0.0), (2, 2.0), (3, 3.0)]);
371
372 assert_eq!(robustness, expected);
373 Ok(())
374 }
375
376 #[test]
377 fn implies() -> Result<(), BinaryOperatorError<ConstError, ConstError>> {
378 let input = Trace::from_iter([
379 (0, (0.0, 1.0)),
380 (1, (1.0, 0.0)),
381 (2, (-4.0, 2.0)),
382 (3, (3.0, 6.0)),
383 ]);
384
385 let formula = Implies::new(ConstLeft, ConstRight);
386 let robustness = formula.evaluate(&input)?;
387 let expected = Trace::from_iter([
388 (0, 1.0),
389 (1, 0.0),
390 (2, 4.0),
391 (3, 6.0),
392 ]);
393
394 assert_eq!(robustness, expected);
395 Ok(())
396 }
397}