use std::ffi::c_void;
use std::sync::Arc;
use cudarc::driver::LaunchConfig;
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::provider::{sat_kernels, SAT_MODULE};
use xlog_cuda::{AsKernelParam, CudaKernelProvider, DeviceSlice, LaunchAsync};
use crate::gpu_cnf::GpuCnf;
const SAT_STATUS_UNSAT: i32 = 0;
const SAT_STATUS_SAT: i32 = 1;
struct GpuCdclRun {
assignment: TrackedCudaSlice<i8>,
#[allow(dead_code)]
decision_heap: TrackedCudaSlice<u32>,
#[allow(dead_code)]
decision_heap_pos: TrackedCudaSlice<u32>,
learned_offsets: TrackedCudaSlice<u32>,
learned_lits: TrackedCudaSlice<i32>,
proof_offsets: TrackedCudaSlice<u32>,
proof_data: TrackedCudaSlice<u32>,
out_status: TrackedCudaSlice<i32>,
out_error: TrackedCudaSlice<i32>,
out_learned_count: TrackedCudaSlice<u32>,
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct GpuCdclConfig {
pub max_learned_clauses: u32,
pub max_learned_lits: u32,
pub max_proof_u32: u32,
pub restart_base: u32,
pub reduce_interval: u32,
}
impl Default for GpuCdclConfig {
fn default() -> Self {
Self {
max_learned_clauses: 32_768,
max_learned_lits: 262_144,
max_proof_u32: 1_048_576,
restart_base: 100,
reduce_interval: 2000,
}
}
}
pub struct GpuCdclSolver {
provider: Arc<CudaKernelProvider>,
config: GpuCdclConfig,
}
pub struct GpuCdclWorkspace {
pub(crate) var_cap: usize,
pub(crate) clause_total_cap: usize,
pub(crate) assign: TrackedCudaSlice<i8>,
pub(crate) level: TrackedCudaSlice<u32>,
pub(crate) reason: TrackedCudaSlice<i32>,
pub(crate) var_activity: TrackedCudaSlice<u32>,
pub(crate) var_phase: TrackedCudaSlice<i8>,
pub(crate) decision_heap: TrackedCudaSlice<u32>,
pub(crate) decision_heap_pos: TrackedCudaSlice<u32>,
pub(crate) trail: TrackedCudaSlice<i32>,
pub(crate) trail_lim: TrackedCudaSlice<u32>,
pub(crate) seen: TrackedCudaSlice<u8>,
pub(crate) learnt_tmp: TrackedCudaSlice<i32>,
pub(crate) proof_vars_tmp: TrackedCudaSlice<u32>,
pub(crate) proof_reason_tmp: TrackedCudaSlice<u32>,
pub(crate) watch0_pos: TrackedCudaSlice<u32>, pub(crate) watch1_pos: TrackedCudaSlice<u32>, pub(crate) watch_head: TrackedCudaSlice<i32>, pub(crate) watch_next: TrackedCudaSlice<i32>, pub(crate) watch_prev: TrackedCudaSlice<i32>,
pub(crate) learned_offsets: TrackedCudaSlice<u32>, pub(crate) learned_lits: TrackedCudaSlice<i32>, pub(crate) learned_deleted: TrackedCudaSlice<u8>, pub(crate) learned_lbd: TrackedCudaSlice<u32>, pub(crate) learned_activity: TrackedCudaSlice<u32>, pub(crate) learned_locked: TrackedCudaSlice<u8>,
pub(crate) proof_offsets: TrackedCudaSlice<u32>, pub(crate) proof_data: TrackedCudaSlice<u32>,
pub(crate) out_status: TrackedCudaSlice<i32>, pub(crate) out_error: TrackedCudaSlice<i32>, pub(crate) out_learned_count: TrackedCudaSlice<u32>, }
impl GpuCdclWorkspace {
#[inline]
pub(crate) fn reset_for_solve(&mut self) {
}
#[inline]
#[allow(dead_code)] pub(crate) fn var_cap(&self) -> usize {
self.var_cap
}
#[inline]
#[allow(dead_code)] pub(crate) fn clause_total_cap(&self) -> usize {
self.clause_total_cap
}
#[inline]
pub fn assign_device_ptr(&self) -> cudarc::driver::sys::CUdeviceptr {
self.assign.device_ptr_value()
}
}
pub struct GpuCdclRawOutput {
pub assignment: TrackedCudaSlice<i8>,
pub out_status: TrackedCudaSlice<i32>,
pub out_error: TrackedCudaSlice<i32>,
pub out_learned_count: TrackedCudaSlice<u32>,
}
fn checked_solver_len_add_one(context: &str, value: usize) -> Result<usize> {
value
.checked_add(1)
.ok_or_else(|| XlogError::Kernel(format!("{context} length overflow")))
}
fn checked_solver_len_double(context: &str, value: usize) -> Result<usize> {
value
.checked_mul(2)
.ok_or_else(|| XlogError::Kernel(format!("{context} length overflow")))
}
impl GpuCdclSolver {
fn require_expected_status(
&self,
out_status: &TrackedCudaSlice<i32>,
out_error: &TrackedCudaSlice<i32>,
expected_status: i32,
context: &'static str,
) -> Result<()> {
let actual_status = self
.provider
.dtoh_scalar_untracked(out_status, 0)
.map_err(|e| XlogError::Kernel(format!("Failed to read {context} status: {e}")))?;
let actual_error = self
.provider
.dtoh_scalar_untracked(out_error, 0)
.map_err(|e| XlogError::Kernel(format!("Failed to read {context} error: {e}")))?;
if actual_error != 0 || actual_status != expected_status {
return Err(XlogError::Kernel(format!(
"{context} expected status {expected_status}, got status {actual_status} error {actual_error}"
)));
}
Ok(())
}
fn provider_memory_ptr(&self) -> usize {
Arc::as_ptr(self.provider.memory()) as usize
}
fn require_slice_on_provider<T: cudarc::driver::DeviceRepr>(
&self,
name: &'static str,
slice: &TrackedCudaSlice<T>,
) -> Result<()> {
let expected_memory = self.provider_memory_ptr();
let actual_memory = slice.memory_manager_ptr_value();
if actual_memory != expected_memory {
return Err(XlogError::UnsupportedEpistemicConstruct {
construct: "GPU CDCL solver provider boundary".to_string(),
context: format!(
"{name} belongs to memory manager {actual_memory}, expected {expected_memory}"
),
});
}
Ok(())
}
pub(crate) fn require_workspace_on_provider(&self, ws: &GpuCdclWorkspace) -> Result<()> {
macro_rules! require_workspace_slice {
($field:ident) => {
self.require_slice_on_provider(
concat!("workspace.", stringify!($field)),
&ws.$field,
)?
};
}
require_workspace_slice!(assign);
require_workspace_slice!(level);
require_workspace_slice!(reason);
require_workspace_slice!(var_activity);
require_workspace_slice!(var_phase);
require_workspace_slice!(decision_heap);
require_workspace_slice!(decision_heap_pos);
require_workspace_slice!(trail);
require_workspace_slice!(trail_lim);
require_workspace_slice!(seen);
require_workspace_slice!(learnt_tmp);
require_workspace_slice!(proof_vars_tmp);
require_workspace_slice!(proof_reason_tmp);
require_workspace_slice!(watch0_pos);
require_workspace_slice!(watch1_pos);
require_workspace_slice!(watch_head);
require_workspace_slice!(watch_next);
require_workspace_slice!(watch_prev);
require_workspace_slice!(learned_offsets);
require_workspace_slice!(learned_lits);
require_workspace_slice!(learned_deleted);
require_workspace_slice!(learned_lbd);
require_workspace_slice!(learned_activity);
require_workspace_slice!(learned_locked);
require_workspace_slice!(proof_offsets);
require_workspace_slice!(proof_data);
require_workspace_slice!(out_status);
require_workspace_slice!(out_error);
require_workspace_slice!(out_learned_count);
Ok(())
}
pub(crate) fn require_workspace_capacity_for_cnf(
&self,
ws: &GpuCdclWorkspace,
var_cap: u32,
clause_cap: u32,
) -> Result<()> {
let num_vars_cap = var_cap as usize;
if num_vars_cap > ws.var_cap {
return Err(XlogError::Kernel(format!(
"CNF var_cap {} exceeds workspace var_cap {}",
num_vars_cap, ws.var_cap
)));
}
let max_learned_clauses = self.config.max_learned_clauses as usize;
let max_total_clauses = (clause_cap as usize)
.checked_add(max_learned_clauses)
.ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
if max_total_clauses > ws.clause_total_cap {
return Err(XlogError::Kernel(format!(
"CNF clause_total {} exceeds workspace clause_total_cap {}",
max_total_clauses, ws.clause_total_cap
)));
}
Ok(())
}
pub fn new(provider: Arc<CudaKernelProvider>, config: GpuCdclConfig) -> Self {
Self { provider, config }
}
pub fn new_workspace(&self, max_var_cap: u32, max_clause_cap: u32) -> Result<GpuCdclWorkspace> {
let num_vars_cap = max_var_cap as usize;
let num_clauses_cap = max_clause_cap as usize;
let max_learned_clauses = self.config.max_learned_clauses as usize;
let max_learned_lits = self.config.max_learned_lits as usize;
let max_proof_u32 = self.config.max_proof_u32 as usize;
if max_var_cap == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver workspace requires max_var_cap > 0".to_string(),
));
}
if self.config.max_learned_clauses == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_clauses > 0".to_string(),
));
}
if self.config.max_learned_lits == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_lits > 0".to_string(),
));
}
if self.config.max_proof_u32 < 2 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
));
}
let max_total_clauses = num_clauses_cap
.checked_add(max_learned_clauses)
.ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
let vars_plus_one =
checked_solver_len_add_one("SAT workspace variable arena", num_vars_cap)?;
let learned_offsets_len =
checked_solver_len_add_one("SAT workspace learned offsets", max_learned_clauses)?;
let watch_head_len = checked_solver_len_double("SAT workspace watch head", num_vars_cap)?;
let watch_clause_len =
checked_solver_len_double("SAT workspace watch clauses", max_total_clauses)?;
let memory = self.provider.memory();
Ok(GpuCdclWorkspace {
var_cap: num_vars_cap,
clause_total_cap: max_total_clauses,
assign: memory.alloc::<i8>(vars_plus_one)?,
level: memory.alloc::<u32>(vars_plus_one)?,
reason: memory.alloc::<i32>(vars_plus_one)?,
var_activity: memory.alloc::<u32>(vars_plus_one)?,
var_phase: memory.alloc::<i8>(vars_plus_one)?,
decision_heap: memory.alloc::<u32>(vars_plus_one)?,
decision_heap_pos: memory.alloc::<u32>(vars_plus_one)?,
trail: memory.alloc::<i32>(vars_plus_one)?,
trail_lim: memory.alloc::<u32>(vars_plus_one)?,
seen: memory.alloc::<u8>(vars_plus_one)?,
learnt_tmp: memory.alloc::<i32>(vars_plus_one)?,
proof_vars_tmp: memory.alloc::<u32>(vars_plus_one)?,
proof_reason_tmp: memory.alloc::<u32>(vars_plus_one)?,
watch0_pos: memory.alloc::<u32>(max_total_clauses)?,
watch1_pos: memory.alloc::<u32>(max_total_clauses)?,
watch_head: memory.alloc::<i32>(watch_head_len)?,
watch_next: memory.alloc::<i32>(watch_clause_len)?,
watch_prev: memory.alloc::<i32>(watch_clause_len)?,
learned_offsets: memory.alloc::<u32>(learned_offsets_len)?,
learned_lits: memory.alloc::<i32>(max_learned_lits)?,
learned_deleted: memory.alloc::<u8>(max_learned_clauses)?,
learned_lbd: memory.alloc::<u32>(max_learned_clauses)?,
learned_activity: memory.alloc::<u32>(max_learned_clauses)?,
learned_locked: memory.alloc::<u8>(max_learned_clauses)?,
proof_offsets: memory.alloc::<u32>(learned_offsets_len)?,
proof_data: memory.alloc::<u32>(max_proof_u32)?,
out_status: memory.alloc::<i32>(1)?,
out_error: memory.alloc::<i32>(1)?,
out_learned_count: memory.alloc::<u32>(1)?,
})
}
fn alloc_u32_scalar(&self, value: u32) -> Result<TrackedCudaSlice<u32>> {
let memory = self.provider.memory();
let mut gate = memory.alloc::<u32>(1)?;
self.provider
.htod_launch_metadata_sync_copy_into(&[value], &mut gate)
.map_err(|e| XlogError::Kernel(format!("GpuCdclSolver gate upload failed: {}", e)))?;
Ok(gate)
}
fn launch_cdcl_with_decision_ranges_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<GpuCdclRun> {
let num_vars_cap = cnf.var_cap as usize;
let num_clauses_cap = cnf.clause_cap as usize;
cnf.require_provider_memory(&self.provider, "GPU CDCL solver provider boundary")?;
self.require_slice_on_provider("compile_needed", compile_needed)?;
self.require_slice_on_provider("decision_base_limit", decision_base_limit)?;
self.require_slice_on_provider("decision_extra_base", decision_extra_base)?;
self.require_slice_on_provider("decision_extra_count", decision_extra_count)?;
if compile_needed.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires compile_needed len=1, got {}",
compile_needed.len()
)));
}
if cnf.var_cap == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires num_vars > 0".to_string(),
));
}
if decision_base_limit.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_base_limit len=1, got {}",
decision_base_limit.len()
)));
}
if decision_extra_base.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_extra_base len=1, got {}",
decision_extra_base.len()
)));
}
if decision_extra_count.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_extra_count len=1, got {}",
decision_extra_count.len()
)));
}
if self.config.max_learned_clauses == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_clauses > 0".to_string(),
));
}
if self.config.max_learned_lits == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_lits > 0".to_string(),
));
}
if self.config.max_proof_u32 < 2 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
));
}
let max_learned_clauses = self.config.max_learned_clauses as usize;
let max_learned_lits = self.config.max_learned_lits as usize;
let max_proof_u32 = self.config.max_proof_u32 as usize;
let max_total_clauses = num_clauses_cap
.checked_add(max_learned_clauses)
.ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
let vars_plus_one = checked_solver_len_add_one("SAT variable arena", num_vars_cap)?;
let learned_offsets_len =
checked_solver_len_add_one("SAT learned offsets", max_learned_clauses)?;
let watch_head_len = checked_solver_len_double("SAT watch head", num_vars_cap)?;
let watch_clause_len = checked_solver_len_double("SAT watch clauses", max_total_clauses)?;
let memory = self.provider.memory();
let assign = memory.alloc::<i8>(vars_plus_one)?;
let level = memory.alloc::<u32>(vars_plus_one)?;
let reason = memory.alloc::<i32>(vars_plus_one)?;
let var_activity = memory.alloc::<u32>(vars_plus_one)?;
let var_phase = memory.alloc::<i8>(vars_plus_one)?;
let decision_heap = memory.alloc::<u32>(vars_plus_one)?;
let decision_heap_pos = memory.alloc::<u32>(vars_plus_one)?;
let trail = memory.alloc::<i32>(vars_plus_one)?;
let trail_lim = memory.alloc::<u32>(vars_plus_one)?;
let seen = memory.alloc::<u8>(vars_plus_one)?;
let learnt_tmp = memory.alloc::<i32>(vars_plus_one)?;
let proof_vars_tmp = memory.alloc::<u32>(vars_plus_one)?;
let proof_reason_tmp = memory.alloc::<u32>(vars_plus_one)?;
let watch0_pos = memory.alloc::<u32>(max_total_clauses)?;
let watch1_pos = memory.alloc::<u32>(max_total_clauses)?;
let watch_head = memory.alloc::<i32>(watch_head_len)?;
let watch_next = memory.alloc::<i32>(watch_clause_len)?;
let watch_prev = memory.alloc::<i32>(watch_clause_len)?;
let learned_offsets = memory.alloc::<u32>(learned_offsets_len)?;
let learned_lits = memory.alloc::<i32>(max_learned_lits)?;
let learned_deleted = memory.alloc::<u8>(max_learned_clauses)?;
let learned_lbd = memory.alloc::<u32>(max_learned_clauses)?;
let learned_activity = memory.alloc::<u32>(max_learned_clauses)?;
let learned_locked = memory.alloc::<u8>(max_learned_clauses)?;
let proof_offsets = memory.alloc::<u32>(learned_offsets_len)?;
let proof_data = memory.alloc::<u32>(max_proof_u32)?;
let out_status = memory.alloc::<i32>(1)?;
let out_error = memory.alloc::<i32>(1)?;
let mut out_learned_count = memory.alloc::<u32>(1)?;
self.provider
.htod_launch_metadata_sync_copy_into(&[0u32], &mut out_learned_count)
.map_err(|e| {
XlogError::Kernel(format!("Failed to init learned import count: {}", e))
})?;
let sat_fn = self
.provider
.device()
.inner()
.get_func(SAT_MODULE, sat_kernels::SAT_CDCL_SOLVE)
.ok_or_else(|| XlogError::Kernel("sat_cdcl_solve kernel not found".to_string()))?;
let cnf_var_cap = cnf.var_cap;
let cnf_clause_cap = cnf.clause_cap;
let cfg_max_learned_clauses = self.config.max_learned_clauses;
let cfg_max_learned_lits = self.config.max_learned_lits;
let cfg_max_proof_u32 = self.config.max_proof_u32;
let cfg_restart_base = self.config.restart_base;
let cfg_reduce_interval = self.config.reduce_interval;
let learned_import_count_param = (&out_learned_count).as_kernel_param();
let mut params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.clause_offsets).as_kernel_param(),
(&cnf.literals).as_kernel_param(),
(&cnf.num_vars).as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
decision_base_limit.as_kernel_param(),
decision_extra_base.as_kernel_param(),
decision_extra_count.as_kernel_param(),
cnf_var_cap.as_kernel_param(),
cnf_clause_cap.as_kernel_param(),
cfg_max_learned_clauses.as_kernel_param(),
cfg_max_learned_lits.as_kernel_param(),
cfg_max_proof_u32.as_kernel_param(),
cfg_restart_base.as_kernel_param(),
cfg_reduce_interval.as_kernel_param(),
learned_import_count_param,
(&assign).as_kernel_param(),
(&level).as_kernel_param(),
(&reason).as_kernel_param(),
(&var_activity).as_kernel_param(),
(&var_phase).as_kernel_param(),
(&decision_heap).as_kernel_param(),
(&decision_heap_pos).as_kernel_param(),
(&trail).as_kernel_param(),
(&trail_lim).as_kernel_param(),
(&seen).as_kernel_param(),
(&learnt_tmp).as_kernel_param(),
(&proof_vars_tmp).as_kernel_param(),
(&proof_reason_tmp).as_kernel_param(),
(&watch0_pos).as_kernel_param(),
(&watch1_pos).as_kernel_param(),
(&watch_head).as_kernel_param(),
(&watch_next).as_kernel_param(),
(&watch_prev).as_kernel_param(),
(&learned_offsets).as_kernel_param(),
(&learned_lits).as_kernel_param(),
(&learned_deleted).as_kernel_param(),
(&learned_lbd).as_kernel_param(),
(&learned_activity).as_kernel_param(),
(&learned_locked).as_kernel_param(),
(&proof_offsets).as_kernel_param(),
(&proof_data).as_kernel_param(),
(&out_status).as_kernel_param(),
(&out_error).as_kernel_param(),
(&out_learned_count).as_kernel_param(),
];
unsafe {
sat_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
},
&mut params,
)
}
.map_err(|e| XlogError::Kernel(format!("Failed to launch SAT solver kernel: {}", e)))?;
Ok(GpuCdclRun {
assignment: assign,
decision_heap,
decision_heap_pos,
learned_offsets,
learned_lits,
proof_offsets,
proof_data,
out_status,
out_error,
out_learned_count,
})
}
#[allow(clippy::too_many_arguments)]
fn launch_cdcl_with_workspace_gated(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
import_existing_learned: bool,
) -> Result<()> {
cnf.require_provider_memory(&self.provider, "GPU CDCL solver provider boundary")?;
self.require_workspace_on_provider(ws)?;
self.require_slice_on_provider("compile_needed", compile_needed)?;
self.require_slice_on_provider("decision_base_limit", decision_base_limit)?;
self.require_slice_on_provider("decision_extra_base", decision_extra_base)?;
self.require_slice_on_provider("decision_extra_count", decision_extra_count)?;
if compile_needed.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires compile_needed len=1, got {}",
compile_needed.len()
)));
}
self.require_workspace_capacity_for_cnf(ws, cnf.var_cap, cnf.clause_cap)?;
if cnf.var_cap == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires num_vars > 0".to_string(),
));
}
if decision_base_limit.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_base_limit len=1, got {}",
decision_base_limit.len()
)));
}
if decision_extra_base.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_extra_base len=1, got {}",
decision_extra_base.len()
)));
}
if decision_extra_count.len() != 1 {
return Err(XlogError::Compilation(format!(
"GpuCdclSolver requires decision_extra_count len=1, got {}",
decision_extra_count.len()
)));
}
if self.config.max_learned_clauses == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_clauses > 0".to_string(),
));
}
if self.config.max_learned_lits == 0 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_learned_lits > 0".to_string(),
));
}
if self.config.max_proof_u32 < 2 {
return Err(XlogError::Compilation(
"GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
));
}
ws.reset_for_solve();
if !import_existing_learned {
self.provider
.htod_launch_metadata_sync_copy_into(&[0u32], &mut ws.out_learned_count)
.map_err(|e| {
XlogError::Kernel(format!("Failed to init learned import count: {}", e))
})?;
}
let sat_fn = self
.provider
.device()
.inner()
.get_func(SAT_MODULE, sat_kernels::SAT_CDCL_SOLVE)
.ok_or_else(|| XlogError::Kernel("sat_cdcl_solve kernel not found".to_string()))?;
let cnf_var_cap = cnf.var_cap;
let cnf_clause_cap = cnf.clause_cap;
let cfg_max_learned_clauses = self.config.max_learned_clauses;
let cfg_max_learned_lits = self.config.max_learned_lits;
let cfg_max_proof_u32 = self.config.max_proof_u32;
let cfg_restart_base = self.config.restart_base;
let cfg_reduce_interval = self.config.reduce_interval;
let learned_import_count_param = (&ws.out_learned_count).as_kernel_param();
let mut params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.clause_offsets).as_kernel_param(),
(&cnf.literals).as_kernel_param(),
(&cnf.num_vars).as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
decision_base_limit.as_kernel_param(),
decision_extra_base.as_kernel_param(),
decision_extra_count.as_kernel_param(),
cnf_var_cap.as_kernel_param(),
cnf_clause_cap.as_kernel_param(),
cfg_max_learned_clauses.as_kernel_param(),
cfg_max_learned_lits.as_kernel_param(),
cfg_max_proof_u32.as_kernel_param(),
cfg_restart_base.as_kernel_param(),
cfg_reduce_interval.as_kernel_param(),
learned_import_count_param,
(&ws.assign).as_kernel_param(),
(&ws.level).as_kernel_param(),
(&ws.reason).as_kernel_param(),
(&ws.var_activity).as_kernel_param(),
(&ws.var_phase).as_kernel_param(),
(&ws.decision_heap).as_kernel_param(),
(&ws.decision_heap_pos).as_kernel_param(),
(&ws.trail).as_kernel_param(),
(&ws.trail_lim).as_kernel_param(),
(&ws.seen).as_kernel_param(),
(&ws.learnt_tmp).as_kernel_param(),
(&ws.proof_vars_tmp).as_kernel_param(),
(&ws.proof_reason_tmp).as_kernel_param(),
(&ws.watch0_pos).as_kernel_param(),
(&ws.watch1_pos).as_kernel_param(),
(&ws.watch_head).as_kernel_param(),
(&ws.watch_next).as_kernel_param(),
(&ws.watch_prev).as_kernel_param(),
(&ws.learned_offsets).as_kernel_param(),
(&ws.learned_lits).as_kernel_param(),
(&ws.learned_deleted).as_kernel_param(),
(&ws.learned_lbd).as_kernel_param(),
(&ws.learned_activity).as_kernel_param(),
(&ws.learned_locked).as_kernel_param(),
(&ws.proof_offsets).as_kernel_param(),
(&ws.proof_data).as_kernel_param(),
(&ws.out_status).as_kernel_param(),
(&ws.out_error).as_kernel_param(),
(&ws.out_learned_count).as_kernel_param(),
];
unsafe {
sat_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
},
&mut params,
)
}
.map_err(|e| XlogError::Kernel(format!("Failed to launch SAT solver kernel: {}", e)))?;
Ok(())
}
pub fn solve_raw_with_branch_limit(
&self,
cnf: &GpuCnf,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<GpuCdclRawOutput> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_raw_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
}
pub fn solve_raw_with_branch_limit_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<GpuCdclRawOutput> {
let zero = self.alloc_u32_scalar(0)?;
let run = self.launch_cdcl_with_decision_ranges_gated(
cnf,
compile_needed,
branch_var_limit,
&zero,
&zero,
)?;
self.provider.device().synchronize()?;
let GpuCdclRun {
assignment,
out_status,
out_error,
out_learned_count,
..
} = run;
Ok(GpuCdclRawOutput {
assignment,
out_status,
out_error,
out_learned_count,
})
}
pub fn solve_raw_with_decision_ranges(
&self,
cnf: &GpuCnf,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<GpuCdclRawOutput> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_raw_with_decision_ranges_gated(
cnf,
&compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)
}
pub fn solve_raw_with_decision_ranges_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<GpuCdclRawOutput> {
let run = self.launch_cdcl_with_decision_ranges_gated(
cnf,
compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)?;
self.provider.device().synchronize()?;
let GpuCdclRun {
assignment,
out_status,
out_error,
out_learned_count,
..
} = run;
Ok(GpuCdclRawOutput {
assignment,
out_status,
out_error,
out_learned_count,
})
}
pub fn solve_expect_sat(&self, cnf: &GpuCnf) -> Result<TrackedCudaSlice<i8>> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_sat_gated(cnf, &compile_needed)
}
pub fn solve_expect_sat_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<TrackedCudaSlice<i8>> {
self.solve_expect_sat_with_branch_limit_gated(cnf, compile_needed, &cnf.num_vars)
}
pub fn solve_expect_sat_with_decision_ranges(
&self,
cnf: &GpuCnf,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<TrackedCudaSlice<i8>> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_sat_with_decision_ranges_gated(
cnf,
&compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)
}
pub fn solve_expect_sat_with_decision_ranges_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<TrackedCudaSlice<i8>> {
#[cfg(debug_assertions)]
let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
#[cfg(debug_assertions)]
let t0 = std::time::Instant::now();
let run = self.launch_cdcl_with_decision_ranges_gated(
cnf,
compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)?;
self.require_expected_status(
&run.out_status,
&run.out_error,
SAT_STATUS_SAT,
"GPU CDCL SAT expectation",
)?;
let device = self.provider.device().inner();
let memory = self.provider.memory();
let assert_status_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
.ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
unsafe {
assert_status_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&run.out_status,
&run.out_error,
SAT_STATUS_SAT,
),
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
})?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] cdcl(sat) time: {:?}", t0.elapsed());
}
let mut out_ok = memory.alloc::<i32>(1)?;
let check_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_CHECK_MODEL)
.ok_or_else(|| XlogError::Kernel("sat_check_model kernel not found".to_string()))?;
unsafe {
check_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&cnf.clause_offsets,
&cnf.literals,
&cnf.num_clauses,
&run.assignment,
&mut out_ok,
),
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch SAT model check: {}", e))
})?;
}
let assert_ok_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
.ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
unsafe {
assert_ok_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(compile_needed, &out_ok),
)
.map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!(
"[xlog-solve] cdcl(sat)+model_check time: {:?}",
t0.elapsed()
);
}
Ok(run.assignment)
}
pub fn solve_expect_sat_with_branch_limit(
&self,
cnf: &GpuCnf,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<TrackedCudaSlice<i8>> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_sat_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
}
pub fn solve_expect_sat_with_branch_limit_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<TrackedCudaSlice<i8>> {
let zero = self.alloc_u32_scalar(0)?;
self.solve_expect_sat_with_decision_ranges_gated(
cnf,
compile_needed,
branch_var_limit,
&zero,
&zero,
)
}
pub fn solve_expect_unsat(&self, cnf: &GpuCnf) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_unsat_gated(cnf, &compile_needed)
}
pub fn solve_expect_unsat_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
) -> Result<()> {
self.solve_expect_unsat_with_branch_limit_gated(cnf, compile_needed, &cnf.num_vars)
}
pub fn solve_expect_unsat_with_branch_limit(
&self,
cnf: &GpuCnf,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_unsat_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
}
pub fn solve_expect_unsat_with_decision_ranges(
&self,
cnf: &GpuCnf,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_unsat_with_decision_ranges_gated(
cnf,
&compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)
}
pub fn solve_expect_unsat_with_branch_limit_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<()> {
let zero = self.alloc_u32_scalar(0)?;
self.solve_expect_unsat_with_decision_ranges_gated(
cnf,
compile_needed,
branch_var_limit,
&zero,
&zero,
)
}
pub fn solve_expect_unsat_with_decision_ranges_gated(
&self,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<()> {
#[cfg(debug_assertions)]
let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
#[cfg(debug_assertions)]
let t0 = std::time::Instant::now();
let run = self.launch_cdcl_with_decision_ranges_gated(
cnf,
compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)?;
self.require_expected_status(
&run.out_status,
&run.out_error,
SAT_STATUS_UNSAT,
"GPU CDCL UNSAT expectation",
)?;
let device = self.provider.device().inner();
let memory = self.provider.memory();
let assert_status_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
.ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
unsafe {
assert_status_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&run.out_status,
&run.out_error,
SAT_STATUS_UNSAT,
),
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
})?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] cdcl(unsat) time: {:?}", t0.elapsed());
}
let mut out_ok = memory.alloc::<i32>(1)?;
self.provider
.htod_launch_metadata_sync_copy_into(&[1i32], &mut out_ok)
.map_err(|e| XlogError::Kernel(format!("Failed to init proof out_ok: {}", e)))?;
let scratch_cap_u32 = cnf
.var_cap
.checked_add(1)
.ok_or_else(|| XlogError::Kernel("Proof scratch capacity overflow".to_string()))?;
let scratch_cap = scratch_cap_u32 as usize;
let mut proof_blocks: usize = 1;
let mut scratch_a = None;
let mut scratch_b = None;
let mut scratch_map = None;
let mut last_alloc_err: Option<XlogError> = None;
for blocks in [512usize, 256, 128, 64, 32, 16, 8, 4, 2, 1] {
let len = match scratch_cap.checked_mul(blocks) {
Some(v) => v,
None => {
last_alloc_err = Some(XlogError::Kernel(
"Proof scratch allocation length overflow".to_string(),
));
continue;
}
};
let a = match memory.alloc::<i32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
continue;
}
};
let b = match memory.alloc::<i32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
drop(a);
continue;
}
};
let m = match memory.alloc::<u32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
drop(a);
drop(b);
continue;
}
};
proof_blocks = blocks;
scratch_a = Some(a);
scratch_b = Some(b);
scratch_map = Some(m);
break;
}
let scratch_a = scratch_a.ok_or_else(|| {
last_alloc_err.unwrap_or_else(|| {
XlogError::Kernel("Failed to allocate proof scratch buffers".to_string())
})
})?;
let scratch_b = scratch_b
.ok_or_else(|| XlogError::Kernel("Missing proof scratch buffer".to_string()))?;
let mut scratch_map = scratch_map
.ok_or_else(|| XlogError::Kernel("Missing proof scratch map buffer".to_string()))?;
device
.memset_zeros(&mut scratch_map)
.map_err(|e| XlogError::Kernel(format!("Failed to zero proof scratch map: {}", e)))?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] proof_check blocks: {}", proof_blocks);
}
#[cfg(debug_assertions)]
let t_mark = std::time::Instant::now();
let needed_cap_u32 = self.config.max_learned_clauses;
let needed_cap = needed_cap_u32 as usize;
let mut needed = memory.alloc::<u8>(needed_cap)?;
device
.memset_zeros(&mut needed)
.map_err(|e| XlogError::Kernel(format!("Failed to zero proof needed mask: {}", e)))?;
let mark_needed_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_PROOF_MARK_NEEDED)
.ok_or_else(|| {
XlogError::Kernel("sat_proof_mark_needed kernel not found".to_string())
})?;
let mut mark_params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
(&run.out_learned_count).as_kernel_param(),
(&run.proof_offsets).as_kernel_param(),
(&run.proof_data).as_kernel_param(),
needed_cap_u32.as_kernel_param(),
(&needed).as_kernel_param(),
];
unsafe {
mark_needed_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
&mut mark_params,
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch sat_proof_mark_needed: {}", e))
})?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!(
"[xlog-solve] proof_mark_needed time: {:?}",
t_mark.elapsed()
);
}
let proof_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_PROOF_CHECK)
.ok_or_else(|| XlogError::Kernel("sat_proof_check kernel not found".to_string()))?;
#[cfg(debug_assertions)]
let t_proof = std::time::Instant::now();
let proof_blocks_u32 = u32::try_from(proof_blocks)
.map_err(|_| XlogError::Kernel("Proof check grid dim exceeds u32::MAX".to_string()))?;
let mut proof_params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.clause_offsets).as_kernel_param(),
(&cnf.literals).as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
(&run.learned_offsets).as_kernel_param(),
(&run.learned_lits).as_kernel_param(),
(&run.out_learned_count).as_kernel_param(),
(&run.proof_offsets).as_kernel_param(),
(&run.proof_data).as_kernel_param(),
(&needed).as_kernel_param(),
needed_cap_u32.as_kernel_param(),
(&scratch_a).as_kernel_param(),
(&scratch_b).as_kernel_param(),
(&scratch_map).as_kernel_param(),
scratch_cap_u32.as_kernel_param(),
(&out_ok).as_kernel_param(),
];
unsafe {
proof_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (proof_blocks_u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
},
&mut proof_params,
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch SAT proof check: {}", e))
})?;
}
let assert_ok_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
.ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
unsafe {
assert_ok_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(compile_needed, &out_ok),
)
.map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] proof_check time: {:?}", t_proof.elapsed());
eprintln!(
"[xlog-solve] cdcl(unsat)+proof_check time: {:?}",
t0.elapsed()
);
}
Ok(())
}
pub fn solve_expect_unsat_with_branch_limit_ws(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_unsat_with_branch_limit_gated_ws(
ws,
cnf,
&compile_needed,
branch_var_limit,
)
}
pub fn solve_expect_unsat_with_branch_limit_gated_ws(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<()> {
let zero = self.alloc_u32_scalar(0)?;
self.solve_expect_unsat_with_decision_ranges_gated_ws(
ws,
cnf,
compile_needed,
branch_var_limit,
&zero,
&zero,
)
}
pub fn solve_expect_unsat_with_decision_ranges_ws(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
self.solve_expect_unsat_with_decision_ranges_gated_ws(
ws,
cnf,
&compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
)
}
pub fn solve_expect_unsat_with_decision_ranges_gated_ws(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
) -> Result<()> {
self.solve_expect_unsat_with_decision_ranges_gated_ws_inner(
ws,
cnf,
compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
false,
)
}
pub fn solve_expect_unsat_with_branch_limit_ws_importing_learned(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
branch_var_limit: &TrackedCudaSlice<u32>,
) -> Result<()> {
let compile_needed = self.alloc_u32_scalar(1)?;
let zero = self.alloc_u32_scalar(0)?;
self.solve_expect_unsat_with_decision_ranges_gated_ws_inner(
ws,
cnf,
&compile_needed,
branch_var_limit,
&zero,
&zero,
true,
)
}
#[allow(clippy::too_many_arguments)]
fn solve_expect_unsat_with_decision_ranges_gated_ws_inner(
&self,
ws: &mut GpuCdclWorkspace,
cnf: &GpuCnf,
compile_needed: &TrackedCudaSlice<u32>,
decision_base_limit: &TrackedCudaSlice<u32>,
decision_extra_base: &TrackedCudaSlice<u32>,
decision_extra_count: &TrackedCudaSlice<u32>,
import_existing_learned: bool,
) -> Result<()> {
#[cfg(debug_assertions)]
let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
#[cfg(debug_assertions)]
let t0 = std::time::Instant::now();
self.launch_cdcl_with_workspace_gated(
ws,
cnf,
compile_needed,
decision_base_limit,
decision_extra_base,
decision_extra_count,
import_existing_learned,
)?;
self.require_expected_status(
&ws.out_status,
&ws.out_error,
SAT_STATUS_UNSAT,
"GPU CDCL workspace UNSAT expectation",
)?;
let device = self.provider.device().inner();
let memory = self.provider.memory();
let assert_status_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
.ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
unsafe {
assert_status_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(
compile_needed,
&ws.out_status,
&ws.out_error,
SAT_STATUS_UNSAT,
),
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
})?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] cdcl_ws(unsat) time: {:?}", t0.elapsed());
}
let mut out_ok = memory.alloc::<i32>(1)?;
self.provider
.htod_launch_metadata_sync_copy_into(&[1i32], &mut out_ok)
.map_err(|e| XlogError::Kernel(format!("Failed to init proof out_ok: {}", e)))?;
let scratch_cap_u32 = cnf
.var_cap
.checked_add(1)
.ok_or_else(|| XlogError::Kernel("Proof scratch capacity overflow".to_string()))?;
let scratch_cap = scratch_cap_u32 as usize;
let mut proof_blocks: usize = 1;
let mut scratch_a = None;
let mut scratch_b = None;
let mut scratch_map = None;
let mut last_alloc_err: Option<XlogError> = None;
for blocks in [512usize, 256, 128, 64, 32, 16, 8, 4, 2, 1] {
let len = match scratch_cap.checked_mul(blocks) {
Some(v) => v,
None => {
last_alloc_err = Some(XlogError::Kernel(
"Proof scratch allocation length overflow".to_string(),
));
continue;
}
};
let a = match memory.alloc::<i32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
continue;
}
};
let b = match memory.alloc::<i32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
drop(a);
continue;
}
};
let m = match memory.alloc::<u32>(len) {
Ok(buf) => buf,
Err(e) => {
last_alloc_err = Some(e);
drop(a);
drop(b);
continue;
}
};
proof_blocks = blocks;
scratch_a = Some(a);
scratch_b = Some(b);
scratch_map = Some(m);
break;
}
let scratch_a = scratch_a.ok_or_else(|| {
last_alloc_err.unwrap_or_else(|| {
XlogError::Kernel("Failed to allocate proof scratch buffers".to_string())
})
})?;
let scratch_b = scratch_b
.ok_or_else(|| XlogError::Kernel("Missing proof scratch buffer".to_string()))?;
let mut scratch_map = scratch_map
.ok_or_else(|| XlogError::Kernel("Missing proof scratch map buffer".to_string()))?;
device
.memset_zeros(&mut scratch_map)
.map_err(|e| XlogError::Kernel(format!("Failed to zero proof scratch map: {}", e)))?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] proof_check_ws blocks: {}", proof_blocks);
}
#[cfg(debug_assertions)]
let t_mark = std::time::Instant::now();
let needed_cap_u32 = self.config.max_learned_clauses;
let needed_cap = needed_cap_u32 as usize;
let mut needed = memory.alloc::<u8>(needed_cap)?;
device
.memset_zeros(&mut needed)
.map_err(|e| XlogError::Kernel(format!("Failed to zero proof needed mask: {}", e)))?;
let mark_needed_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_PROOF_MARK_NEEDED)
.ok_or_else(|| {
XlogError::Kernel("sat_proof_mark_needed kernel not found".to_string())
})?;
let mut mark_params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
(&ws.out_learned_count).as_kernel_param(),
(&ws.proof_offsets).as_kernel_param(),
(&ws.proof_data).as_kernel_param(),
needed_cap_u32.as_kernel_param(),
(&needed).as_kernel_param(),
];
unsafe {
mark_needed_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
&mut mark_params,
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch sat_proof_mark_needed: {}", e))
})?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!(
"[xlog-solve] proof_mark_needed_ws time: {:?}",
t_mark.elapsed()
);
}
let proof_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_PROOF_CHECK)
.ok_or_else(|| XlogError::Kernel("sat_proof_check kernel not found".to_string()))?;
#[cfg(debug_assertions)]
let t_proof = std::time::Instant::now();
let proof_blocks_u32 = u32::try_from(proof_blocks)
.map_err(|_| XlogError::Kernel("Proof check grid dim exceeds u32::MAX".to_string()))?;
let mut proof_params: Vec<*mut c_void> = vec![
compile_needed.as_kernel_param(),
(&cnf.clause_offsets).as_kernel_param(),
(&cnf.literals).as_kernel_param(),
(&cnf.num_clauses).as_kernel_param(),
(&ws.learned_offsets).as_kernel_param(),
(&ws.learned_lits).as_kernel_param(),
(&ws.out_learned_count).as_kernel_param(),
(&ws.proof_offsets).as_kernel_param(),
(&ws.proof_data).as_kernel_param(),
(&needed).as_kernel_param(),
needed_cap_u32.as_kernel_param(),
(&scratch_a).as_kernel_param(),
(&scratch_b).as_kernel_param(),
(&scratch_map).as_kernel_param(),
scratch_cap_u32.as_kernel_param(),
(&out_ok).as_kernel_param(),
];
unsafe {
proof_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (proof_blocks_u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
},
&mut proof_params,
)
.map_err(|e| {
XlogError::Kernel(format!("Failed to launch SAT proof check: {}", e))
})?;
}
let assert_ok_fn = device
.get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
.ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
unsafe {
assert_ok_fn
.clone()
.launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(compile_needed, &out_ok),
)
.map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
}
self.provider.device().synchronize()?;
#[cfg(debug_assertions)]
if trace {
eprintln!("[xlog-solve] proof_check_ws time: {:?}", t_proof.elapsed());
eprintln!(
"[xlog-solve] cdcl_ws(unsat)+proof_check time: {:?}",
t0.elapsed()
);
}
Ok(())
}
}