logic/expressions/
predicate.rs

1use std::collections::HashMap;
2
3use crate::trace::Trace;
4use crate::{Formula, Metric};
5
6#[derive(Clone)]
7pub enum Bound {
8    Constant(f64),
9    Variable(String),
10}
11
12#[derive(Clone)]
13pub struct Predicate {
14    coefficient_map: HashMap<String, f64>,
15    bound: Bound,
16}
17
18impl Predicate {
19    pub fn new(coefficient_map: HashMap<String, f64>, bound: Bound) -> Predicate {
20        Predicate { coefficient_map, bound }
21    }
22
23    fn weighted_sum(&self, value_map: &HashMap<String, f64>) -> Result<f64, super::error::Error> {
24        self.coefficient_map
25            .iter()
26            .map(|(name, weight)| {
27                let value = value_map
28                    .get(name)
29                    .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?;
30                Ok(value * weight)
31            })
32            .sum()
33    }
34}
35
36impl Formula<HashMap<String, f64>> for Predicate {
37    type Error = super::error::Error;
38
39    fn satisfied_by(&self, value_map: &HashMap<String, f64>) -> Result<bool, Self::Error> {
40        let sum = self.weighted_sum(value_map)?;
41        let bound_value = match &self.bound {
42            Bound::Constant(value) => value,
43            Bound::Variable(name) => value_map
44                .get(name)
45                .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
46        };
47
48        Ok(sum <= *bound_value)
49    }
50}
51
52impl Formula<Trace<HashMap<String, f64>>> for Predicate {
53    type Error = super::error::TimedError;
54
55    fn satisfied_by(&self, trace: &Trace<HashMap<String, f64>>) -> Result<bool, Self::Error> {
56        let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;
57
58        self.satisfied_by(first_state.state)
59            .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
60    }
61}
62
63impl Metric<HashMap<String, f64>> for Predicate {
64    type Error = super::error::Error;
65
66    fn distance(&self, value_map: &HashMap<String, f64>) -> Result<f64, Self::Error> {
67        // The distance of a predicate is the distance of the named value from the specified bound.
68        // The distance is negative if the value is outside the bound (falsifying) and positive if the bound is inside
69        // the bound (nonfalsifying).
70        // For the equality, the distance is positive infinity if the value and bound are equal, and negative otherwise
71
72        let sum = self.weighted_sum(value_map)?;
73        let bound_value = match &self.bound {
74            Bound::Constant(value) => value,
75            Bound::Variable(name) => value_map
76                .get(name)
77                .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
78        };
79
80        Ok(*bound_value - sum)
81    }
82}
83
84impl Metric<Trace<HashMap<String, f64>>> for Predicate {
85    type Error = super::error::TimedError;
86
87    fn distance(&self, trace: &Trace<HashMap<String, f64>>) -> Result<f64, Self::Error> {
88        let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;
89
90        self.distance(first_state.state)
91            .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
92    }
93}