#![cfg(feature = "host-io")]
use xlog_cuda::CudaDevice;
use xlog_prob::mc::{ForceabilityReason, McEvalConfig, McProgram, McSamplingMethod};
use xlog_prob::provenance::Value;
fn has_cuda_device() -> bool {
CudaDevice::new(0).is_ok()
}
fn value_as_symbol_name(v: &Value) -> Option<String> {
match v {
Value::Symbol(id) => Some(xlog_core::symbol::resolve(*id)),
Value::String(s) => Some(s.clone()),
_ => None,
}
}
fn prob_of_atom(result: &xlog_prob::mc::McResult, predicate: &str) -> f64 {
result
.query_estimates
.iter()
.find(|q| q.atom.predicate == predicate && q.atom.args.is_empty())
.unwrap_or_else(|| panic!("missing query for {}()", predicate))
.prob
}
fn mc_config(
samples: usize,
seed: u64,
max_nonmonotone_iterations: usize,
sampling_method: Option<McSamplingMethod>,
) -> McEvalConfig {
let mut config = McEvalConfig::default();
config.samples = samples;
config.seed = seed;
config.confidence = 0.95;
config.max_nonmonotone_iterations = max_nonmonotone_iterations;
config.sampling_method = sampling_method;
config
}
#[test]
fn test_mc_probabilistic_fact_marginal_is_reasonable() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.7::rain().
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(5_000, 123, 128, None);
let result = program.evaluate(cfg).unwrap();
let p = prob_of_atom(&result, "rain");
assert!((p - 0.7).abs() < 0.06, "p={}", p);
assert_eq!(result.evidence_samples, result.total_samples);
}
#[test]
fn test_mc_wet_conditioning_close_to_exact() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.7::rain().
0.2::sprinkler().
wet() :- rain().
wet() :- sprinkler().
evidence(wet(), true).
query(rain()).
query(sprinkler()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(20_000, 7, 128, None);
let result = program.evaluate(cfg).unwrap();
let p_wet = 1.0 - (1.0 - 0.7) * (1.0 - 0.2);
let expected_rain = 0.7 / p_wet;
let expected_sprinkler = 0.2 / p_wet;
let got_rain = prob_of_atom(&result, "rain");
let got_sprinkler = prob_of_atom(&result, "sprinkler");
assert!(
(got_rain - expected_rain).abs() < 0.04,
"got_rain={}",
got_rain
);
assert!(
(got_sprinkler - expected_sprinkler).abs() < 0.04,
"got_sprinkler={}",
got_sprinkler
);
assert!(result.evidence_samples > 0);
}
#[test]
fn test_mc_nonmonotone_recursion_runs_and_is_stable() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.5::flip().
p() :- flip().
q() :- not p().
p() :- not q().
query(p()).
query(flip()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(2_000, 999, 128, None);
let result = program.evaluate(cfg).unwrap();
let p_flip = prob_of_atom(&result, "flip");
let p_p = prob_of_atom(&result, "p");
assert!((p_flip - 0.5).abs() < 0.08, "p_flip={}", p_flip);
assert!((p_p - 0.5).abs() < 0.08, "p_p={}", p_p);
assert!(result.nonmonotone_sccs > 0);
}
#[test]
fn test_mc_annotated_disjunction_is_exclusive_under_evidence() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.3::coin(1); 0.3::coin(2).
evidence(coin(1), true).
query(coin(2)).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(2_000, 2026, 128, None);
let result = program.evaluate(cfg).unwrap();
let p_coin2 = result
.query_estimates
.iter()
.find(|q| {
q.atom.predicate == "coin"
&& q.atom.args.len() == 1
&& q.atom.args[0] == xlog_prob::provenance::Value::I64(2)
})
.unwrap_or_else(|| panic!("missing query for coin(2)"))
.prob;
assert_eq!(p_coin2, 0.0);
assert!(result.evidence_samples > 0);
}
#[test]
fn test_evidence_forcing_prob_fact_true() {
let src = r#"
0.3::rain().
0.7::sprinkler().
evidence(rain(), true).
query(sprinkler()).
"#;
let program = McProgram::compile_source(src).unwrap();
let forcing = program.compile_evidence_forcing().unwrap();
assert!(forcing.forceable);
assert_eq!(forcing.reason, ForceabilityReason::AllForceable);
assert_eq!(forcing.force_mask[0], 1);
assert_eq!(forcing.forced_value[0], 1);
assert_eq!(forcing.force_mask[1], 0);
}
#[test]
fn test_evidence_forcing_prob_fact_false() {
let src = r#"
0.3::rain().
evidence(rain(), false).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let forcing = program.compile_evidence_forcing().unwrap();
assert!(forcing.forceable);
assert_eq!(forcing.reason, ForceabilityReason::AllForceable);
assert_eq!(forcing.force_mask[0], 1);
assert_eq!(forcing.forced_value[0], 0);
}
#[test]
fn test_evidence_forcing_derived_atom_not_forceable() {
let src = r#"
0.3::rain().
wet() :- rain().
evidence(wet(), true).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let forcing = program.compile_evidence_forcing().unwrap();
assert!(!forcing.forceable);
assert_eq!(forcing.reason, ForceabilityReason::ContainsDerivedEvidence);
}
#[test]
fn test_evidence_forcing_ad_3way_middle_head() {
let src = r#"
0.2::color(red); 0.3::color(blue); 0.4::color(green).
evidence(color(blue), true).
query(color(red)).
query(color(green)).
"#;
let program = McProgram::compile_source(src).unwrap();
let forcing = program.compile_evidence_forcing().unwrap();
assert!(
forcing.forceable,
"3-way AD positive evidence should be forceable"
);
assert_eq!(forcing.force_mask[0], 1);
assert_eq!(forcing.forced_value[0], 0);
assert_eq!(forcing.force_mask[1], 1);
assert_eq!(forcing.forced_value[1], 1);
assert_eq!(forcing.force_mask[2], 0); }
#[test]
fn test_evidence_forcing_ad_last_head_no_none() {
let src = r#"
0.4::coin(heads); 0.6::coin(tails).
evidence(coin(tails), true).
query(coin(heads)).
"#;
let program = McProgram::compile_source(src).unwrap();
let forcing = program.compile_evidence_forcing().unwrap();
assert!(forcing.forceable);
assert_eq!(forcing.force_mask[0], 1);
assert_eq!(forcing.forced_value[0], 0); }
#[test]
fn test_evidence_clamping_prob_fact_true_matches_exact() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.7::rain().
0.2::sprinkler().
evidence(sprinkler(), true).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(5_000, 42, 128, None);
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::EvidenceClamping);
assert_eq!(result.evidence_samples, result.total_samples);
let p = prob_of_atom(&result, "rain");
assert!((p - 0.7).abs() < 0.06, "p={}", p);
}
#[test]
fn test_evidence_clamping_prob_fact_false_matches_exact() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.7::rain().
0.2::sprinkler().
evidence(sprinkler(), false).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(5_000, 42, 128, None);
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::EvidenceClamping);
assert_eq!(result.evidence_samples, result.total_samples);
let p = prob_of_atom(&result, "rain");
assert!((p - 0.7).abs() < 0.06, "p={}", p);
}
#[test]
fn test_evidence_clamping_all_samples_count() {
if !has_cuda_device() {
eprintln!("Skipping test: no CUDA device available");
return;
}
let src = r#"
0.01::rare().
0.5::other().
evidence(rare(), true).
query(other()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(1000, 7, 128, None);
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::EvidenceClamping);
assert_eq!(result.evidence_samples, 1000);
let p = prob_of_atom(&result, "other");
assert!((p - 0.5).abs() < 0.05, "p={}", p);
}
#[test]
fn test_evidence_clamping_derived_evidence_falls_back() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src = r#"
0.3::rain().
wet() :- rain().
evidence(wet(), true).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(5_000, 7, 128, None); let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::Rejection);
let p = prob_of_atom(&result, "rain");
assert!((p - 1.0).abs() < 0.05, "p={}", p);
}
#[test]
fn test_evidence_clamping_negative_ad_falls_back() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src = r#"
0.3::coin(1); 0.3::coin(2).
evidence(coin(1), false).
query(coin(2)).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(2_000, 2026, 128, None);
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::Rejection);
}
#[test]
fn test_explicit_clamping_unforceable_evidence_errors() {
let src = r#"
0.3::rain().
wet() :- rain().
evidence(wet(), true).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(1000, 7, 128, Some(McSamplingMethod::EvidenceClamping));
let result = program.evaluate(cfg);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("EvidenceClamping") || err.contains("forceable"),
"Error should mention clamping: {}",
err
);
}
#[test]
fn test_sampling_method_in_result_metadata() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src_no_ev = r#"
0.5::a().
query(a()).
"#;
let prog = McProgram::compile_source(src_no_ev).unwrap();
let cfg = mc_config(100, 0, 128, None);
let result = prog.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::Rejection);
let src_ev = r#"
0.5::a().
0.3::b().
evidence(a(), true).
query(b()).
"#;
let prog2 = McProgram::compile_source(src_ev).unwrap();
let cfg2 = mc_config(100, 0, 128, None);
let result2 = prog2.evaluate(cfg2).unwrap();
assert_eq!(result2.sampling_method, McSamplingMethod::EvidenceClamping);
}
#[test]
fn test_rejection_unchanged() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src = r#"
0.7::rain().
0.2::sprinkler().
evidence(sprinkler(), true).
query(rain()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(5_000, 7, 128, Some(McSamplingMethod::Rejection));
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::Rejection);
let p = prob_of_atom(&result, "rain");
assert!((p - 0.7).abs() < 0.06, "p={}", p);
assert!(result.evidence_samples < result.total_samples);
}
#[test]
fn test_evidence_clamping_ad_head_3way() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src = r#"
0.2::color(red); 0.3::color(blue); 0.4::color(green).
evidence(color(blue), true).
query(color(red)).
query(color(green)).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(10_000, 42, 128, None);
let result = program.evaluate(cfg).unwrap();
assert_eq!(result.sampling_method, McSamplingMethod::EvidenceClamping);
assert_eq!(result.evidence_samples, result.total_samples);
let p_red = result
.query_estimates
.iter()
.find(|q| {
q.atom.predicate == "color"
&& !q.atom.args.is_empty()
&& value_as_symbol_name(&q.atom.args[0]).as_deref() == Some("red")
})
.expect("missing query for color(red)")
.prob;
let p_green = result
.query_estimates
.iter()
.find(|q| {
q.atom.predicate == "color"
&& !q.atom.args.is_empty()
&& value_as_symbol_name(&q.atom.args[0]).as_deref() == Some("green")
})
.expect("missing query for color(green)")
.prob;
assert_eq!(p_red, 0.0, "P(color(red) | color(blue)) should be 0.0");
assert_eq!(p_green, 0.0, "P(color(green) | color(blue)) should be 0.0");
}
#[test]
fn test_mc_sample_reset_plan_preserves_base_and_clears_sampled_relations() {
if !has_cuda_device() {
eprintln!("Skipping: no CUDA device");
return;
}
let src = r#"
base().
0.5::flip().
seen_base() :- base().
seen_flip() :- flip().
query(seen_base()).
query(seen_flip()).
"#;
let program = McProgram::compile_source(src).unwrap();
let cfg = mc_config(2000, 9, 64, Some(McSamplingMethod::Rejection));
let result = program.evaluate_gpu(cfg).unwrap();
let p_base = prob_of_atom(&result, "seen_base");
let p_flip = prob_of_atom(&result, "seen_flip");
assert!((p_base - 1.0).abs() < 1e-9, "p_base={}", p_base);
assert!((p_flip - 0.5).abs() < 0.08, "p_flip={}", p_flip);
}
#[test]
fn test_mc_timing_breakdown_totals_sum() {
let t = xlog_prob::mc::McTimingBreakdown {
sampler_us: 10,
sample_reset_us: 20,
sample_build_us: 30,
eval_us: 40,
count_us: 50,
};
assert_eq!(t.total_us(), 150);
}
#[test]
fn test_count_strategy_clamped_skips_evidence_side() {
use xlog_prob::mc::{McCountStrategy, McSamplingMethod};
assert_eq!(
McCountStrategy::from_method(McSamplingMethod::EvidenceClamping),
McCountStrategy::QueriesOnly,
);
assert_eq!(
McCountStrategy::from_method(McSamplingMethod::Rejection),
McCountStrategy::QueriesAndEvidence,
);
}