Skip to main content

ark_relations/utils/
variable.rs

1/// Variables in [`ConstraintSystem`]s
2#[derive(Copy, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
3#[must_use]
4pub struct Variable(u64);
5
6impl Variable {
7    // [ tag: payload: 61 bits | 3 bits ]
8    const TAG_BITS: u64 = 3;
9
10    /// Bit position where the tag field starts (61 = 64 − 3).
11    const TAG_SHIFT: u64 = 64 - Self::TAG_BITS;
12
13    /// Mask for the payload (low 61 bits) when the tag is in the top byte.
14    const PAYLOAD_MASK: u64 = (1u64 << Self::TAG_SHIFT) - 1;
15
16    /// The zero variable.
17    #[allow(non_upper_case_globals)]
18    pub const Zero: Variable = Variable::pack_unchecked(0, 0);
19
20    /// The one variable.
21    #[allow(non_upper_case_globals)]
22    pub const One: Variable = Variable::pack_unchecked(1, 0);
23
24    /// The zero variable.
25    #[inline(always)]
26    pub const fn zero() -> Self {
27        Self::Zero
28    }
29
30    /// Is `self` the zero variable?
31    #[inline(always)]
32    #[must_use]
33    pub const fn is_zero(&self) -> bool {
34        self.0 == 0
35    }
36
37    /// Is `self` the one variable?
38    #[inline(always)]
39    #[must_use]
40    pub const fn is_one(&self) -> bool {
41        self.0 == Self::One.0
42    }
43
44    /// The `one` variable.
45    #[inline(always)]
46    pub const fn one() -> Self {
47        Self::One
48    }
49
50    /// Construct an instance variable.
51    #[inline(always)]
52    pub const fn instance(i: usize) -> Self {
53        Self::pack_unchecked(0b010, i as u64)
54    }
55
56    /// Is `self` an instance variable?
57    #[inline(always)]
58    #[must_use]
59    pub const fn is_instance(self) -> bool {
60        self.tag() == VarKind::Instance as u8
61    }
62
63    /// Construct a new witness variable.
64    #[inline(always)]
65    pub const fn witness(i: usize) -> Self {
66        Self::pack_unchecked(0b011, i as u64)
67    }
68
69    /// Is `self` a witness variable?
70    #[inline(always)]
71    #[must_use]
72    pub const fn is_witness(self) -> bool {
73        self.tag() == VarKind::Witness as u8
74    }
75
76    /// Construct a symbolic linear combination variable.
77    #[inline(always)]
78    pub const fn symbolic_lc(i: usize) -> Self {
79        Self::pack_unchecked(0b100, i as u64)
80    }
81
82    /// Is `self` a symbolic linear combination variable?
83    #[inline(always)]
84    #[must_use]
85    pub const fn is_lc(self) -> bool {
86        self.tag() == VarKind::SymbolicLc as u8
87    }
88
89    /// Get the `usize` in `self` if `self.is_lc()`.
90    #[inline(always)]
91    #[must_use]
92    #[allow(clippy::cast_possible_truncation)]
93    pub const fn get_lc_index(&self) -> Option<usize> {
94        if self.is_lc() {
95            Some(self.payload() as usize)
96        } else {
97            None
98        }
99    }
100
101    /// Returns `Some(usize)` if `!self.is_lc()`, and `None` otherwise.
102    #[inline(always)]
103    #[must_use]
104    #[allow(clippy::cast_possible_truncation)]
105    pub const fn get_variable_index(&self, witness_offset: usize) -> Option<usize> {
106        match self.kind() {
107            // The one variable always has index 0
108            VarKind::One => Some(0),
109            VarKind::Instance => Some(self.payload() as usize),
110            VarKind::Witness => Some(self.payload() as usize + witness_offset),
111            _ => None,
112        }
113    }
114
115    /// Returns the tag of the variable.
116    #[inline(always)]
117    const fn tag(self) -> u8 {
118        (self.0 >> Self::TAG_SHIFT) as u8
119    }
120
121    /// Unconditionally returns the payload of the variable.
122    /// Note that when `self.tag() == 0` or `self.tag() == 1`, the data
123    /// value is not meaningful.
124    #[inline(always)]
125    const fn payload(self) -> u64 {
126        self.0 & Self::PAYLOAD_MASK
127    }
128
129    /// What kind of variable is this?
130    #[inline(always)]
131    #[allow(unsafe_code)]
132    pub const fn kind(self) -> VarKind {
133        match self.tag() {
134            0 => VarKind::Zero,
135            1 => VarKind::One,
136            2 => VarKind::Instance,
137            3 => VarKind::Witness,
138            4 => VarKind::SymbolicLc,
139            _ => unsafe { core::hint::unreachable_unchecked() },
140        }
141    }
142
143    /// If `self` is an instance, witness, or symbolic linear combination,
144    /// returns the index of that variable.
145    #[inline(always)]
146    #[must_use]
147    #[allow(clippy::cast_possible_truncation)]
148    pub const fn index(self) -> Option<usize> {
149        match self.kind() {
150            VarKind::Zero | VarKind::One => None,
151            _ => Some(self.payload() as usize),
152        }
153    }
154
155    /// Does not check that the tag and payload are valid.
156    const fn pack_unchecked(tag: u64, payload: u64) -> Self {
157        debug_assert!(payload <= Self::PAYLOAD_MASK);
158        Variable((tag << Self::TAG_SHIFT) | payload & Self::PAYLOAD_MASK)
159    }
160
161    #[cfg(test)]
162    const fn new(kind: VarKind, index: usize) -> Self {
163        match kind {
164            VarKind::Zero => Self::Zero,
165            VarKind::One => Self::One,
166            VarKind::Instance => Self::instance(index),
167            VarKind::Witness => Self::witness(index),
168            VarKind::SymbolicLc => Self::symbolic_lc(index),
169        }
170    }
171}
172
173/// The kinds of variables that can be used in a constraint system.
174#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
175#[allow(missing_docs)]
176#[must_use]
177pub enum VarKind {
178    Zero = 0,
179    One = 1,
180    Instance = 2,
181    Witness = 3,
182    SymbolicLc = 4,
183}
184
185impl core::fmt::Debug for Variable {
186    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
187        match (self.kind(), self.index()) {
188            (VarKind::Zero, _) => f.write_str("Zero"),
189            (VarKind::One, _) => f.write_str("One"),
190            (k, Some(i)) => f.debug_tuple(&format!("{k:?}")).field(&i).finish(),
191            _ => unreachable!(),
192        }
193    }
194}
195
196// Compile-time proof it really is 8 B.
197const _: () = assert!(core::mem::size_of::<Variable>() == 8);
198
199#[cfg(test)]
200mod tests {
201    // test PartialOrd and Ord vs Eq and PartialEq
202    use super::*;
203
204    use ark_std::rand::Rng;
205
206    #[test]
207    fn test_variable_ordering() {
208        use core::cmp::Ordering::*;
209        use VarKind::*;
210        let mut rng = ark_std::test_rng();
211        let kinds = [Zero, One, Instance, Witness, SymbolicLc];
212        for this_kind in kinds {
213            let this_payload: u32 = rng.gen();
214            let this = Variable::new(this_kind, this_payload as usize);
215            for other_kind in kinds {
216                let other_1 = Variable::new(other_kind, this_payload as usize);
217
218                let other_payload: u32 = rng.gen();
219                let other_2 = Variable::new(other_kind, other_payload as usize);
220
221                let eq_case_with_payload = || {
222                    assert_eq!(this, other_1, "{this:?} != {other_1:?}");
223                    if this_payload < other_payload {
224                        assert!(this < other_2, "{this:?} >= {other_2:?}");
225                    } else if this_payload > other_payload {
226                        assert!(this > other_2, "{this:?} <= {other_2:?}");
227                    } else {
228                        assert_eq!(this, other_2, "{this:?} != {other_2:?}");
229                    }
230                    assert_eq!(this.cmp(&other_1), Equal);
231                };
232                let eq_case = || {
233                    assert_eq!(this, other_1, "{this:?} != {other_1:?}");
234                    assert_eq!(this, other_2, "{this:?} != {other_2:?}");
235                    assert_eq!(this.cmp(&other_1), Equal);
236                };
237                let lt_case = || {
238                    assert!(this < other_1, "{this:?} >= {other_1:?}");
239                    assert!(this < other_2, "{this:?} >= {other_2:?}");
240                };
241                let gt_case = || {
242                    assert!(this > other_1, "{this:?} <= {other_1:?}");
243                    assert!(this > other_2, "{this:?} <= {other_2:?}");
244                };
245                match (this_kind, other_kind) {
246                    (Zero, Zero) => eq_case(),
247                    (One, One) => eq_case(),
248                    (Instance, Instance) => eq_case_with_payload(),
249                    (Witness, Witness) => eq_case_with_payload(),
250                    (SymbolicLc, SymbolicLc) => eq_case_with_payload(),
251
252                    (Zero, _) => lt_case(),
253                    (_, Zero) => gt_case(),
254
255                    (One, _) => lt_case(),
256                    (_, One) => gt_case(),
257
258                    (Instance, _) => lt_case(),
259                    (_, Instance) => gt_case(),
260
261                    (Witness, _) => lt_case(),
262                    (_, Witness) => gt_case(),
263                }
264            }
265        }
266    }
267}