rtvm_precompile/
bn128.rs

1use crate::{
2    utilities::{bool_to_bytes32, right_pad},
3    Address, Error, Precompile, PrecompileResult, PrecompileWithAddress,
4};
5use bn::{AffineG1, AffineG2, Fq, Fq2, Group, Gt, G1, G2};
6
7pub mod add {
8    use super::*;
9
10    const ADDRESS: Address = crate::u64_to_address(6);
11
12    pub const ISTANBUL_ADD_GAS_COST: u64 = 150;
13    pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
14        ADDRESS,
15        Precompile::Standard(|input, gas_limit| run_add(input, ISTANBUL_ADD_GAS_COST, gas_limit)),
16    );
17
18    pub const BYZANTIUM_ADD_GAS_COST: u64 = 500;
19    pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
20        ADDRESS,
21        Precompile::Standard(|input, gas_limit| run_add(input, BYZANTIUM_ADD_GAS_COST, gas_limit)),
22    );
23}
24
25pub mod mul {
26    use super::*;
27
28    const ADDRESS: Address = crate::u64_to_address(7);
29
30    pub const ISTANBUL_MUL_GAS_COST: u64 = 6_000;
31    pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
32        ADDRESS,
33        Precompile::Standard(|input, gas_limit| run_mul(input, ISTANBUL_MUL_GAS_COST, gas_limit)),
34    );
35
36    pub const BYZANTIUM_MUL_GAS_COST: u64 = 40_000;
37    pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
38        ADDRESS,
39        Precompile::Standard(|input, gas_limit| run_mul(input, BYZANTIUM_MUL_GAS_COST, gas_limit)),
40    );
41}
42
43pub mod pair {
44    use super::*;
45
46    const ADDRESS: Address = crate::u64_to_address(8);
47
48    pub const ISTANBUL_PAIR_PER_POINT: u64 = 34_000;
49    pub const ISTANBUL_PAIR_BASE: u64 = 45_000;
50    pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress(
51        ADDRESS,
52        Precompile::Standard(|input, gas_limit| {
53            run_pair(
54                input,
55                ISTANBUL_PAIR_PER_POINT,
56                ISTANBUL_PAIR_BASE,
57                gas_limit,
58            )
59        }),
60    );
61
62    pub const BYZANTIUM_PAIR_PER_POINT: u64 = 80_000;
63    pub const BYZANTIUM_PAIR_BASE: u64 = 100_000;
64    pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress(
65        ADDRESS,
66        Precompile::Standard(|input, gas_limit| {
67            run_pair(
68                input,
69                BYZANTIUM_PAIR_PER_POINT,
70                BYZANTIUM_PAIR_BASE,
71                gas_limit,
72            )
73        }),
74    );
75}
76
77/// Input length for the add operation.
78/// `ADD` takes two uncompressed G1 points (64 bytes each).
79pub const ADD_INPUT_LEN: usize = 64 + 64;
80
81/// Input length for the multiplication operation.
82/// `MUL` takes an uncompressed G1 point (64 bytes) and scalar (32 bytes).
83pub const MUL_INPUT_LEN: usize = 64 + 32;
84
85/// Pair element length.
86/// `PAIR` elements are composed of an uncompressed G1 point (64 bytes) and an uncompressed G2 point
87/// (128 bytes).
88pub const PAIR_ELEMENT_LEN: usize = 64 + 128;
89
90/// Reads a single `Fq` from the input slice.
91///
92/// # Panics
93///
94/// Panics if the input is not at least 32 bytes long.
95#[inline]
96pub fn read_fq(input: &[u8]) -> Result<Fq, Error> {
97    Fq::from_slice(&input[..32]).map_err(|_| Error::Bn128FieldPointNotAMember)
98}
99
100/// Reads the `x` and `y` points from the input slice.
101///
102/// # Panics
103///
104/// Panics if the input is not at least 64 bytes long.
105#[inline]
106pub fn read_point(input: &[u8]) -> Result<G1, Error> {
107    let px = read_fq(&input[0..32])?;
108    let py = read_fq(&input[32..64])?;
109    new_g1_point(px, py)
110}
111
112/// Creates a new `G1` point from the given `x` and `y` coordinates.
113pub fn new_g1_point(px: Fq, py: Fq) -> Result<G1, Error> {
114    if px == Fq::zero() && py == Fq::zero() {
115        Ok(G1::zero())
116    } else {
117        AffineG1::new(px, py)
118            .map(Into::into)
119            .map_err(|_| Error::Bn128AffineGFailedToCreate)
120    }
121}
122
123pub fn run_add(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult {
124    if gas_cost > gas_limit {
125        return Err(Error::OutOfGas);
126    }
127
128    let input = right_pad::<ADD_INPUT_LEN>(input);
129
130    let p1 = read_point(&input[..64])?;
131    let p2 = read_point(&input[64..])?;
132
133    let mut output = [0u8; 64];
134    if let Some(sum) = AffineG1::from_jacobian(p1 + p2) {
135        sum.x().to_big_endian(&mut output[..32]).unwrap();
136        sum.y().to_big_endian(&mut output[32..]).unwrap();
137    }
138    Ok((gas_cost, output.into()))
139}
140
141pub fn run_mul(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult {
142    if gas_cost > gas_limit {
143        return Err(Error::OutOfGas);
144    }
145
146    let input = right_pad::<MUL_INPUT_LEN>(input);
147
148    let p = read_point(&input[..64])?;
149
150    // `Fr::from_slice` can only fail when the length is not 32.
151    let fr = bn::Fr::from_slice(&input[64..96]).unwrap();
152
153    let mut output = [0u8; 64];
154    if let Some(mul) = AffineG1::from_jacobian(p * fr) {
155        mul.x().to_big_endian(&mut output[..32]).unwrap();
156        mul.y().to_big_endian(&mut output[32..]).unwrap();
157    }
158    Ok((gas_cost, output.into()))
159}
160
161pub fn run_pair(
162    input: &[u8],
163    pair_per_point_cost: u64,
164    pair_base_cost: u64,
165    gas_limit: u64,
166) -> PrecompileResult {
167    let gas_used = (input.len() / PAIR_ELEMENT_LEN) as u64 * pair_per_point_cost + pair_base_cost;
168    if gas_used > gas_limit {
169        return Err(Error::OutOfGas);
170    }
171
172    if input.len() % PAIR_ELEMENT_LEN != 0 {
173        return Err(Error::Bn128PairLength);
174    }
175
176    let success = if input.is_empty() {
177        true
178    } else {
179        let elements = input.len() / PAIR_ELEMENT_LEN;
180
181        let mut mul = Gt::one();
182        for idx in 0..elements {
183            let read_fq_at = |n: usize| {
184                debug_assert!(n < PAIR_ELEMENT_LEN / 32);
185                let start = idx * PAIR_ELEMENT_LEN + n * 32;
186                // SAFETY: We're reading `6 * 32 == PAIR_ELEMENT_LEN` bytes from `input[idx..]`
187                // per iteration. This is guaranteed to be in-bounds.
188                let slice = unsafe { input.get_unchecked(start..start + 32) };
189                Fq::from_slice(slice).map_err(|_| Error::Bn128FieldPointNotAMember)
190            };
191            let ax = read_fq_at(0)?;
192            let ay = read_fq_at(1)?;
193            let bay = read_fq_at(2)?;
194            let bax = read_fq_at(3)?;
195            let bby = read_fq_at(4)?;
196            let bbx = read_fq_at(5)?;
197
198            let a = new_g1_point(ax, ay)?;
199            let b = {
200                let ba = Fq2::new(bax, bay);
201                let bb = Fq2::new(bbx, bby);
202                if ba.is_zero() && bb.is_zero() {
203                    G2::zero()
204                } else {
205                    G2::from(AffineG2::new(ba, bb).map_err(|_| Error::Bn128AffineGFailedToCreate)?)
206                }
207            };
208
209            mul = mul * bn::pairing(a, b);
210        }
211
212        mul == Gt::one()
213    };
214    Ok((gas_used, bool_to_bytes32(success)))
215}
216
217#[cfg(test)]
218mod tests {
219    use crate::bn128::add::BYZANTIUM_ADD_GAS_COST;
220    use crate::bn128::mul::BYZANTIUM_MUL_GAS_COST;
221    use crate::bn128::pair::{BYZANTIUM_PAIR_BASE, BYZANTIUM_PAIR_PER_POINT};
222    use rtvm_primitives::hex;
223
224    use super::*;
225
226    #[test]
227    fn test_alt_bn128_add() {
228        let input = hex::decode(
229            "\
230             18b18acfb4c2c30276db5411368e7185b311dd124691610c5d3b74034e093dc9\
231             063c909c4720840cb5134cb9f59fa749755796819658d32efc0d288198f37266\
232             07c2b7f58a84bd6145f00c9c2bc0bb1a187f20ff2c92963a88019e7c6a014eed\
233             06614e20c147e940f2d70da3f74c9a17df361706a4485c742bd6788478fa17d7",
234        )
235        .unwrap();
236        let expected = hex::decode(
237            "\
238            2243525c5efd4b9c3d3c45ac0ca3fe4dd85e830a4ce6b65fa1eeaee202839703\
239            301d1d33be6da8e509df21cc35964723180eed7532537db9ae5e7d48f195c915",
240        )
241        .unwrap();
242
243        let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
244        assert_eq!(res, expected);
245
246        // zero sum test
247        let input = hex::decode(
248            "\
249            0000000000000000000000000000000000000000000000000000000000000000\
250            0000000000000000000000000000000000000000000000000000000000000000\
251            0000000000000000000000000000000000000000000000000000000000000000\
252            0000000000000000000000000000000000000000000000000000000000000000",
253        )
254        .unwrap();
255        let expected = hex::decode(
256            "\
257            0000000000000000000000000000000000000000000000000000000000000000\
258            0000000000000000000000000000000000000000000000000000000000000000",
259        )
260        .unwrap();
261
262        let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
263        assert_eq!(res, expected);
264
265        // out of gas test
266        let input = hex::decode(
267            "\
268            0000000000000000000000000000000000000000000000000000000000000000\
269            0000000000000000000000000000000000000000000000000000000000000000\
270            0000000000000000000000000000000000000000000000000000000000000000\
271            0000000000000000000000000000000000000000000000000000000000000000",
272        )
273        .unwrap();
274
275        let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 499);
276        println!("{:?}", res);
277        assert!(matches!(res, Err(Error::OutOfGas)));
278
279        // no input test
280        let input = [0u8; 0];
281        let expected = hex::decode(
282            "\
283            0000000000000000000000000000000000000000000000000000000000000000\
284            0000000000000000000000000000000000000000000000000000000000000000",
285        )
286        .unwrap();
287
288        let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap();
289        assert_eq!(res, expected);
290
291        // point not on curve fail
292        let input = hex::decode(
293            "\
294            1111111111111111111111111111111111111111111111111111111111111111\
295            1111111111111111111111111111111111111111111111111111111111111111\
296            1111111111111111111111111111111111111111111111111111111111111111\
297            1111111111111111111111111111111111111111111111111111111111111111",
298        )
299        .unwrap();
300
301        let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500);
302        assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
303    }
304
305    #[test]
306    fn test_alt_bn128_mul() {
307        let input = hex::decode(
308            "\
309            2bd3e6d0f3b142924f5ca7b49ce5b9d54c4703d7ae5648e61d02268b1a0a9fb7\
310            21611ce0a6af85915e2f1d70300909ce2e49dfad4a4619c8390cae66cefdb204\
311            00000000000000000000000000000000000000000000000011138ce750fa15c2",
312        )
313        .unwrap();
314        let expected = hex::decode(
315            "\
316            070a8d6a982153cae4be29d434e8faef8a47b274a053f5a4ee2a6c9c13c31e5c\
317            031b8ce914eba3a9ffb989f9cdd5b0f01943074bf4f0f315690ec3cec6981afc",
318        )
319        .unwrap();
320
321        let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
322        assert_eq!(res, expected);
323
324        // out of gas test
325        let input = hex::decode(
326            "\
327            0000000000000000000000000000000000000000000000000000000000000000\
328            0000000000000000000000000000000000000000000000000000000000000000\
329            0200000000000000000000000000000000000000000000000000000000000000",
330        )
331        .unwrap();
332
333        let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 39_999);
334        assert!(matches!(res, Err(Error::OutOfGas)));
335
336        // zero multiplication test
337        let input = hex::decode(
338            "\
339            0000000000000000000000000000000000000000000000000000000000000000\
340            0000000000000000000000000000000000000000000000000000000000000000\
341            0200000000000000000000000000000000000000000000000000000000000000",
342        )
343        .unwrap();
344        let expected = hex::decode(
345            "\
346            0000000000000000000000000000000000000000000000000000000000000000\
347            0000000000000000000000000000000000000000000000000000000000000000",
348        )
349        .unwrap();
350
351        let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
352        assert_eq!(res, expected);
353
354        // no input test
355        let input = [0u8; 0];
356        let expected = hex::decode(
357            "\
358            0000000000000000000000000000000000000000000000000000000000000000\
359            0000000000000000000000000000000000000000000000000000000000000000",
360        )
361        .unwrap();
362
363        let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap();
364        assert_eq!(res, expected);
365
366        // point not on curve fail
367        let input = hex::decode(
368            "\
369            1111111111111111111111111111111111111111111111111111111111111111\
370            1111111111111111111111111111111111111111111111111111111111111111\
371            0f00000000000000000000000000000000000000000000000000000000000000",
372        )
373        .unwrap();
374
375        let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000);
376        assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
377    }
378
379    #[test]
380    fn test_alt_bn128_pair() {
381        let input = hex::decode(
382            "\
383            1c76476f4def4bb94541d57ebba1193381ffa7aa76ada664dd31c16024c43f59\
384            3034dd2920f673e204fee2811c678745fc819b55d3e9d294e45c9b03a76aef41\
385            209dd15ebff5d46c4bd888e51a93cf99a7329636c63514396b4a452003a35bf7\
386            04bf11ca01483bfa8b34b43561848d28905960114c8ac04049af4b6315a41678\
387            2bb8324af6cfc93537a2ad1a445cfd0ca2a71acd7ac41fadbf933c2a51be344d\
388            120a2a4cf30c1bf9845f20c6fe39e07ea2cce61f0c9bb048165fe5e4de877550\
389            111e129f1cf1097710d41c4ac70fcdfa5ba2023c6ff1cbeac322de49d1b6df7c\
390            2032c61a830e3c17286de9462bf242fca2883585b93870a73853face6a6bf411\
391            198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2\
392            1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed\
393            090689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b\
394            12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa",
395        )
396        .unwrap();
397        let expected =
398            hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
399                .unwrap();
400
401        let (_, res) = run_pair(
402            &input,
403            BYZANTIUM_PAIR_PER_POINT,
404            BYZANTIUM_PAIR_BASE,
405            260_000,
406        )
407        .unwrap();
408        assert_eq!(res, expected);
409
410        // out of gas test
411        let input = hex::decode(
412            "\
413            1c76476f4def4bb94541d57ebba1193381ffa7aa76ada664dd31c16024c43f59\
414            3034dd2920f673e204fee2811c678745fc819b55d3e9d294e45c9b03a76aef41\
415            209dd15ebff5d46c4bd888e51a93cf99a7329636c63514396b4a452003a35bf7\
416            04bf11ca01483bfa8b34b43561848d28905960114c8ac04049af4b6315a41678\
417            2bb8324af6cfc93537a2ad1a445cfd0ca2a71acd7ac41fadbf933c2a51be344d\
418            120a2a4cf30c1bf9845f20c6fe39e07ea2cce61f0c9bb048165fe5e4de877550\
419            111e129f1cf1097710d41c4ac70fcdfa5ba2023c6ff1cbeac322de49d1b6df7c\
420            2032c61a830e3c17286de9462bf242fca2883585b93870a73853face6a6bf411\
421            198e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c2\
422            1800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed\
423            090689d0585ff075ec9e99ad690c3395bc4b313370b38ef355acdadcd122975b\
424            12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa",
425        )
426        .unwrap();
427
428        let res = run_pair(
429            &input,
430            BYZANTIUM_PAIR_PER_POINT,
431            BYZANTIUM_PAIR_BASE,
432            259_999,
433        );
434        assert!(matches!(res, Err(Error::OutOfGas)));
435
436        // no input test
437        let input = [0u8; 0];
438        let expected =
439            hex::decode("0000000000000000000000000000000000000000000000000000000000000001")
440                .unwrap();
441
442        let (_, res) = run_pair(
443            &input,
444            BYZANTIUM_PAIR_PER_POINT,
445            BYZANTIUM_PAIR_BASE,
446            260_000,
447        )
448        .unwrap();
449        assert_eq!(res, expected);
450
451        // point not on curve fail
452        let input = hex::decode(
453            "\
454            1111111111111111111111111111111111111111111111111111111111111111\
455            1111111111111111111111111111111111111111111111111111111111111111\
456            1111111111111111111111111111111111111111111111111111111111111111\
457            1111111111111111111111111111111111111111111111111111111111111111\
458            1111111111111111111111111111111111111111111111111111111111111111\
459            1111111111111111111111111111111111111111111111111111111111111111",
460        )
461        .unwrap();
462
463        let res = run_pair(
464            &input,
465            BYZANTIUM_PAIR_PER_POINT,
466            BYZANTIUM_PAIR_BASE,
467            260_000,
468        );
469        assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate)));
470
471        // invalid input length
472        let input = hex::decode(
473            "\
474            1111111111111111111111111111111111111111111111111111111111111111\
475            1111111111111111111111111111111111111111111111111111111111111111\
476            111111111111111111111111111111\
477        ",
478        )
479        .unwrap();
480
481        let res = run_pair(
482            &input,
483            BYZANTIUM_PAIR_PER_POINT,
484            BYZANTIUM_PAIR_BASE,
485            260_000,
486        );
487        assert!(matches!(res, Err(Error::Bn128PairLength)));
488    }
489}