use std::collections::HashMap;
use crate::errors::CausalityError;
use crate::prelude::{Identifiable, IdentificationValue, NumericalValue};
pub trait Causable: Identifiable
{
fn explain(&self) -> Result<String, CausalityError>;
fn is_active(&self) -> bool;
fn is_singleton(&self) -> bool;
fn verify_single_cause(
&self,
obs: &NumericalValue,
)
-> Result<bool, CausalityError>;
fn verify_all_causes(
&self,
data: &[NumericalValue],
data_index: Option<&HashMap<IdentificationValue, IdentificationValue>>,
)
-> Result<bool, CausalityError>;
}
pub trait CausableReasoning<T>
where
T: Causable,
{
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn to_vec(&self) -> Vec<T>;
fn get_all_items(&self) -> Vec<&T>;
fn get_all_causes_true(&self)
-> bool
{
for cause in self.get_all_items() {
if !cause.is_active() {
return false;
}
}
true
}
fn get_all_active_causes(&self)
-> Vec<&T>
{
self.get_all_items().into_iter().filter(|cause| cause.is_active()).collect()
}
fn get_all_inactive_causes(&self)
-> Vec<&T>
{
self.get_all_items().into_iter().filter(|cause| !cause.is_active()).collect()
}
fn number_active(&self)
-> NumericalValue
{
self.get_all_items().iter().filter(|c| c.is_active()).count() as NumericalValue
}
fn percent_active(&self)
-> NumericalValue
{
let count = self.number_active();
let total = self.len() as NumericalValue;
(count / total) * (100 as NumericalValue)
}
fn reason_all_causes(
&self,
data: &[NumericalValue]
)
-> Result<bool, CausalityError>
{
if self.is_empty() {
return Err(CausalityError("Causality collection is empty".into()));
}
for (i, cause) in self.get_all_items().iter().enumerate() {
let valid = if cause.is_singleton()
{
match cause.verify_single_cause(data.get(i).expect("failed to get value")) {
Ok(res) => res,
Err(e) => return Err(e),
}
} else {
match cause.verify_all_causes(data, None) {
Ok(res) => res,
Err(e) => return Err(e),
}
};
if !valid {
return Ok(false);
}
}
Ok(true)
}
fn explain(&self)
-> String
{
let mut explanation = String::new();
for cause in self.get_all_items() {
explanation.push('\n');
explanation.push_str(format!(" * {}", cause.explain().unwrap()).as_str());
explanation.push('\n');
}
explanation
}
}