#[derive(Copy, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
#[must_use]
pub struct Variable(u64);
impl Variable {
const TAG_BITS: u64 = 3;
const TAG_SHIFT: u64 = 64 - Self::TAG_BITS;
const PAYLOAD_MASK: u64 = (1u64 << Self::TAG_SHIFT) - 1;
#[allow(non_upper_case_globals)]
pub const Zero: Variable = Variable::pack_unchecked(0, 0);
#[allow(non_upper_case_globals)]
pub const One: Variable = Variable::pack_unchecked(1, 0);
#[inline(always)]
pub const fn zero() -> Self {
Self::Zero
}
#[inline(always)]
#[must_use]
pub const fn is_zero(&self) -> bool {
self.0 == 0
}
#[inline(always)]
#[must_use]
pub const fn is_one(&self) -> bool {
self.0 == Self::One.0
}
#[inline(always)]
pub const fn one() -> Self {
Self::One
}
#[inline(always)]
pub const fn instance(i: usize) -> Self {
Self::pack_unchecked(0b010, i as u64)
}
#[inline(always)]
#[must_use]
pub const fn is_instance(self) -> bool {
self.tag() == VarKind::Instance as u8
}
#[inline(always)]
pub const fn witness(i: usize) -> Self {
Self::pack_unchecked(0b011, i as u64)
}
#[inline(always)]
#[must_use]
pub const fn is_witness(self) -> bool {
self.tag() == VarKind::Witness as u8
}
#[inline(always)]
pub const fn symbolic_lc(i: usize) -> Self {
Self::pack_unchecked(0b100, i as u64)
}
#[inline(always)]
#[must_use]
pub const fn is_lc(self) -> bool {
self.tag() == VarKind::SymbolicLc as u8
}
#[inline(always)]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub const fn get_lc_index(&self) -> Option<usize> {
if self.is_lc() {
Some(self.payload() as usize)
} else {
None
}
}
#[inline(always)]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub const fn get_variable_index(&self, witness_offset: usize) -> Option<usize> {
match self.kind() {
VarKind::One => Some(0),
VarKind::Instance => Some(self.payload() as usize),
VarKind::Witness => Some(self.payload() as usize + witness_offset),
_ => None,
}
}
#[inline(always)]
const fn tag(self) -> u8 {
(self.0 >> Self::TAG_SHIFT) as u8
}
#[inline(always)]
const fn payload(self) -> u64 {
self.0 & Self::PAYLOAD_MASK
}
#[inline(always)]
#[allow(unsafe_code)]
pub const fn kind(self) -> VarKind {
match self.tag() {
0 => VarKind::Zero,
1 => VarKind::One,
2 => VarKind::Instance,
3 => VarKind::Witness,
4 => VarKind::SymbolicLc,
_ => unsafe { core::hint::unreachable_unchecked() },
}
}
#[inline(always)]
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub const fn index(self) -> Option<usize> {
match self.kind() {
VarKind::Zero | VarKind::One => None,
_ => Some(self.payload() as usize),
}
}
const fn pack_unchecked(tag: u64, payload: u64) -> Self {
debug_assert!(payload <= Self::PAYLOAD_MASK);
Variable((tag << Self::TAG_SHIFT) | payload & Self::PAYLOAD_MASK)
}
#[cfg(test)]
const fn new(kind: VarKind, index: usize) -> Self {
match kind {
VarKind::Zero => Self::Zero,
VarKind::One => Self::One,
VarKind::Instance => Self::instance(index),
VarKind::Witness => Self::witness(index),
VarKind::SymbolicLc => Self::symbolic_lc(index),
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[allow(missing_docs)]
#[must_use]
pub enum VarKind {
Zero = 0,
One = 1,
Instance = 2,
Witness = 3,
SymbolicLc = 4,
}
impl core::fmt::Debug for Variable {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match (self.kind(), self.index()) {
(VarKind::Zero, _) => f.write_str("Zero"),
(VarKind::One, _) => f.write_str("One"),
(k, Some(i)) => f.debug_tuple(&format!("{k:?}")).field(&i).finish(),
_ => unreachable!(),
}
}
}
const _: () = assert!(core::mem::size_of::<Variable>() == 8);
#[cfg(test)]
mod tests {
use super::*;
use ark_std::rand::Rng;
#[test]
fn test_variable_ordering() {
use core::cmp::Ordering::*;
use VarKind::*;
let mut rng = ark_std::test_rng();
let kinds = [Zero, One, Instance, Witness, SymbolicLc];
for this_kind in kinds {
let this_payload: u32 = rng.gen();
let this = Variable::new(this_kind, this_payload as usize);
for other_kind in kinds {
let other_1 = Variable::new(other_kind, this_payload as usize);
let other_payload: u32 = rng.gen();
let other_2 = Variable::new(other_kind, other_payload as usize);
let eq_case_with_payload = || {
assert_eq!(this, other_1, "{this:?} != {other_1:?}");
if this_payload < other_payload {
assert!(this < other_2, "{this:?} >= {other_2:?}");
} else if this_payload > other_payload {
assert!(this > other_2, "{this:?} <= {other_2:?}");
} else {
assert_eq!(this, other_2, "{this:?} != {other_2:?}");
}
assert_eq!(this.cmp(&other_1), Equal);
};
let eq_case = || {
assert_eq!(this, other_1, "{this:?} != {other_1:?}");
assert_eq!(this, other_2, "{this:?} != {other_2:?}");
assert_eq!(this.cmp(&other_1), Equal);
};
let lt_case = || {
assert!(this < other_1, "{this:?} >= {other_1:?}");
assert!(this < other_2, "{this:?} >= {other_2:?}");
};
let gt_case = || {
assert!(this > other_1, "{this:?} <= {other_1:?}");
assert!(this > other_2, "{this:?} <= {other_2:?}");
};
match (this_kind, other_kind) {
(Zero, Zero) => eq_case(),
(One, One) => eq_case(),
(Instance, Instance) => eq_case_with_payload(),
(Witness, Witness) => eq_case_with_payload(),
(SymbolicLc, SymbolicLc) => eq_case_with_payload(),
(Zero, _) => lt_case(),
(_, Zero) => gt_case(),
(One, _) => lt_case(),
(_, One) => gt_case(),
(Instance, _) => lt_case(),
(_, Instance) => gt_case(),
(Witness, _) => lt_case(),
(_, Witness) => gt_case(),
}
}
}
}
}