use std::collections::BTreeMap;
use std::sync::Arc;
use cudarc::driver::LaunchConfig;
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::provider::{mc_resident_kernels, MC_RESIDENT_MODULE};
use xlog_cuda::{CudaKernelProvider, LaunchAsync};
use xlog_logic::ast::{Atom, BodyLiteral, Term};
use super::{McEvalConfig, McProgram, McSamplingMethod};
use crate::provenance::{GroundAtom, Value};
pub const MAX_ARITY: usize = 3;
pub const MAX_BODY: usize = 3;
pub const MAX_VARS: usize = 8;
pub const MAX_UNIVERSE: usize = 1 << 16;
pub const MAX_DOMAIN: usize = 256;
const ATOM_REC: usize = 6;
const RULE_REC: usize = 3 + 4 * ATOM_REC;
const CONST_FLAG: u32 = 0x8000_0000;
const RESIDENT_BUDGET_ENV: &str = "XLOG_MC_RESIDENT_MEMORY_BUDGET_BYTES";
const RESIDENT_BLOCKS_PER_WORLD_ENV: &str = "XLOG_MC_RESIDENT_BLOCKS_PER_WORLD";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResidentRejectKind {
Negation,
EpistemicLiteral,
NonRelationalLiteral,
ArityTooHigh,
BodyTooLong,
TooManyVars,
UnboundedTerm,
DomainTooLarge,
UniverseTooLarge,
InconsistentArity,
AnnotatedDisjunctionUnsupported,
}
impl ResidentRejectKind {
pub fn as_str(self) -> &'static str {
match self {
ResidentRejectKind::Negation => "negation",
ResidentRejectKind::EpistemicLiteral => "epistemic_literal",
ResidentRejectKind::NonRelationalLiteral => "non_relational_literal",
ResidentRejectKind::ArityTooHigh => "arity_too_high",
ResidentRejectKind::BodyTooLong => "body_too_long",
ResidentRejectKind::TooManyVars => "too_many_vars",
ResidentRejectKind::UnboundedTerm => "unbounded_term",
ResidentRejectKind::DomainTooLarge => "domain_too_large",
ResidentRejectKind::UniverseTooLarge => "universe_too_large",
ResidentRejectKind::InconsistentArity => "inconsistent_arity",
ResidentRejectKind::AnnotatedDisjunctionUnsupported => {
"annotated_disjunction_unsupported"
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResidentRejection {
pub kind: ResidentRejectKind,
pub construct: String,
pub context: String,
}
impl ResidentRejection {
fn err(
kind: ResidentRejectKind,
construct: impl Into<String>,
context: impl Into<String>,
) -> Self {
ResidentRejection {
kind,
construct: construct.into(),
context: context.into(),
}
}
pub fn into_error(self) -> XlogError {
XlogError::Compilation(format!(
"resident MC engine rejected program [kind={}] construct=`{}` context=`{}`",
self.kind.as_str(),
self.construct,
self.context
))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum ConstKey {
Int(i64),
Sym(u32),
Str(String),
FloatBits(u64),
}
impl ConstKey {
fn from_value(v: &Value) -> ConstKey {
match v {
Value::I64(i) => ConstKey::Int(*i),
Value::Symbol(s) => ConstKey::Sym(*s),
Value::String(s) => ConstKey::Str(s.clone()),
Value::F64(bits) => ConstKey::FloatBits(*bits),
}
}
fn from_term(t: &Term) -> std::result::Result<TermClass, ResidentRejection> {
match t {
Term::Variable(name) => Ok(TermClass::Var(name.clone())),
Term::Integer(i) => Ok(TermClass::Const(ConstKey::Int(*i))),
Term::Symbol(s) => Ok(TermClass::Const(ConstKey::Sym(*s))),
Term::String(s) => Ok(TermClass::Const(ConstKey::Str(s.clone()))),
Term::Float(f) => Ok(TermClass::Const(ConstKey::FloatBits(f.to_bits()))),
other => Err(ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
format!("{:?}", other),
"rule term must be a variable or ground constant",
)),
}
}
}
enum TermClass {
Var(String),
Const(ConstKey),
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct McNoHostStats {
pub tracked_htod_calls: u64,
pub tracked_dtoh_calls: u64,
pub untracked_metadata_reads: u64,
pub engine_launches: u64,
pub host_loop_iterations: u64,
pub per_sample_host_launches: u64,
pub host_fixpoint_iterations: u64,
pub per_operator_host_allocations: u64,
}
impl McNoHostStats {
pub fn is_no_host(&self) -> bool {
self.tracked_htod_calls == 0
&& self.tracked_dtoh_calls == 0
&& self.untracked_metadata_reads == 0
&& self.host_loop_iterations == 0
&& self.per_sample_host_launches == 0
&& self.host_fixpoint_iterations == 0
&& self.per_operator_host_allocations == 0
}
}
pub struct McResidentResult {
pub query_counts: TrackedCudaSlice<u32>,
pub evidence_count: TrackedCudaSlice<u32>,
pub iter_trace: TrackedCudaSlice<u32>,
pub sparse_final_row_counts: TrackedCudaSlice<u32>,
pub sparse_offsets: TrackedCudaSlice<u32>,
pub resident_status_flags: TrackedCudaSlice<u32>,
pub total_samples: usize,
pub seed: u64,
pub confidence: f64,
pub sampling_method: McSamplingMethod,
pub num_queries: usize,
pub no_host: McNoHostStats,
}
#[derive(Debug, Clone)]
pub struct ResidentPlan {
pub universe_size: u32,
pub domain_size: u32,
pub max_iters: u32,
edb_slots: Vec<u32>,
pf_slot: Vec<u32>,
pf_var: Vec<u32>,
rule_data: Vec<u32>,
num_rules: u32,
q_slot: Vec<u32>,
ev_slot: Vec<u32>,
ev_expected: Vec<u8>,
ad_data: Vec<u32>,
num_ads: u32,
pub num_vars: usize,
bernoulli_probs: Vec<f32>,
}
fn resident_memory_budget_bytes() -> Result<Option<u64>> {
match std::env::var(RESIDENT_BUDGET_ENV) {
Ok(raw) => raw.parse::<u64>().map(Some).map_err(|e| {
XlogError::Execution(format!("invalid {RESIDENT_BUDGET_ENV} value `{raw}`: {e}"))
}),
Err(std::env::VarError::NotPresent) => Ok(None),
Err(e) => Err(XlogError::Execution(format!(
"invalid {RESIDENT_BUDGET_ENV}: {e}"
))),
}
}
fn resident_blocks_per_world() -> Result<u32> {
match std::env::var(RESIDENT_BLOCKS_PER_WORLD_ENV) {
Ok(raw) => {
let blocks = raw.parse::<u32>().map_err(|e| {
XlogError::Execution(format!(
"invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: {e}"
))
})?;
if blocks == 0 {
return Err(XlogError::Execution(format!(
"invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: must be >= 1"
)));
}
Ok(blocks)
}
Err(std::env::VarError::NotPresent) => Ok(1),
Err(e) => Err(XlogError::Execution(format!(
"invalid {RESIDENT_BLOCKS_PER_WORLD_ENV}: {e}"
))),
}
}
fn sat_mul(a: u64, b: u64) -> u64 {
a.saturating_mul(b)
}
fn sat_pow(mut base: u64, mut exp: u32) -> u64 {
let mut acc = 1u64;
while exp > 0 {
if exp & 1 == 1 {
acc = sat_mul(acc, base);
}
exp >>= 1;
if exp > 0 {
base = sat_mul(base, base);
}
}
acc
}
fn estimate_resident_bound_bytes(plan: &ResidentPlan, num_worlds: u32) -> u64 {
let worlds = num_worlds.max(1) as u64;
let vars = plan.num_vars.max(1) as u64;
let universe = plan.universe_size.max(1) as u64;
let meta_words = plan
.edb_slots
.len()
.saturating_add(plan.pf_slot.len())
.saturating_add(plan.pf_var.len())
.saturating_add(plan.rule_data.len())
.saturating_add(plan.q_slot.len())
.saturating_add(plan.ev_slot.len())
.saturating_add(plan.ev_expected.len())
.saturating_add(plan.ad_data.len())
.saturating_add(18);
let sparse_cap = universe;
let setup_bytes = sat_mul(2, vars)
.saturating_add(sat_mul(worlds, vars))
.saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), universe), 4))
.saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), sparse_cap), 16))
.saturating_add(sat_mul(sat_mul(worlds, 2), 4))
.saturating_add(sat_mul(worlds, 4))
.saturating_add(sat_mul(worlds.saturating_add(1), 4))
.saturating_add(sat_mul(worlds.saturating_mul(4).saturating_add(1), 4))
.saturating_add(sat_mul(meta_words as u64, 4))
.saturating_add(sat_mul(plan.q_slot.len().max(1) as u64, 4))
.saturating_add(4)
.saturating_add(sat_mul(worlds, 4));
let mut sparse_join_bytes = 0u64;
for rule in plan.rule_data.chunks_exact(RULE_REC) {
let n_body = rule[0];
if n_body < 2 {
continue;
}
let n_vars = rule[1];
let assignments = sat_pow(plan.domain_size.max(1) as u64, n_vars);
let row_words = (n_body as u64).saturating_add(1);
sparse_join_bytes =
sparse_join_bytes.max(sat_mul(sat_mul(worlds, assignments), row_words * 4));
}
setup_bytes.saturating_add(sparse_join_bytes)
}
struct PredInfo {
arity: usize,
base: u32,
}
struct Universe {
domain: BTreeMap<ConstKey, u32>,
preds: BTreeMap<String, PredInfo>,
domain_size: u32,
}
impl Universe {
fn stride0(&self, arity: usize) -> u32 {
if arity >= 2 {
self.domain_size.pow((arity - 1) as u32)
} else {
1
}
}
fn arg_stride(&self, arity: usize, arg_idx: usize) -> u32 {
if arity <= arg_idx + 1 {
1
} else {
self.domain_size.pow((arity - arg_idx - 1) as u32)
}
}
fn ground_slot(&self, atom: &GroundAtom) -> std::result::Result<u32, ResidentRejection> {
let info = self.preds.get(&atom.predicate).ok_or_else(|| {
ResidentRejection::err(
ResidentRejectKind::InconsistentArity,
atom.predicate.clone(),
"ground atom references unknown predicate",
)
})?;
if atom.args.len() != info.arity {
return Err(ResidentRejection::err(
ResidentRejectKind::InconsistentArity,
atom.predicate.clone(),
format!("expected arity {} got {}", info.arity, atom.args.len()),
));
}
let mut slot = info.base;
for (i, v) in atom.args.iter().enumerate() {
let key = ConstKey::from_value(v);
let idx = *self.domain.get(&key).ok_or_else(|| {
ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
format!("{:?}", v),
"ground constant absent from bounded domain",
)
})?;
slot += idx * self.arg_stride(info.arity, i);
}
Ok(slot)
}
}
pub fn compile_resident_plan(
mc: &McProgram,
) -> std::result::Result<ResidentPlan, ResidentRejection> {
let program = &mc.program;
let mut arities: BTreeMap<String, usize> = BTreeMap::new();
let mut note_pred = |pred: &str, arity: usize| -> std::result::Result<(), ResidentRejection> {
if arity > MAX_ARITY {
return Err(ResidentRejection::err(
ResidentRejectKind::ArityTooHigh,
pred.to_string(),
format!("arity {} exceeds max {}", arity, MAX_ARITY),
));
}
match arities.get(pred) {
Some(&existing) if existing != arity => Err(ResidentRejection::err(
ResidentRejectKind::InconsistentArity,
pred.to_string(),
format!("arity {} vs {}", existing, arity),
)),
_ => {
arities.insert(pred.to_string(), arity);
Ok(())
}
}
};
let mut domain: BTreeMap<ConstKey, u32> = BTreeMap::new();
let mut note_const = |key: ConstKey, domain: &mut BTreeMap<ConstKey, u32>| {
let next = domain.len() as u32;
domain.entry(key).or_insert(next);
};
for fact in program.facts() {
note_pred(&fact.head.predicate, fact.head.terms.len())?;
for t in &fact.head.terms {
match ConstKey::from_term(t)? {
TermClass::Const(k) => note_const(k, &mut domain),
TermClass::Var(_) => {
return Err(ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
fact.head.predicate.clone(),
"fact head contains a variable",
))
}
}
}
}
for pf in &mc.prob_facts {
note_pred(&pf.atom.predicate, pf.atom.args.len())?;
for v in &pf.atom.args {
note_const(ConstKey::from_value(v), &mut domain);
}
}
for q in &mc.queries {
note_pred(&q.predicate, q.args.len())?;
for v in &q.args {
note_const(ConstKey::from_value(v), &mut domain);
}
}
for (e, _) in &mc.evidence {
note_pred(&e.predicate, e.args.len())?;
for v in &e.args {
note_const(ConstKey::from_value(v), &mut domain);
}
}
for ad in &mc.annotated_disjunctions {
for atom in &ad.choices {
note_pred(&atom.predicate, atom.args.len())?;
for v in &atom.args {
note_const(ConstKey::from_value(v), &mut domain);
}
}
}
for rule in &program.rules {
if rule.is_fact() {
continue;
}
note_pred(&rule.head.predicate, rule.head.terms.len())?;
collect_atom_consts(&rule.head, &mut domain, &mut note_const)?;
if rule.body.len() > MAX_BODY {
return Err(ResidentRejection::err(
ResidentRejectKind::BodyTooLong,
rule.head.predicate.clone(),
format!("body length {} exceeds max {}", rule.body.len(), MAX_BODY),
));
}
for lit in &rule.body {
let atom = classify_body_literal(lit, &rule.head.predicate)?;
note_pred(&atom.predicate, atom.terms.len())?;
collect_atom_consts(atom, &mut domain, &mut note_const)?;
}
}
if domain.len() > MAX_DOMAIN {
return Err(ResidentRejection::err(
ResidentRejectKind::DomainTooLarge,
format!("{} constants", domain.len()),
format!("domain exceeds max {}", MAX_DOMAIN),
));
}
let domain_size = domain.len() as u32;
let mut preds: BTreeMap<String, PredInfo> = BTreeMap::new();
let mut base: u64 = 0;
for (pred, &arity) in &arities {
let slot_count: u64 = if arity == 0 {
1
} else {
(domain_size as u64).pow(arity as u32)
};
preds.insert(
pred.clone(),
PredInfo {
arity,
base: base as u32,
},
);
base += slot_count;
if base > MAX_UNIVERSE as u64 {
return Err(ResidentRejection::err(
ResidentRejectKind::UniverseTooLarge,
format!("{} slots", base),
format!("universe exceeds max {}", MAX_UNIVERSE),
));
}
}
let universe_size = base as u32;
let universe = Universe {
domain,
preds,
domain_size,
};
let mut edb_slots = Vec::new();
for fact in program.facts() {
let ga = ground_atom_from_atom(&fact.head)?;
edb_slots.push(universe.ground_slot(&ga)?);
}
let mut pf_slot = Vec::new();
let mut pf_var = Vec::new();
for pf in &mc.prob_facts {
pf_slot.push(universe.ground_slot(&pf.atom)?);
pf_var.push(pf.var_idx as u32);
}
let mut q_slot = Vec::new();
for q in &mc.queries {
q_slot.push(universe.ground_slot(q)?);
}
let mut ev_slot = Vec::new();
let mut ev_expected = Vec::new();
for (e, v) in &mc.evidence {
ev_slot.push(universe.ground_slot(e)?);
ev_expected.push(if *v { 1u8 } else { 0u8 });
}
let mut rule_data = Vec::new();
let mut num_rules = 0u32;
for rule in &program.rules {
if rule.is_fact() {
continue;
}
let rec = lower_rule(rule, &universe)?;
rule_data.extend_from_slice(&rec);
num_rules += 1;
}
let mut ad_data: Vec<u32> = Vec::new();
let mut num_ads = 0u32;
for ad in &mc.annotated_disjunctions {
let n_choices = ad.choices.len() as u32;
let n_dvars = ad.decision_vars.len() as u32;
ad_data.push(n_choices);
ad_data.push(n_dvars);
for &dv in &ad.decision_vars {
ad_data.push(u32::try_from(dv).map_err(|_| {
ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
"decision_var",
"AD decision var index exceeds u32",
)
})?);
}
for atom in &ad.choices {
ad_data.push(universe.ground_slot(atom)?);
}
num_ads += 1;
}
let max_iters = universe_size.saturating_add(1).max(1);
Ok(ResidentPlan {
universe_size,
domain_size,
max_iters,
edb_slots,
pf_slot,
pf_var,
rule_data,
num_rules,
q_slot,
ev_slot,
ev_expected,
ad_data,
num_ads,
num_vars: mc.bernoulli_probs.len(),
bernoulli_probs: mc.bernoulli_probs.clone(),
})
}
fn collect_atom_consts<F: FnMut(ConstKey, &mut BTreeMap<ConstKey, u32>)>(
atom: &Atom,
domain: &mut BTreeMap<ConstKey, u32>,
note_const: &mut F,
) -> std::result::Result<(), ResidentRejection> {
if atom.terms.len() > MAX_ARITY {
return Err(ResidentRejection::err(
ResidentRejectKind::ArityTooHigh,
atom.predicate.clone(),
format!("arity {} exceeds max {}", atom.terms.len(), MAX_ARITY),
));
}
for t in &atom.terms {
if let TermClass::Const(k) = ConstKey::from_term(t)? {
note_const(k, domain);
}
}
Ok(())
}
fn classify_body_literal<'a>(
lit: &'a BodyLiteral,
rule_ctx: &str,
) -> std::result::Result<&'a Atom, ResidentRejection> {
match lit {
BodyLiteral::Positive(a) => Ok(a),
BodyLiteral::Negated(a) => Err(ResidentRejection::err(
ResidentRejectKind::Negation,
a.predicate.clone(),
format!("negated literal in rule for `{}`", rule_ctx),
)),
BodyLiteral::Epistemic(l) => Err(ResidentRejection::err(
ResidentRejectKind::EpistemicLiteral,
l.atom.predicate.clone(),
format!("epistemic literal in rule for `{}`", rule_ctx),
)),
BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => {
Err(ResidentRejection::err(
ResidentRejectKind::NonRelationalLiteral,
"comparison/is/univ",
format!("non-relational literal in rule for `{}`", rule_ctx),
))
}
}
}
fn lower_rule(
rule: &xlog_logic::ast::Rule,
universe: &Universe,
) -> std::result::Result<Vec<u32>, ResidentRejection> {
let mut var_ids: BTreeMap<String, u32> = BTreeMap::new();
let assign_var = |name: &str,
var_ids: &mut BTreeMap<String, u32>|
-> std::result::Result<u32, ResidentRejection> {
if let Some(&id) = var_ids.get(name) {
return Ok(id);
}
let id = var_ids.len() as u32;
if id as usize >= MAX_VARS {
return Err(ResidentRejection::err(
ResidentRejectKind::TooManyVars,
rule.head.predicate.clone(),
format!("more than {} distinct variables", MAX_VARS),
));
}
var_ids.insert(name.to_string(), id);
Ok(id)
};
let body_atoms: Vec<&Atom> = {
let mut v = Vec::new();
for lit in &rule.body {
v.push(classify_body_literal(lit, &rule.head.predicate)?);
}
v
};
for t in &rule.head.terms {
if let TermClass::Var(name) = ConstKey::from_term(t)? {
assign_var(&name, &mut var_ids)?;
}
}
for atom in &body_atoms {
for t in &atom.terms {
if let TermClass::Var(name) = ConstKey::from_term(t)? {
assign_var(&name, &mut var_ids)?;
}
}
}
let n_vars = var_ids.len() as u32;
let encode_atom = |atom: &Atom| -> std::result::Result<[u32; ATOM_REC], ResidentRejection> {
let info = universe.preds.get(&atom.predicate).ok_or_else(|| {
ResidentRejection::err(
ResidentRejectKind::InconsistentArity,
atom.predicate.clone(),
"rule atom references unknown predicate",
)
})?;
let arity = info.arity as u32;
let mut rec = [0u32; ATOM_REC];
rec[0] = info.base;
rec[1] = arity;
rec[5] = universe.stride0(info.arity);
for (i, t) in atom.terms.iter().enumerate() {
let spec = match ConstKey::from_term(t)? {
TermClass::Var(name) => *var_ids.get(&name).expect("var assigned above"),
TermClass::Const(k) => {
let idx = *universe.domain.get(&k).ok_or_else(|| {
ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
format!("{:?}", k),
"rule constant absent from bounded domain",
)
})?;
CONST_FLAG | idx
}
};
rec[2 + i] = spec;
}
Ok(rec)
};
let mut rec = vec![0u32; RULE_REC];
rec[0] = body_atoms.len() as u32;
rec[1] = n_vars;
rec[2] = universe.domain_size;
let head_rec = encode_atom(&rule.head)?;
rec[3..3 + ATOM_REC].copy_from_slice(&head_rec);
for (bi, atom) in body_atoms.iter().enumerate() {
let a = encode_atom(atom)?;
let off = 3 + ATOM_REC + bi * ATOM_REC;
rec[off..off + ATOM_REC].copy_from_slice(&a);
}
Ok(rec)
}
fn ground_atom_from_atom(atom: &Atom) -> std::result::Result<GroundAtom, ResidentRejection> {
let mut args = Vec::with_capacity(atom.terms.len());
for t in &atom.terms {
let v = match t {
Term::Integer(i) => Value::I64(*i),
Term::Symbol(s) => Value::Symbol(*s),
Term::String(s) => Value::String(s.clone()),
Term::Float(f) => Value::F64(f.to_bits()),
other => {
return Err(ResidentRejection::err(
ResidentRejectKind::UnboundedTerm,
format!("{:?}", other),
"fact term must be a ground constant",
))
}
};
args.push(v);
}
Ok(GroundAtom {
predicate: atom.predicate.clone(),
args,
})
}
impl McProgram {
pub fn evaluate_resident_with_provider(
&self,
cfg: McEvalConfig,
provider: Arc<CudaKernelProvider>,
) -> Result<McResidentResult> {
cfg.validate()?;
let plan = compile_resident_plan(self).map_err(ResidentRejection::into_error)?;
run_resident(&plan, &cfg, self, provider)
}
pub fn evaluate_resident(&self, cfg: McEvalConfig) -> Result<McResidentResult> {
let provider = Arc::new(self.provider()?);
self.evaluate_resident_with_provider(cfg, provider)
}
}
fn run_resident(
plan: &ResidentPlan,
cfg: &McEvalConfig,
mc: &McProgram,
provider: Arc<CudaKernelProvider>,
) -> Result<McResidentResult> {
let (method, forcing) = mc.resolve_sampling_method(cfg.sampling_method)?;
let num_worlds = u32::try_from(cfg.samples)
.map_err(|_| XlogError::Execution("MC samples exceed u32::MAX".to_string()))?;
let blocks_per_world = resident_blocks_per_world()?;
let num_vars = plan.num_vars;
if let Some(budget_bytes) = resident_memory_budget_bytes()? {
let bound_bytes = estimate_resident_bound_bytes(plan, num_worlds);
if bound_bytes > budget_bytes {
return Err(XlogError::ResourceExhausted {
context: format!(
"resident_resource_budget operator=sparse_wcoj bound_bytes={bound_bytes} budget_bytes={budget_bytes}"
),
estimated_bytes: bound_bytes,
budget_bytes,
});
}
}
let dev = provider.device();
let mut d_force_mask = provider.memory().alloc::<u8>(num_vars.max(1))?;
let mut d_forced_value = provider.memory().alloc::<u8>(num_vars.max(1))?;
if method == McSamplingMethod::EvidenceClamping && num_vars > 0 {
provider.htod_sync_copy_into_tracked(&forcing.force_mask, &mut d_force_mask)?;
provider.htod_sync_copy_into_tracked(&forcing.forced_value, &mut d_forced_value)?;
} else {
dev.inner()
.memset_zeros(&mut d_force_mask)
.map_err(|e| XlogError::Kernel(format!("zero force_mask: {e}")))?;
dev.inner()
.memset_zeros(&mut d_forced_value)
.map_err(|e| XlogError::Kernel(format!("zero forced_value: {e}")))?;
}
let samples_device = if num_vars == 0 || cfg.samples == 0 {
provider.memory().alloc::<u8>(1)?
} else {
provider.sample_bernoulli_matrix_device(
&plan.bernoulli_probs,
cfg.samples,
cfg.seed,
&d_force_mask.slice(..),
&d_forced_value.slice(..),
)?
};
let u = plan.universe_size.max(1) as usize;
let rel_len = (num_worlds as usize)
.saturating_mul(u)
.saturating_mul(2)
.max(1);
let mut d_rel = provider.memory().alloc::<u32>(rel_len)?;
dev.inner()
.memset_zeros(&mut d_rel)
.map_err(|e| XlogError::Kernel(format!("zero rel: {e}")))?;
let sparse_cap = u.max(1);
let sparse_len = (num_worlds as usize)
.saturating_mul(2)
.saturating_mul(sparse_cap)
.max(1);
let mut d_sparse_columns = provider
.memory()
.alloc::<u32>(sparse_len.saturating_mul(4).max(1))?;
let mut d_sparse_counts = provider
.memory()
.alloc::<u32>((num_worlds as usize).saturating_mul(2).max(1))?;
let mut d_sparse_final_counts = provider
.memory()
.alloc::<u32>((num_worlds as usize).max(1))?;
let mut d_sparse_offsets = provider
.memory()
.alloc::<u32>((num_worlds as usize).saturating_add(1).max(1))?;
let mut d_resident_status_flags = provider.memory().alloc::<u32>(
(num_worlds as usize)
.saturating_mul(4)
.saturating_add(1)
.max(1),
)?;
dev.inner()
.memset_zeros(&mut d_sparse_columns)
.map_err(|e| XlogError::Kernel(format!("zero sparse_columns: {e}")))?;
dev.inner()
.memset_zeros(&mut d_sparse_counts)
.map_err(|e| XlogError::Kernel(format!("zero sparse_counts: {e}")))?;
dev.inner()
.memset_zeros(&mut d_sparse_final_counts)
.map_err(|e| XlogError::Kernel(format!("zero sparse_final_counts: {e}")))?;
dev.inner()
.memset_zeros(&mut d_sparse_offsets)
.map_err(|e| XlogError::Kernel(format!("zero sparse_offsets: {e}")))?;
dev.inner()
.memset_zeros(&mut d_resident_status_flags)
.map_err(|e| XlogError::Kernel(format!("zero resident_status_flags: {e}")))?;
let q_count = plan.q_slot.len();
let ev_expected_u32: Vec<u32> = plan.ev_expected.iter().map(|&b| b as u32).collect();
let mut meta: Vec<u32> = Vec::new();
let push_meta = |data: &[u32], meta: &mut Vec<u32>| -> u32 {
let off = meta.len() as u32;
meta.extend_from_slice(data);
off
};
let edb_off = push_meta(&plan.edb_slots, &mut meta);
let pf_slot_off = push_meta(&plan.pf_slot, &mut meta);
let pf_var_off = push_meta(&plan.pf_var, &mut meta);
let rules_off = push_meta(&plan.rule_data, &mut meta);
let q_off = push_meta(&plan.q_slot, &mut meta);
let ev_slot_off = push_meta(&plan.ev_slot, &mut meta);
let ev_exp_off = push_meta(&ev_expected_u32, &mut meta);
let ad_off = push_meta(&plan.ad_data, &mut meta);
let cfg_host: [u32; 19] = [
num_worlds,
plan.universe_size,
num_vars as u32,
plan.max_iters,
edb_off,
plan.edb_slots.len() as u32,
pf_slot_off,
pf_var_off,
plan.pf_slot.len() as u32,
rules_off,
plan.num_rules,
q_off,
q_count as u32,
ev_slot_off,
ev_exp_off,
plan.ev_slot.len() as u32,
ad_off,
plan.num_ads,
blocks_per_world,
];
let mut d_cfg = provider.memory().alloc::<u32>(cfg_host.len())?;
provider.htod_sync_copy_into_tracked(&cfg_host, &mut d_cfg)?;
let mut d_meta = provider.memory().alloc::<u32>(meta.len().max(1))?;
if !meta.is_empty() {
provider.htod_sync_copy_into_tracked(&meta, &mut d_meta)?;
}
let mut d_query_counts = provider.memory().alloc::<u32>(q_count.max(1))?;
dev.inner()
.memset_zeros(&mut d_query_counts)
.map_err(|e| XlogError::Kernel(format!("zero query_counts: {e}")))?;
let mut d_evidence_count = provider.memory().alloc::<u32>(1)?;
dev.inner()
.memset_zeros(&mut d_evidence_count)
.map_err(|e| XlogError::Kernel(format!("zero evidence_count: {e}")))?;
let mut d_iter_trace = provider.memory().alloc::<u32>(num_worlds.max(1) as usize)?;
dev.inner()
.memset_zeros(&mut d_iter_trace)
.map_err(|e| XlogError::Kernel(format!("zero iter_trace: {e}")))?;
let engine_fn = dev
.inner()
.get_func(MC_RESIDENT_MODULE, mc_resident_kernels::MC_RESIDENT_ENGINE)
.ok_or_else(|| XlogError::Kernel("mc_resident_engine kernel not found".to_string()))?;
dev.synchronize()?;
let pre = provider.host_transfer_stats();
let pre_untracked = provider.untracked_metadata_dtoh_count();
let pre_allocs = provider.memory().alloc_count();
let mut engine_launches = 0u64;
let block_dim = 128u32;
let grid_dim = num_worlds
.max(1)
.checked_mul(blocks_per_world)
.ok_or_else(|| {
XlogError::Execution(format!(
"resident grid overflow: worlds={num_worlds} blocks_per_world={blocks_per_world}"
))
})?;
let launch_cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (block_dim, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let args = (
&d_cfg,
&d_meta,
&mut d_rel,
&samples_device,
&mut d_query_counts,
&mut d_evidence_count,
&mut d_iter_trace,
&mut d_sparse_columns,
&mut d_sparse_counts,
&mut d_sparse_final_counts,
&mut d_sparse_offsets,
&mut d_resident_status_flags,
sparse_cap as u32,
);
if blocks_per_world == 1 {
engine_fn
.launch(launch_cfg, args)
.map_err(|e| XlogError::Kernel(format!("mc_resident_engine launch failed: {e}")))?;
} else {
engine_fn
.launch_cooperative(launch_cfg, args)
.map_err(|e| {
XlogError::Kernel(format!("mc_resident_engine cooperative launch failed: {e}"))
})?;
}
}
engine_launches += 1;
dev.synchronize()?;
let post = provider.host_transfer_stats();
let post_untracked = provider.untracked_metadata_dtoh_count();
let post_allocs = provider.memory().alloc_count();
let no_host = McNoHostStats {
tracked_htod_calls: post.htod_calls.saturating_sub(pre.htod_calls),
tracked_dtoh_calls: post.dtoh_calls.saturating_sub(pre.dtoh_calls),
untracked_metadata_reads: post_untracked.saturating_sub(pre_untracked),
engine_launches,
host_loop_iterations: 0,
host_fixpoint_iterations: 0,
per_operator_host_allocations: post_allocs.saturating_sub(pre_allocs),
per_sample_host_launches: 0,
};
Ok(McResidentResult {
query_counts: d_query_counts,
evidence_count: d_evidence_count,
iter_trace: d_iter_trace,
sparse_final_row_counts: d_sparse_final_counts,
sparse_offsets: d_sparse_offsets,
resident_status_flags: d_resident_status_flags,
total_samples: cfg.samples,
seed: cfg.seed,
confidence: cfg.confidence,
sampling_method: method,
num_queries: q_count,
no_host,
})
}