Skip to main content

arcis_compiler/
traits.rs

1use crate::{
2    core::circuits::boolean::{boolean_value::Boolean, byte::Byte},
3    utils::{
4        crypto::key::X25519PublicKey,
5        curve_point::Curve,
6        elliptic_curve::F25519,
7        number::Number,
8    },
9};
10use std::ops::Not;
11
12pub trait Equal<Other>: Sized {
13    type Output: Not<Output = Self::Output>;
14
15    fn eq(self, other: Other) -> Self::Output;
16    fn ne(self, other: Other) -> Self::Output {
17        Self::eq(self, other).not()
18    }
19}
20
21pub trait IsZero {
22    type Output;
23
24    fn is_zero(&self) -> Self::Output;
25}
26
27pub trait GreaterEqual<Other>: Sized {
28    type Output: Not<Output = Self::Output>;
29
30    fn ge(self, other: Other) -> Self::Output;
31    fn lt(self, other: Other) -> Self::Output {
32        Self::ge(self, other).not()
33    }
34}
35
36pub trait GreaterThan<Other>: Sized {
37    type Output: Not<Output = Self::Output>;
38
39    fn gt(self, other: Other) -> Self::Output;
40    fn le(self, other: Other) -> Self::Output {
41        Self::gt(self, other).not()
42    }
43}
44
45/// Implement [`Equal`] for all types that implement [`Eq`].
46impl<T> Equal<T> for T
47where
48    T: Eq,
49{
50    type Output = bool;
51
52    fn eq(self, other: T) -> Self::Output {
53        self == other
54    }
55}
56
57/// Implement [`GreaterEqual`] for all types that implement [`PartialOrd`].
58impl<T> GreaterEqual<T> for T
59where
60    T: PartialOrd,
61{
62    type Output = bool;
63
64    fn ge(self, other: T) -> Self::Output {
65        self >= other
66    }
67}
68
69/// Implement [`GreaterThan`] for all types that implement [`PartialOrd`].
70impl<T> GreaterThan<T> for T
71where
72    T: PartialOrd,
73{
74    type Output = bool;
75
76    fn gt(self, other: T) -> Self::Output {
77        self > other
78    }
79}
80
81pub trait Selectable<T = Self> {
82    type Conditional;
83    type Output;
84
85    fn construct_selection(condition: Self::Conditional, a: Self, b: T) -> Self::Output;
86}
87
88pub trait Select<T, U, V> {
89    fn select(self, a: T, b: V) -> U;
90}
91
92pub trait Enc<T> {
93    fn reveal(self) -> T;
94}
95
96pub trait FromLeBits<B: Boolean> {
97    fn from_le_bits(bits: Vec<B>, signed: bool) -> Self;
98}
99
100pub trait GetBit {
101    type Output: Boolean;
102
103    fn get_bit(&self, index: usize, signed: bool) -> Self::Output;
104}
105
106pub trait FromLeBytes {
107    fn from_le_bytes(bytes: [u8; 32]) -> Self;
108}
109
110pub trait ToLeBytes {
111    type BooleanOutput: Boolean;
112
113    fn to_le_bytes(self) -> [Byte<Self::BooleanOutput>; 32];
114}
115
116pub trait Random {
117    fn random() -> Self;
118}
119
120pub trait RandomBit {
121    fn random() -> Self;
122}
123
124pub trait Reveal {
125    fn reveal(self) -> Self;
126}
127
128pub trait Invert {
129    fn invert(self, is_expected_non_zero: bool) -> Self;
130}
131
132pub trait Pow {
133    fn pow(self, e: &Number, is_expected_non_zero: bool) -> Self;
134}
135
136pub trait Keccak {
137    fn f1600(state: [Byte<Self>; 200]) -> [Byte<Self>; 200]
138    where
139        Self: Boolean;
140
141    fn sponge<const N: usize>(
142        rate: usize,
143        capacity: usize,
144        input_bytes: Vec<Byte<Self>>,
145    ) -> [Byte<Self>; N]
146    where
147        Self: Boolean,
148    {
149        if input_bytes.len() > 1 << 20 {
150            panic!(
151                "sha3 not supported on inputs of more than 2^20 bytes (found {})",
152                input_bytes.len()
153            );
154        }
155        if rate + capacity != 1600 || rate % 8 != 0 {
156            panic!("rate + capacity must equal 1600 and rate must be a multiple of 8 (found rate: {rate}, capacity: {capacity})");
157        }
158        let mut state = [Byte::from(0u8); 200];
159        let rate_in_bytes = rate / 8;
160        // absorb the input blocks
161        input_bytes.chunks(rate_in_bytes).for_each(|chunk| {
162            chunk.iter().copied().enumerate().for_each(|(i, c)| {
163                state[i] ^= c;
164            });
165            if chunk.len() == rate_in_bytes {
166                state = Keccak::f1600(state);
167            }
168        });
169        // do the padding
170        let block_size = input_bytes.len() % rate_in_bytes;
171        state[block_size] ^= Byte::from(0x06);
172        state[rate_in_bytes - 1] ^= Byte::from(0x80);
173        state = Keccak::f1600(state);
174        // squeezing phase
175        (0..N)
176            .step_by(rate_in_bytes)
177            .fold(Vec::new(), |mut acc, pos| {
178                let block_size = (N - pos).min(rate_in_bytes);
179                acc.append(&mut state[0..block_size].to_vec());
180                if acc.len() < N {
181                    state = Keccak::f1600(state);
182                }
183                acc
184            })
185            .try_into()
186            .unwrap_or_else(|v: Vec<Byte<Self>>| {
187                panic!("Expected a Vec of length {N} (found {})", v.len())
188            })
189    }
190}
191
192pub trait WithBooleanBounds {
193    fn with_boolean_bounds(&self) -> Self;
194}
195
196pub trait ToMontgomery {
197    type Output: F25519;
198
199    fn to_montgomery(self, is_expected_non_identity: bool) -> (Self::Output, Self::Output);
200}
201
202pub trait MxeX25519PrivateKey {
203    fn mxe_x25519_private_key() -> Self;
204}
205
206pub trait MxeRescueKey {
207    fn mxe_rescue_key(i: usize) -> Self;
208}
209
210pub trait GetSharedRescueKey<C: Curve> {
211    fn get_shared_rescue_key(pubkey: X25519PublicKey<C>, i: usize) -> Self;
212}
213
214/// Trait used to convert the ECDH output to the target field.
215/// The implementor must make sure that the conversion is injective!
216pub trait FromF25519<T: F25519> {
217    #[allow(non_snake_case)]
218    fn from_F25519(value: T) -> Vec<Self>
219    where
220        Self: Sized;
221}