use std::sync::Arc;
use std::ffi::c_void;
use cudarc::driver::LaunchConfig;
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::provider::sat_kernels;
use xlog_cuda::provider::SAT_MODULE;
use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
use xlog_solve::{GpuCdclConfig, GpuCdclSolver, GpuCnf};
#[cfg(debug_assertions)]
use crate::compilation::gpu_d4::validate_cnf_gpu;
use crate::gpu::GpuXgcf;
const MAX_GRID_X: u64 = 65_535;
fn checked_launch_grid(elements: u32, block: u32, context: &str) -> Result<u32> {
if block == 0 {
return Err(XlogError::Kernel(format!(
"{context}: CUDA launch block size must be nonzero"
)));
}
let grid = if elements == 0 {
1
} else {
u64::from(elements).div_ceil(u64::from(block))
};
if grid > MAX_GRID_X {
return Err(XlogError::Kernel(format!(
"{context}: launch grid {grid} exceeds x-dimension limit {MAX_GRID_X} \
for {elements} elements with block size {block}"
)));
}
Ok(grid as u32)
}
fn checked_clause_offset_span(clause_cap: u32, context: &str) -> Result<u32> {
clause_cap
.checked_add(1)
.ok_or_else(|| XlogError::Kernel(format!("{context}: clause offset span overflow")))
}
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub struct GpuEquivalenceConfig {
pub cdcl: GpuCdclConfig,
pub reuse_workspace: bool,
}
pub struct GpuEquivalenceQueries {
pub q1: GpuCnf,
pub q2: GpuCnf,
pub q2_unsat_var_base: TrackedCudaSlice<u32>,
}
struct CircuitCnf {
cnf: GpuCnf,
internal_prefix: TrackedCudaSlice<u32>,
}
fn build_circuit_cnf(
provider: &Arc<CudaKernelProvider>,
circuit: &GpuXgcf,
base_num_vars: &TrackedCudaSlice<u32>,
base_var_cap: u32,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<CircuitCnf> {
if base_var_cap == 0 {
return Err(XlogError::Compilation(
"GPU equivalence verifier requires base_var_cap > 0".to_string(),
));
}
if circuit.max_var() > base_var_cap {
return Err(XlogError::Compilation(format!(
"Circuit references var {} but base CNF has only {} vars",
circuit.max_var(),
base_var_cap
)));
}
let num_nodes = circuit.num_nodes();
if num_nodes == 0 {
return Err(XlogError::Compilation(
"GPU equivalence verifier requires circuit with num_nodes > 0".to_string(),
));
}
if circuit.root() as usize >= num_nodes {
return Err(XlogError::Compilation(format!(
"GPU equivalence verifier: circuit root {} out of bounds (num_nodes={})",
circuit.root(),
num_nodes
)));
}
let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
XlogError::Compilation(format!(
"GPU equivalence verifier: circuit num_nodes {} exceeds u32::MAX",
num_nodes
))
})?;
let num_edges = circuit.num_edges();
let n64 = num_nodes as u64;
let e64 = num_edges as u64;
let var_cap = u32::try_from((base_var_cap as u64).saturating_add(n64))
.map_err(|_| XlogError::Kernel("Circuit CNF var capacity exceeds u32::MAX".to_string()))?;
let clause_cap =
u32::try_from(e64.checked_add(4u64.saturating_mul(n64)).ok_or_else(|| {
XlogError::Kernel("Circuit CNF clause capacity overflow".to_string())
})?)
.map_err(|_| {
XlogError::Kernel("Circuit CNF clause capacity exceeds u32::MAX".to_string())
})?;
let lit_cap = u32::try_from(
(3u64.saturating_mul(e64))
.checked_add(12u64.saturating_mul(n64))
.ok_or_else(|| {
XlogError::Kernel("Circuit CNF literal capacity overflow".to_string())
})?,
)
.map_err(|_| XlogError::Kernel("Circuit CNF literal capacity exceeds u32::MAX".to_string()))?;
let memory = provider.memory();
let device = provider.device().inner();
let mut internal_prefix = memory.alloc::<u32>(num_nodes)?;
let mut clause_base = memory.alloc::<u32>(num_nodes)?;
let mut lit_base = memory.alloc::<u32>(num_nodes)?;
let counts_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COUNTS)
.ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_counts kernel not found".to_string()))?;
let block = 256u32;
let grid = checked_launch_grid(num_nodes_u32, block, "sat_xgcf_cnf_counts")?;
unsafe {
counts_fn.clone().launch(
LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
circuit.node_type(),
circuit.child_offsets(),
num_nodes_u32,
&mut internal_prefix,
&mut clause_base,
&mut lit_base,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_counts failed: {}", e)))?;
let mut internal_last = memory.alloc::<u32>(1)?;
let mut clause_last = memory.alloc::<u32>(1)?;
let mut lit_last = memory.alloc::<u32>(1)?;
let capture_last_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_CAPTURE_LAST_COUNTS)
.ok_or_else(|| {
XlogError::Kernel("sat_xgcf_cnf_capture_last_counts kernel not found".to_string())
})?;
unsafe {
capture_last_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(
&internal_prefix,
&clause_base,
&lit_base,
num_nodes_u32,
&mut internal_last,
&mut clause_last,
&mut lit_last,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_capture_last_counts failed: {}", e)))?;
provider.exclusive_scan_u32_inplace(&mut internal_prefix, num_nodes_u32)?;
provider.exclusive_scan_u32_inplace(&mut clause_base, num_nodes_u32)?;
provider.exclusive_scan_u32_inplace(&mut lit_base, num_nodes_u32)?;
let d_num_vars = memory.alloc::<u32>(1)?;
let d_num_clauses = memory.alloc::<u32>(1)?;
let d_num_lits = memory.alloc::<u32>(1)?;
let mut d_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
let d_lits = memory.alloc::<i32>(lit_cap as usize)?;
let totals_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COMPUTE_TOTALS)
.ok_or_else(|| {
XlogError::Kernel("sat_xgcf_cnf_compute_totals kernel not found".to_string())
})?;
let mut totals_params: Vec<*mut c_void> = vec![
(&internal_prefix).as_kernel_param(),
(&clause_base).as_kernel_param(),
(&lit_base).as_kernel_param(),
(&internal_last).as_kernel_param(),
(&clause_last).as_kernel_param(),
(&lit_last).as_kernel_param(),
num_nodes_u32.as_kernel_param(),
(base_num_vars).as_kernel_param(),
clause_cap.as_kernel_param(),
lit_cap.as_kernel_param(),
(&d_num_vars).as_kernel_param(),
(&d_num_clauses).as_kernel_param(),
(&d_num_lits).as_kernel_param(),
];
unsafe {
totals_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
&mut totals_params,
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_compute_totals failed: {}", e)))?;
let emit_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_EMIT)
.ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_emit kernel not found".to_string()))?;
let mut params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
circuit.node_type().as_kernel_param(),
circuit.child_offsets().as_kernel_param(),
circuit.child_indices().as_kernel_param(),
circuit.lit().as_kernel_param(),
circuit.decision_var().as_kernel_param(),
circuit.decision_child_false().as_kernel_param(),
circuit.decision_child_true().as_kernel_param(),
(&internal_prefix).as_kernel_param(),
(&clause_base).as_kernel_param(),
(&lit_base).as_kernel_param(),
(base_num_vars).as_kernel_param(),
num_nodes_u32.as_kernel_param(),
(&d_offsets).as_kernel_param(),
(&d_lits).as_kernel_param(),
];
unsafe {
emit_fn.clone().launch(
LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
&mut params,
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_emit failed: {}", e)))?;
let term_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_CNF_WRITE_TERMINATOR)
.ok_or_else(|| {
XlogError::Kernel("sat_cnf_write_terminator kernel not found".to_string())
})?;
unsafe {
term_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(&mut d_offsets, &d_num_clauses, &d_num_lits),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_cnf_write_terminator failed: {}", e)))?;
Ok(CircuitCnf {
cnf: GpuCnf {
var_cap,
clause_cap,
lit_cap,
num_vars: d_num_vars,
num_clauses: d_num_clauses,
num_lits: d_num_lits,
clause_offsets: d_offsets,
literals: d_lits,
},
internal_prefix,
})
}
fn build_phi_and_not_c(
provider: &Arc<CudaKernelProvider>,
phi: &GpuCnf,
circuit: &GpuXgcf,
circuit_cnf: &CircuitCnf,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<GpuCnf> {
let device = provider.device().inner();
let memory = provider.memory();
let phi_clause_cap = phi.clause_cap;
let phi_lit_cap = phi.lit_cap;
let clause_cap = u32::try_from(
(phi_clause_cap as u64)
.checked_add(circuit_cnf.cnf.clause_cap as u64)
.and_then(|v| v.checked_add(1))
.ok_or_else(|| XlogError::Kernel("phi ∧ ¬C clause capacity overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("phi ∧ ¬C clause capacity exceeds u32::MAX".to_string()))?;
let lit_cap = u32::try_from(
(phi_lit_cap as u64)
.checked_add(circuit_cnf.cnf.lit_cap as u64)
.and_then(|v| v.checked_add(1))
.ok_or_else(|| XlogError::Kernel("phi ∧ ¬C literal capacity overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("phi ∧ ¬C literal capacity exceeds u32::MAX".to_string()))?;
let var_cap = circuit_cnf.cnf.var_cap;
let out_num_vars = memory.alloc::<u32>(1)?;
let out_num_clauses = memory.alloc::<u32>(1)?;
let out_num_lits = memory.alloc::<u32>(1)?;
let d_unused0 = memory.alloc::<u32>(1)?;
let d_unused1 = memory.alloc::<u32>(1)?;
let d_unused2 = memory.alloc::<u32>(1)?;
let mut d_zero = memory.alloc::<u32>(1)?;
provider
.htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
.map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
let copy_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
.ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
let block = 256u32;
let phi_copy_elems =
checked_clause_offset_span(phi_clause_cap, "sat_cnf_copy_into(phi)")?.max(phi_lit_cap);
let grid = checked_launch_grid(phi_copy_elems, block, "sat_cnf_copy_into(phi)")?;
unsafe {
copy_fn.clone().launch(
LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
(
&phi.clause_offsets,
&phi.literals,
&phi.num_clauses,
&phi.num_lits,
phi.clause_cap,
phi.lit_cap,
&d_zero,
&d_zero,
clause_cap,
lit_cap,
&mut out_offsets,
&mut out_lits,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(phi) failed: {}", e)))?;
let circuit_copy_elems =
checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
.max(circuit_cnf.cnf.lit_cap);
let grid_c = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
unsafe {
copy_fn.clone().launch(
LaunchConfig {
grid_dim: (grid_c, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
(
&circuit_cnf.cnf.clause_offsets,
&circuit_cnf.cnf.literals,
&circuit_cnf.cnf.num_clauses,
&circuit_cnf.cnf.num_lits,
circuit_cnf.cnf.clause_cap,
circuit_cnf.cnf.lit_cap,
&phi.num_clauses,
&phi.num_lits,
clause_cap,
lit_cap,
&mut out_offsets,
&mut out_lits,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
let unit_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
.ok_or_else(|| {
XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
})?;
let root = circuit.root();
let force_true: i32 = 0;
let out_var_cap = var_cap;
let out_clause_cap = clause_cap;
let out_lit_cap = lit_cap;
let mut params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
circuit.node_type().as_kernel_param(),
circuit.lit().as_kernel_param(),
(&circuit_cnf.internal_prefix).as_kernel_param(),
(&phi.num_vars).as_kernel_param(),
root.as_kernel_param(),
force_true.as_kernel_param(), (&phi.num_clauses).as_kernel_param(),
(&phi.num_lits).as_kernel_param(),
(&circuit_cnf.cnf.num_vars).as_kernel_param(),
(&circuit_cnf.cnf.num_clauses).as_kernel_param(),
(&circuit_cnf.cnf.num_lits).as_kernel_param(),
(&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), out_var_cap.as_kernel_param(),
out_clause_cap.as_kernel_param(),
out_lit_cap.as_kernel_param(),
(&out_num_vars).as_kernel_param(),
(&out_num_clauses).as_kernel_param(),
(&out_num_lits).as_kernel_param(),
(&d_unused0).as_kernel_param(),
(&d_unused1).as_kernel_param(),
(&d_unused2).as_kernel_param(),
(&out_offsets).as_kernel_param(),
(&out_lits).as_kernel_param(),
];
unsafe {
unit_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
&mut params,
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
Ok(GpuCnf {
var_cap,
clause_cap,
lit_cap,
num_vars: out_num_vars,
num_clauses: out_num_clauses,
num_lits: out_num_lits,
clause_offsets: out_offsets,
literals: out_lits,
})
}
fn build_c_and_not_phi(
provider: &Arc<CudaKernelProvider>,
phi: &GpuCnf,
circuit: &GpuXgcf,
circuit_cnf: &CircuitCnf,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<(GpuCnf, TrackedCudaSlice<u32>)> {
let device = provider.device().inner();
let memory = provider.memory();
let phi_clause_cap = phi.clause_cap;
let phi_lit_cap = phi.lit_cap;
let notphi_clause_cap = u32::try_from(
(phi_lit_cap as u64)
.checked_add(phi_clause_cap as u64)
.and_then(|v| v.checked_add(1))
.ok_or_else(|| XlogError::Kernel("¬phi clause count overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("¬phi clause count exceeds u32::MAX".to_string()))?;
let notphi_lit_cap = u32::try_from(
(phi_lit_cap as u64)
.checked_mul(3)
.and_then(|v| v.checked_add(2u64.saturating_mul(phi_clause_cap as u64)))
.ok_or_else(|| XlogError::Kernel("¬phi literal count overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("¬phi literal count exceeds u32::MAX".to_string()))?;
let var_cap = circuit_cnf
.cnf
.var_cap
.checked_add(phi_clause_cap)
.ok_or_else(|| XlogError::Kernel("C ∧ ¬phi var capacity overflow".to_string()))?;
let clause_cap = u32::try_from(
(circuit_cnf.cnf.clause_cap as u64)
.checked_add(1)
.and_then(|v| v.checked_add(notphi_clause_cap as u64))
.ok_or_else(|| XlogError::Kernel("C ∧ ¬phi clause capacity overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("C ∧ ¬phi clause capacity exceeds u32::MAX".to_string()))?;
let lit_cap = u32::try_from(
(circuit_cnf.cnf.lit_cap as u64)
.checked_add(1)
.and_then(|v| v.checked_add(notphi_lit_cap as u64))
.ok_or_else(|| XlogError::Kernel("C ∧ ¬phi literal capacity overflow".to_string()))?,
)
.map_err(|_| XlogError::Kernel("C ∧ ¬phi literal capacity exceeds u32::MAX".to_string()))?;
let out_num_vars = memory.alloc::<u32>(1)?;
let out_num_clauses = memory.alloc::<u32>(1)?;
let out_num_lits = memory.alloc::<u32>(1)?;
let mut d_zero = memory.alloc::<u32>(1)?;
provider
.htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
.map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
let mut d_extra_num_vars = memory.alloc::<u32>(1)?;
let mut d_extra_num_clauses = memory.alloc::<u32>(1)?;
let mut d_extra_num_lits = memory.alloc::<u32>(1)?;
let d_unsat_var_base = memory.alloc::<u32>(1)?;
let d_notphi_clause_base = memory.alloc::<u32>(1)?;
let d_notphi_lit_base = memory.alloc::<u32>(1)?;
let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
let copy_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
.ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
let block = 256u32;
let circuit_copy_elems =
checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
.max(circuit_cnf.cnf.lit_cap);
let grid = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
unsafe {
copy_fn.clone().launch(
LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
(
&circuit_cnf.cnf.clause_offsets,
&circuit_cnf.cnf.literals,
&circuit_cnf.cnf.num_clauses,
&circuit_cnf.cnf.num_lits,
circuit_cnf.cnf.clause_cap,
circuit_cnf.cnf.lit_cap,
&d_zero,
&d_zero,
clause_cap,
lit_cap,
&mut out_offsets,
&mut out_lits,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
let notphi_counts_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_NOT_PHI_COUNTS)
.ok_or_else(|| XlogError::Kernel("sat_not_phi_counts kernel not found".to_string()))?;
unsafe {
notphi_counts_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&phi.num_clauses,
&phi.num_lits,
&mut d_extra_num_vars,
&mut d_extra_num_clauses,
&mut d_extra_num_lits,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_not_phi_counts failed: {}", e)))?;
let unit_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
.ok_or_else(|| {
XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
})?;
let root = circuit.root();
let force_true: i32 = 1;
let out_var_cap = var_cap;
let out_clause_cap = clause_cap;
let out_lit_cap = lit_cap;
let mut params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
circuit.node_type().as_kernel_param(),
circuit.lit().as_kernel_param(),
(&circuit_cnf.internal_prefix).as_kernel_param(),
(&phi.num_vars).as_kernel_param(),
root.as_kernel_param(),
force_true.as_kernel_param(), (&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), (&circuit_cnf.cnf.num_vars).as_kernel_param(),
(&circuit_cnf.cnf.num_clauses).as_kernel_param(),
(&circuit_cnf.cnf.num_lits).as_kernel_param(),
(&d_extra_num_vars).as_kernel_param(), (&d_extra_num_clauses).as_kernel_param(), (&d_extra_num_lits).as_kernel_param(), out_var_cap.as_kernel_param(),
out_clause_cap.as_kernel_param(),
out_lit_cap.as_kernel_param(),
(&out_num_vars).as_kernel_param(),
(&out_num_clauses).as_kernel_param(),
(&out_num_lits).as_kernel_param(),
(&d_unsat_var_base).as_kernel_param(),
(&d_notphi_clause_base).as_kernel_param(),
(&d_notphi_lit_base).as_kernel_param(),
(&out_offsets).as_kernel_param(),
(&out_lits).as_kernel_param(),
];
unsafe {
unit_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
&mut params,
)
}
.map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
let not_phi_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_EMIT_NOT_PHI)
.ok_or_else(|| XlogError::Kernel("sat_emit_not_phi kernel not found".to_string()))?;
let block = 256u32;
let grid = checked_launch_grid(phi_clause_cap, block, "sat_emit_not_phi")?;
unsafe {
not_phi_fn.clone().launch(
LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&phi.clause_offsets,
&phi.literals,
&phi.num_clauses,
&d_unsat_var_base,
&d_notphi_clause_base,
&d_notphi_lit_base,
&mut out_offsets,
&mut out_lits,
),
)
}
.map_err(|e| XlogError::Kernel(format!("sat_emit_not_phi failed: {}", e)))?;
Ok((
GpuCnf {
var_cap,
clause_cap,
lit_cap,
num_vars: out_num_vars,
num_clauses: out_num_clauses,
num_lits: out_num_lits,
clause_offsets: out_offsets,
literals: out_lits,
},
d_unsat_var_base,
))
}
pub(crate) fn check_equivalence_gpu(
phi: &GpuCnf,
phi_decision_var_limit: &TrackedCudaSlice<u32>,
circuit: &GpuXgcf,
provider: &Arc<CudaKernelProvider>,
config: GpuEquivalenceConfig,
) -> Result<()> {
let queries = build_equivalence_queries_gpu(phi, circuit, provider)?;
#[cfg(debug_assertions)]
{
validate_cnf_gpu(&queries.q1, provider.as_ref())?;
validate_cnf_gpu(&queries.q2, provider.as_ref())?;
}
let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
if config.reuse_workspace {
let max_var_cap = std::cmp::max(queries.q1.var_cap, queries.q2.var_cap);
let max_clause_cap = std::cmp::max(queries.q1.clause_cap, queries.q2.clause_cap);
let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
solver.solve_expect_unsat_with_branch_limit_ws(
&mut ws,
&queries.q1,
phi_decision_var_limit,
)?;
solver.solve_expect_unsat_with_decision_ranges_ws(
&mut ws,
&queries.q2,
phi_decision_var_limit,
&queries.q2_unsat_var_base,
&phi.num_clauses,
)?;
} else {
solver.solve_expect_unsat_with_branch_limit(&queries.q1, phi_decision_var_limit)?;
solver.solve_expect_unsat_with_decision_ranges(
&queries.q2,
phi_decision_var_limit,
&queries.q2_unsat_var_base,
&phi.num_clauses,
)?;
}
Ok(())
}
pub fn build_equivalence_queries_gpu(
phi: &GpuCnf,
circuit: &GpuXgcf,
provider: &Arc<CudaKernelProvider>,
) -> Result<GpuEquivalenceQueries> {
let memory = provider.memory();
let mut compile_needed = memory.alloc::<u32>(1)?;
provider
.htod_launch_metadata_sync_copy_into(&[1u32], &mut compile_needed)
.map_err(|e| XlogError::Kernel(format!("Failed to upload compile_needed=1: {}", e)))?;
let circuit_cnf = build_circuit_cnf(
provider,
circuit,
&phi.num_vars,
phi.var_cap,
&compile_needed,
)?;
let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
let (q2, q2_unsat_var_base) =
build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
Ok(GpuEquivalenceQueries {
q1,
q2,
q2_unsat_var_base,
})
}
pub(crate) fn check_equivalence_gpu_gated(
phi: &GpuCnf,
phi_decision_var_limit: &TrackedCudaSlice<u32>,
circuit: &GpuXgcf,
provider: &Arc<CudaKernelProvider>,
config: GpuEquivalenceConfig,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<()> {
#[cfg(debug_assertions)]
eprintln!("[xlog-prob] equivalence: build_circuit_cnf");
let circuit_cnf = build_circuit_cnf(
provider,
circuit,
&phi.num_vars,
phi.var_cap,
compile_needed,
)?;
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after build_circuit_cnf failed: {}", e))
})?;
eprintln!("[xlog-prob] equivalence: build_phi_and_not_c");
}
let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, compile_needed)?;
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after build_phi_and_not_c failed: {}", e))
})?;
eprintln!("[xlog-prob] equivalence: build_c_and_not_phi");
}
let (q2, q2_unsat_var_base) =
build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, compile_needed)?;
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after build_c_and_not_phi failed: {}", e))
})?;
eprintln!(
"[xlog-prob] equivalence: caps: phi(v={} c={} l={}) circuit_cnf(v={} c={} l={}) q1(v={} c={} l={}) q2(v={} c={} l={})",
phi.var_cap,
phi.clause_cap,
phi.lit_cap,
circuit_cnf.cnf.var_cap,
circuit_cnf.cnf.clause_cap,
circuit_cnf.cnf.lit_cap,
q1.var_cap,
q1.clause_cap,
q1.lit_cap,
q2.var_cap,
q2.clause_cap,
q2.lit_cap,
);
eprintln!("[xlog-prob] equivalence: solve_expect_unsat q1");
}
#[cfg(debug_assertions)]
{
validate_cnf_gpu(&q1, provider.as_ref())?;
validate_cnf_gpu(&q2, provider.as_ref())?;
}
let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
if config.reuse_workspace {
let max_var_cap = std::cmp::max(q1.var_cap, q2.var_cap);
let max_clause_cap = std::cmp::max(q1.clause_cap, q2.clause_cap);
let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
solver.solve_expect_unsat_with_branch_limit_gated_ws(
&mut ws,
&q1,
compile_needed,
phi_decision_var_limit,
)?;
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
})?;
eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
}
solver.solve_expect_unsat_with_decision_ranges_gated_ws(
&mut ws,
&q2,
compile_needed,
phi_decision_var_limit,
&q2_unsat_var_base,
&phi.num_clauses,
)?;
} else {
solver.solve_expect_unsat_with_branch_limit_gated(
&q1,
compile_needed,
phi_decision_var_limit,
)?;
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
})?;
eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
}
solver.solve_expect_unsat_with_decision_ranges_gated(
&q2,
compile_needed,
phi_decision_var_limit,
&q2_unsat_var_base,
&phi.num_clauses,
)?;
}
#[cfg(debug_assertions)]
{
provider.device().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after solve_expect_unsat(q2) failed: {}", e))
})?;
eprintln!("[xlog-prob] equivalence: done");
}
Ok(())
}
pub fn validate_equivalence_gpu(
phi: &GpuCnf,
phi_decision_var_limit: &TrackedCudaSlice<u32>,
circuit: &GpuXgcf,
provider: &Arc<CudaKernelProvider>,
config: GpuEquivalenceConfig,
) -> Result<()> {
check_equivalence_gpu(phi, phi_decision_var_limit, circuit, provider, config)
}
pub fn validate_equivalence_gpu_gated(
phi: &GpuCnf,
phi_decision_var_limit: &TrackedCudaSlice<u32>,
circuit: &GpuXgcf,
provider: &Arc<CudaKernelProvider>,
config: GpuEquivalenceConfig,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<()> {
check_equivalence_gpu_gated(
phi,
phi_decision_var_limit,
circuit,
provider,
config,
compile_needed,
)
}