use crate::scalar::Scalar;
#[derive(Clone, Debug)]
pub struct Predicate {
pub class_index: usize,
#[cfg(feature = "diagnostics")]
pub equation_residuals: Vec<Scalar>,
#[cfg(feature = "diagnostics")]
pub inequality_values: Vec<Scalar>,
pub equations_satisfied: Vec<bool>,
pub inequalities_satisfied: Vec<bool>,
pub overflow_promoted: bool,
}
impl Predicate {
#[allow(unused_variables)]
pub fn new(
class_index: usize,
eq_residuals: &[Scalar],
ineq_values: &[Scalar],
equations_satisfied: Vec<bool>,
inequalities_satisfied: Vec<bool>,
overflow_promoted: bool,
) -> Self {
Predicate {
class_index,
#[cfg(feature = "diagnostics")]
equation_residuals: eq_residuals.to_vec(),
#[cfg(feature = "diagnostics")]
inequality_values: ineq_values.to_vec(),
equations_satisfied,
inequalities_satisfied,
overflow_promoted,
}
}
pub fn is_satisfied(&self) -> bool {
self.equations_satisfied.iter().all(|&s| s)
&& self.inequalities_satisfied.iter().all(|&s| s)
}
pub fn has_overflow(&self) -> bool {
self.overflow_promoted
}
pub fn failing_indices(&self) -> Vec<usize> {
let mut failing = Vec::new();
for (i, &s) in self.equations_satisfied.iter().enumerate() {
if !s {
failing.push(i);
}
}
for (i, &s) in self.inequalities_satisfied.iter().enumerate() {
if !s {
failing.push(self.equations_satisfied.len() + i);
}
}
failing
}
#[cfg(feature = "diagnostics")]
pub fn residuals(&self) -> Vec<&Scalar> {
self.equation_residuals
.iter()
.chain(self.inequality_values.iter())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn satisfied_predicate() {
let p = Predicate::new(
0,
&[Scalar::from(0i64)],
&[Scalar::from(5i64)],
vec![true],
vec![true],
false,
);
assert!(p.is_satisfied());
assert!(p.failing_indices().is_empty());
}
#[test]
fn failed_equation() {
let p = Predicate::new(0, &[Scalar::from(3i64)], &[], vec![false], vec![], false);
assert!(!p.is_satisfied());
assert_eq!(p.failing_indices(), vec![0]);
}
#[test]
fn failed_inequality() {
let p = Predicate::new(
0,
&[Scalar::from(0i64)],
&[Scalar::from(0i64)],
vec![true],
vec![false],
false,
);
assert!(!p.is_satisfied());
}
#[test]
fn overflow_tracking() {
let p = Predicate::new(0, &[Scalar::from(0i64)], &[], vec![true], vec![], true);
assert!(p.is_satisfied());
assert!(p.has_overflow());
}
}