logic/expressions/
predicate.rs1use std::collections::HashMap;
2
3use crate::trace::Trace;
4use crate::{Formula, Metric};
5
6#[derive(Clone)]
7pub enum Bound {
8 Constant(f64),
9 Variable(String),
10}
11
12#[derive(Clone)]
13pub struct Predicate {
14 coefficient_map: HashMap<String, f64>,
15 bound: Bound,
16}
17
18impl Predicate {
19 pub fn new(coefficient_map: HashMap<String, f64>, bound: Bound) -> Predicate {
20 Predicate { coefficient_map, bound }
21 }
22
23 fn weighted_sum(&self, value_map: &HashMap<String, f64>) -> Result<f64, super::error::Error> {
24 self.coefficient_map
25 .iter()
26 .map(|(name, weight)| {
27 let value = value_map
28 .get(name)
29 .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?;
30 Ok(value * weight)
31 })
32 .sum()
33 }
34}
35
36impl Formula<HashMap<String, f64>> for Predicate {
37 type Error = super::error::Error;
38
39 fn satisfied_by(&self, value_map: &HashMap<String, f64>) -> Result<bool, Self::Error> {
40 let sum = self.weighted_sum(value_map)?;
41 let bound_value = match &self.bound {
42 Bound::Constant(value) => value,
43 Bound::Variable(name) => value_map
44 .get(name)
45 .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
46 };
47
48 Ok(sum <= *bound_value)
49 }
50}
51
52impl Formula<Trace<HashMap<String, f64>>> for Predicate {
53 type Error = super::error::TimedError;
54
55 fn satisfied_by(&self, trace: &Trace<HashMap<String, f64>>) -> Result<bool, Self::Error> {
56 let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;
57
58 self.satisfied_by(first_state.state)
59 .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
60 }
61}
62
63impl Metric<HashMap<String, f64>> for Predicate {
64 type Error = super::error::Error;
65
66 fn distance(&self, value_map: &HashMap<String, f64>) -> Result<f64, Self::Error> {
67 let sum = self.weighted_sum(value_map)?;
73 let bound_value = match &self.bound {
74 Bound::Constant(value) => value,
75 Bound::Variable(name) => value_map
76 .get(name)
77 .ok_or_else(|| super::error::Error::MissingVariable(name.clone()))?,
78 };
79
80 Ok(*bound_value - sum)
81 }
82}
83
84impl Metric<Trace<HashMap<String, f64>>> for Predicate {
85 type Error = super::error::TimedError;
86
87 fn distance(&self, trace: &Trace<HashMap<String, f64>>) -> Result<f64, Self::Error> {
88 let first_state = trace.first_state().ok_or(super::error::TimedError::EmptyTrace)?;
89
90 self.distance(first_state.state)
91 .map_err(|error| super::error::TimedError::ValuationError(first_state.time, error))
92 }
93}