use xlog_core::{Result, XlogError};
use super::{McProgram, McSamplingMethod};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ForceabilityReason {
AllForceable,
ContainsDerivedEvidence,
ContainsNegativeAdHeadEvidence,
NoEvidence,
}
#[derive(Debug, Clone)]
pub struct EvidenceForcing {
pub force_mask: Vec<u8>,
pub forced_value: Vec<u8>,
pub forceable: bool,
pub reason: ForceabilityReason,
}
impl McProgram {
pub(super) fn resolve_sampling_method(
&self,
requested: Option<McSamplingMethod>,
) -> Result<(McSamplingMethod, EvidenceForcing)> {
let forcing = self.compile_evidence_forcing()?;
let method = match requested {
Some(McSamplingMethod::EvidenceClamping) => {
if !forcing.forceable {
return Err(XlogError::Execution(format!(
"Cannot use EvidenceClamping: {:?}",
forcing.reason
)));
}
McSamplingMethod::EvidenceClamping
}
Some(McSamplingMethod::Rejection) => McSamplingMethod::Rejection,
None => {
if forcing.forceable {
McSamplingMethod::EvidenceClamping
} else {
McSamplingMethod::Rejection
}
}
};
Ok((method, forcing))
}
pub fn compile_evidence_forcing(&self) -> Result<EvidenceForcing> {
let num_vars = self.bernoulli_probs.len();
let mut force_mask = vec![0u8; num_vars];
let mut forced_value = vec![0u8; num_vars];
if self.evidence.is_empty() {
return Ok(EvidenceForcing {
force_mask,
forced_value,
forceable: false,
reason: ForceabilityReason::NoEvidence,
});
}
for (atom, expected) in &self.evidence {
if let Some(spec) = self.prob_facts.iter().find(|s| &s.atom == atom) {
force_mask[spec.var_idx] = 1;
forced_value[spec.var_idx] = if *expected { 1 } else { 0 };
continue;
}
let mut found_ad = false;
for ad in &self.annotated_disjunctions {
if let Some(choice_idx) = ad.choices.iter().position(|c| c == atom) {
if !*expected {
return Ok(EvidenceForcing {
force_mask: vec![0u8; num_vars],
forced_value: vec![0u8; num_vars],
forceable: false,
reason: ForceabilityReason::ContainsNegativeAdHeadEvidence,
});
}
let num_decision_vars = ad.decision_vars.len();
if choice_idx < num_decision_vars {
for i in 0..choice_idx {
force_mask[ad.decision_vars[i]] = 1;
forced_value[ad.decision_vars[i]] = 0;
}
force_mask[ad.decision_vars[choice_idx]] = 1;
forced_value[ad.decision_vars[choice_idx]] = 1;
} else {
for &dv in &ad.decision_vars {
force_mask[dv] = 1;
forced_value[dv] = 0;
}
}
found_ad = true;
break;
}
}
if found_ad {
continue;
}
return Ok(EvidenceForcing {
force_mask: vec![0u8; num_vars],
forced_value: vec![0u8; num_vars],
forceable: false,
reason: ForceabilityReason::ContainsDerivedEvidence,
});
}
Ok(EvidenceForcing {
force_mask,
forced_value,
forceable: true,
reason: ForceabilityReason::AllForceable,
})
}
}