1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::collections::HashMap;

use crate::trace::Trace;
use crate::{Formula, Metric};

#[derive(Clone)]
pub enum Bound {
    Constant(f64),
    Variable(String),
}

#[derive(Clone)]
pub struct Predicate {
    coefficient_map: HashMap<String, f64>,
    bound: Bound,
}

impl Predicate {
    pub fn new(coefficient_map: HashMap<String, f64>, bound: Bound) -> Predicate {
        Predicate { coefficient_map, bound }
    }

    fn weighted_sum(&self, value_map: &HashMap<String, f64>) -> Result<f64, super::error::Error> {
        self.coefficient_map
            .iter()
            .map(|(name, weight)| {
                let value = value_map
                    .get(name)
                    .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?;
                Ok(value * weight)
            })
            .sum()
    }
}

impl Formula<HashMap<String, f64>> for Predicate {
    type Error = super::error::Error;

    fn satisfied_by(&self, value_map: &HashMap<String, f64>) -> Result<bool, Self::Error> {
        let sum = self.weighted_sum(value_map)?;
        let bound_value = match &self.bound {
            Bound::Constant(value) => value,
            Bound::Variable(name) => value_map
                .get(name)
                .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
        };

        Ok(sum <= *bound_value)
    }
}

impl Formula<Trace<HashMap<String, f64>>> for Predicate {
    type Error = super::error::TimedError;

    fn satisfied_by(&self, trace: &Trace<HashMap<String, f64>>) -> Result<bool, Self::Error> {
        let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;

        self.satisfied_by(first_state.state)
            .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
    }
}

impl Metric<HashMap<String, f64>> for Predicate {
    type Error = super::error::Error;

    fn distance(&self, value_map: &HashMap<String, f64>) -> Result<f64, Self::Error> {
        // The distance of a predicate is the distance of the named value from the specified bound.
        // The distance is negative if the value is outside the bound (falsifying) and positive if the bound is inside
        // the bound (nonfalsifying).
        // For the equality, the distance is positive infinity if the value and bound are equal, and negative otherwise

        let sum = self.weighted_sum(value_map)?;
        let bound_value = match &self.bound {
            Bound::Constant(value) => value,
            Bound::Variable(name) => value_map
                .get(name)
                .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
        };

        Ok(*bound_value - sum)
    }
}

impl Metric<Trace<HashMap<String, f64>>> for Predicate {
    type Error = super::error::TimedError;

    fn distance(&self, trace: &Trace<HashMap<String, f64>>) -> Result<f64, Self::Error> {
        let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;

        self.distance(first_state.state)
            .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
    }
}