ovr_evm_precompile_curve25519/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of Frontier.
3//
4// Copyright (c) 2020 Parity Technologies (UK) Ltd.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18extern crate alloc;
19use alloc::vec::Vec;
20use curve25519_dalek::{
21	ristretto::{CompressedRistretto, RistrettoPoint},
22	scalar::Scalar,
23	traits::Identity,
24};
25use fp_evm::{ExitError, ExitSucceed, LinearCostPrecompile, PrecompileFailure};
26
27// Adds at most 10 curve25519 points and returns the CompressedRistretto bytes representation
28pub struct Curve25519Add;
29
30impl LinearCostPrecompile for Curve25519Add {
31	const BASE: u64 = 60;
32	const WORD: u64 = 12;
33
34	fn execute(input: &[u8], _: u64) -> Result<(ExitSucceed, Vec<u8>), PrecompileFailure> {
35		if input.len() % 32 != 0 {
36			return Err(PrecompileFailure::Error {
37				exit_status: ExitError::Other("input must contain multiple of 32 bytes".into()),
38			});
39		};
40
41		if input.len() > 320 {
42			return Err(PrecompileFailure::Error {
43				exit_status: ExitError::Other(
44					"input cannot be greater than 320 bytes (10 compressed points)".into(),
45				),
46			});
47		};
48
49		let mut points = Vec::new();
50		let mut temp_buf = input.clone();
51		while temp_buf.len() > 0 {
52			let mut buf = [0; 32];
53			buf.copy_from_slice(&temp_buf[0..32]);
54			let point = CompressedRistretto::from_slice(&buf);
55			points.push(point);
56			temp_buf = &temp_buf[32..];
57		}
58
59		let sum = points
60			.iter()
61			.fold(RistrettoPoint::identity(), |acc, point| {
62				let pt = point
63					.decompress()
64					.unwrap_or_else(|| RistrettoPoint::identity());
65				acc + pt
66			});
67
68		Ok((ExitSucceed::Returned, sum.compress().to_bytes().to_vec()))
69	}
70}
71
72// Multiplies a scalar field element with an elliptic curve point
73pub struct Curve25519ScalarMul;
74
75impl LinearCostPrecompile for Curve25519ScalarMul {
76	const BASE: u64 = 60;
77	const WORD: u64 = 12;
78
79	fn execute(input: &[u8], _: u64) -> Result<(ExitSucceed, Vec<u8>), PrecompileFailure> {
80		if input.len() != 64 {
81			return Err(PrecompileFailure::Error {
82				exit_status: ExitError::Other(
83					"input must contain 64 bytes (scalar - 32 bytes, point - 32 bytes)".into(),
84				),
85			});
86		};
87
88		// first 32 bytes is for the scalar value
89		let mut scalar_buf = [0; 32];
90		scalar_buf.copy_from_slice(&input[0..32]);
91		let scalar = Scalar::from_bytes_mod_order(scalar_buf);
92
93		// second 32 bytes is for the compressed ristretto point bytes
94		let mut pt_buf = [0; 32];
95		pt_buf.copy_from_slice(&input[32..64]);
96		let point: RistrettoPoint = CompressedRistretto::from_slice(&pt_buf)
97			.decompress()
98			.unwrap_or_else(|| RistrettoPoint::identity());
99
100		let scalar_mul = scalar * point;
101		Ok((
102			ExitSucceed::Returned,
103			scalar_mul.compress().to_bytes().to_vec(),
104		))
105	}
106}
107
108#[cfg(test)]
109mod tests {
110	use super::*;
111	use curve25519_dalek::constants;
112
113	#[test]
114	fn test_sum() -> Result<(), PrecompileFailure> {
115		let s1 = Scalar::from(999u64);
116		let p1 = &constants::RISTRETTO_BASEPOINT_POINT * &s1;
117
118		let s2 = Scalar::from(333u64);
119		let p2 = &constants::RISTRETTO_BASEPOINT_POINT * &s2;
120
121		let vec = vec![p1.clone(), p2.clone()];
122		let mut input = vec![];
123		input.extend_from_slice(&p1.compress().to_bytes());
124		input.extend_from_slice(&p2.compress().to_bytes());
125
126		let sum: RistrettoPoint = vec.iter().sum();
127		let cost: u64 = 1;
128
129		match Curve25519Add::execute(&input, cost) {
130			Ok((_, out)) => {
131				assert_eq!(out, sum.compress().to_bytes());
132				Ok(())
133			}
134			Err(e) => {
135				panic!("Test not expected to fail: {:?}", e);
136			}
137		}
138	}
139
140	#[test]
141	fn test_empty() -> Result<(), PrecompileFailure> {
142		// Test that sum works for the empty iterator
143		let input = vec![];
144
145		let cost: u64 = 1;
146
147		match Curve25519Add::execute(&input, cost) {
148			Ok((_, out)) => {
149				assert_eq!(out, RistrettoPoint::identity().compress().to_bytes());
150				Ok(())
151			}
152			Err(e) => {
153				panic!("Test not expected to fail: {:?}", e);
154			}
155		}
156	}
157
158	#[test]
159	fn test_scalar_mul() -> Result<(), PrecompileFailure> {
160		let s1 = Scalar::from(999u64);
161		let s2 = Scalar::from(333u64);
162		let p1 = &constants::RISTRETTO_BASEPOINT_POINT * &s1;
163		let p2 = &constants::RISTRETTO_BASEPOINT_POINT * &s2;
164
165		let mut input = vec![];
166		input.extend_from_slice(&s1.to_bytes());
167		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
168
169		let cost: u64 = 1;
170
171		match Curve25519ScalarMul::execute(&input, cost) {
172			Ok((_, out)) => {
173				assert_eq!(out, p1.compress().to_bytes());
174				assert_ne!(out, p2.compress().to_bytes());
175				Ok(())
176			}
177			Err(e) => {
178				panic!("Test not expected to fail: {:?}", e);
179			}
180		}
181	}
182
183	#[test]
184	fn test_scalar_mul_empty_error() -> Result<(), PrecompileFailure> {
185		let input = vec![];
186
187		let cost: u64 = 1;
188
189		match Curve25519ScalarMul::execute(&input, cost) {
190			Ok((_, _out)) => {
191				panic!("Test not expected to work");
192			}
193			Err(e) => {
194				assert_eq!(
195					e,
196					PrecompileFailure::Error {
197						exit_status: ExitError::Other(
198							"input must contain 64 bytes (scalar - 32 bytes, point - 32 bytes)"
199								.into()
200						)
201					}
202				);
203				Ok(())
204			}
205		}
206	}
207
208	#[test]
209	fn test_point_addition_bad_length() -> Result<(), PrecompileFailure> {
210		let input: Vec<u8> = [0u8; 33].to_vec();
211
212		let cost: u64 = 1;
213
214		match Curve25519Add::execute(&input, cost) {
215			Ok((_, _out)) => {
216				panic!("Test not expected to work");
217			}
218			Err(e) => {
219				assert_eq!(
220					e,
221					PrecompileFailure::Error {
222						exit_status: ExitError::Other(
223							"input must contain multiple of 32 bytes".into()
224						)
225					}
226				);
227				Ok(())
228			}
229		}
230	}
231
232	#[test]
233	fn test_point_addition_too_many_points() -> Result<(), PrecompileFailure> {
234		let mut input = vec![];
235		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 1
236		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 2
237		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 3
238		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 4
239		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 5
240		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 6
241		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 7
242		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 8
243		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 9
244		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 10
245		input.extend_from_slice(&constants::RISTRETTO_BASEPOINT_POINT.compress().to_bytes()); // 11
246
247		let cost: u64 = 1;
248
249		match Curve25519Add::execute(&input, cost) {
250			Ok((_, _out)) => {
251				panic!("Test not expected to work");
252			}
253			Err(e) => {
254				assert_eq!(
255					e,
256					PrecompileFailure::Error {
257						exit_status: ExitError::Other(
258							"input cannot be greater than 320 bytes (10 compressed points)".into()
259						)
260					}
261				);
262				Ok(())
263			}
264		}
265	}
266}