use std::{marker::PhantomData, mem, num::NonZeroI32};
use pindakaas::{Lit as RawLit, Var as RawVar};
use tracing::trace;
use crate::{
actions::{Trailed, TrailingActions},
helpers::bytes::Bytes,
};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
struct BoolStore {
value: Option<bool>,
restore: Option<bool>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct Trail {
trail: Vec<u32>,
pos: usize,
prev_len: Vec<usize>,
int_value: Vec<[u8; 8]>,
sat_store: Vec<BoolStore>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum TrailEvent {
SatAssignment(RawVar),
IntAssignment {
index: u32,
value: [u8; 8],
},
}
impl Trail {
pub(crate) const CURRENT_BRANCHER: Trailed<usize> = Trailed {
index: 0,
ty: PhantomData,
};
pub(crate) fn assign_lit(&mut self, lit: RawLit) -> Option<bool> {
let var = lit.var();
let store = &mut self.sat_store[Self::sat_index(var)].value;
if let Some(val) = *store {
Some(if lit.is_negated() { !val } else { val })
} else {
*store = Some(!lit.is_negated());
self.push_trail(TrailEvent::SatAssignment(var));
None
}
}
pub(crate) fn decision_level(&self) -> u32 {
self.prev_len.len() as u32
}
pub(crate) fn goto_assign_lit(&mut self, lit: RawLit) {
let var = lit.var();
if self.sat_store[Self::sat_index(var)].value.is_none() {
while let Some(event) = self.redo() {
if matches!(event, TrailEvent::SatAssignment(r) if r == var) {
let e: Option<TrailEvent> = self.undo::<true>();
debug_assert_eq!(e, Some(TrailEvent::SatAssignment(var)));
trace!(
target: "solver",
len = self.pos,
lit = i32::from(lit),
"redo to when literal was set"
);
return;
}
}
trace!(
target: "solver",
len = self.pos,
lit = i32::from(lit),
"trail reset for unknown literal"
);
return;
}
while let Some(event) = self.undo::<true>() {
if matches!(event, TrailEvent::SatAssignment(r) if r == var) {
trace!(
target: "solver",
len = self.pos,
lit = i32::from(lit),
"undo to when literal was set"
);
return;
}
}
}
pub(crate) fn grow_to_boolvar(&mut self, var: RawVar) {
let idx = Self::sat_index(var);
if idx >= self.sat_store.len() {
self.sat_store.resize(idx + 1, Default::default());
}
}
pub(crate) fn notify_backtrack(&mut self, level: usize) {
if level >= self.prev_len.len() {
return;
}
let len = self.prev_len[level];
self.prev_len.truncate(level);
debug_assert!(
len <= self.trail.len(),
"backtracking to level {level} length {len}, but trail is already at length {}",
self.trail.len()
);
if len <= self.pos {
while self.pos > len {
self.undo::<false>();
}
} else {
while self.pos < len {
self.redo();
}
}
debug_assert_eq!(self.pos, len);
self.trail.truncate(len);
}
pub(crate) fn notify_new_decision_level(&mut self) {
self.prev_len.push(self.trail.len());
}
fn push_trail(&mut self, event: TrailEvent) {
debug_assert_eq!(self.pos, self.trail.len());
match event {
TrailEvent::SatAssignment(_) => self.trail.push(0),
TrailEvent::IntAssignment { .. } => self.trail.extend([0; 3]),
}
event.write_trail(&mut self.trail[self.pos..]);
self.pos = self.trail.len();
}
fn redo(&mut self) -> Option<TrailEvent> {
debug_assert!(self.pos <= self.trail.len());
if self.pos == self.trail.len() {
return None;
}
let event = if (self.trail[self.pos] as i32).is_positive() {
self.pos += 1;
TrailEvent::SatAssignment(
RawLit::from_raw(NonZeroI32::new(self.trail[self.pos - 1] as i32).unwrap()).var(),
)
} else {
self.pos += 3;
TrailEvent::int_from_rev_trail(self.trail[self.pos - 3..self.pos].try_into().unwrap())
};
match event {
TrailEvent::SatAssignment(r) => {
let store = &mut self.sat_store[Self::sat_index(r)];
debug_assert!(store.restore.is_some());
debug_assert!(store.value.is_none());
mem::swap(&mut store.restore, &mut store.value);
}
TrailEvent::IntAssignment { index, value } => {
let x = self.int_value[index as usize];
TrailEvent::IntAssignment { index, value: x }
.write_trail(&mut self.trail[self.pos - 3..self.pos]);
self.int_value[index as usize] = value;
}
}
Some(event)
}
pub(crate) fn reset_to_trail_head(&mut self) {
while self.redo().is_some() {}
}
#[inline]
fn sat_index(var: RawVar) -> usize {
i32::from(var) as usize
}
pub(crate) fn sat_value(&self, lit: impl Into<RawLit>) -> Option<bool> {
let lit = lit.into();
self.sat_store
.get(Self::sat_index(lit.var()))
.and_then(|store| store.value)
.map(|x| if lit.is_negated() { !x } else { x })
}
pub(crate) fn track<T: Bytes>(&mut self, val: T) -> Trailed<T> {
self.int_value.push(val.to_bytes());
Trailed {
index: (self.int_value.len() - 1) as u32,
ty: PhantomData,
}
}
fn undo<const RESTORE: bool>(&mut self) -> Option<TrailEvent> {
debug_assert!(self.pos <= self.trail.len());
if self.pos == 0 {
return None;
}
let event = if (self.trail[self.pos - 1] as i32).is_positive() {
self.pos -= 1;
TrailEvent::SatAssignment(
RawLit::from_raw(NonZeroI32::new(self.trail[self.pos] as i32).unwrap()).var(),
)
} else {
self.pos -= 3;
TrailEvent::int_from_trail(self.trail[self.pos..=self.pos + 2].try_into().unwrap())
};
match event {
TrailEvent::SatAssignment(r) => {
let store = &mut self.sat_store[Self::sat_index(r)];
let b = mem::take(&mut store.value);
if RESTORE {
store.restore = b;
}
}
TrailEvent::IntAssignment { index, value } => {
if RESTORE {
let x = self.int_value[index as usize];
TrailEvent::IntAssignment { index, value: x }
.write_rev_trail(&mut self.trail[self.pos..=self.pos + 2]);
}
self.int_value[index as usize] = value;
}
}
Some(event)
}
}
impl Default for Trail {
fn default() -> Self {
Self {
trail: Vec::new(),
pos: 0,
prev_len: Vec::new(),
int_value: vec![0_u64.to_bytes()],
sat_store: Vec::new(),
}
}
}
impl TrailingActions for Trail {
fn set_trailed<T: Bytes>(&mut self, i: Trailed<T>, v: T) -> T {
let bytes = v.to_bytes();
if self.int_value[i.index as usize] == bytes {
return T::from_bytes(bytes);
}
let old = mem::replace(&mut self.int_value[i.index as usize], bytes);
self.push_trail(TrailEvent::IntAssignment {
index: i.index,
value: old,
});
T::from_bytes(old)
}
fn trailed<T: Bytes>(&self, i: Trailed<T>) -> T {
T::from_bytes(self.int_value[i.index as usize])
}
}
impl TrailEvent {
#[inline]
fn int_from_rev_trail(raw: [u32; 3]) -> Self {
let index = -(raw[0] as i32) as u32;
let high = raw[1] as u64;
let low = raw[2] as u64;
let value = ((high << 32) | low).to_ne_bytes();
TrailEvent::IntAssignment { index, value }
}
#[inline]
fn int_from_trail(raw: [u32; 3]) -> Self {
let index = -(raw[2] as i32) as u32;
let high = raw[1] as u64;
let low = raw[0] as u64;
let value = ((high << 32) | low).to_ne_bytes();
TrailEvent::IntAssignment { index, value }
}
#[inline]
fn write_rev_trail(&self, trail: &mut [u32]) {
match self {
TrailEvent::SatAssignment(var) => trail[0] = i32::from(*var) as u32,
TrailEvent::IntAssignment { index, value } => {
let val = u64::from_ne_bytes(*value);
let high = (val >> 32) as u32;
let low = val as u32;
trail[0] = -(*index as i32) as u32;
trail[1] = high;
trail[2] = low;
}
}
}
#[inline]
fn write_trail(&self, trail: &mut [u32]) {
match self {
TrailEvent::SatAssignment(var) => trail[0] = i32::from(*var) as u32,
TrailEvent::IntAssignment { index, value } => {
let val = u64::from_ne_bytes(*value);
let high = (val >> 32) as u32;
let low = val as u32;
trail[0] = low;
trail[1] = high;
trail[2] = -(*index as i32) as u32;
}
}
}
}
#[cfg(test)]
mod tests {
use pindakaas::{ClauseDatabase, solver::cadical::Cadical};
use crate::{
IntVal,
actions::TrailingActions,
helpers::bytes::Bytes,
solver::trail::{Trail, TrailEvent},
};
#[test]
fn test_trail_event() {
let mut slv = Cadical::default();
let mut trail = Trail::default();
let lits = slv.new_var_range(10);
trail.grow_to_boolvar(lits.end());
let int_events: Vec<_> = [
0,
1,
-1,
IntVal::MAX,
IntVal::MIN,
4084,
-9967,
9076,
-4312,
1718,
]
.into_iter()
.map(|i| (trail.track(0), i))
.collect();
for (l, &(i, v)) in lits.zip(int_events.iter()) {
trail.assign_lit(if i.index % 2 == 0 { l.into() } else { !l });
trail.set_trailed(i, v);
}
for (l, &(i, v)) in lits.rev().zip(int_events.iter().rev()) {
assert_eq!(trail.trailed(i), v);
if v != 0 {
let e = trail.undo::<true>().unwrap();
let TrailEvent::IntAssignment { index, value } = e else {
panic!("unexpected trail event type {e:?}");
};
assert_eq!(i.index, index);
assert_eq!(trail.trailed(i), i64::from_bytes(value));
}
assert_eq!(trail.sat_value(l), Some(i.index % 2 == 0));
let e = trail.undo::<true>().unwrap();
assert_eq!(e, TrailEvent::SatAssignment(l));
assert_eq!(trail.sat_value(l), None);
}
}
}