1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
use super::backend::init_backend;
use crate::compute::{curve::SwCurveConfig, CurveId, ElementP2};
use ark_ec::short_weierstrass::Affine;
use rayon::prelude::*;
use std::{ffi::CString, marker::PhantomData};
fn count_scalars_per_output(scalars_len: usize, output_bit_table: &[u32]) -> u32 {
let bit_sum: usize = output_bit_table.iter().map(|s| *s as usize).sum();
let num_output_bytes = bit_sum.div_ceil(8);
assert!(scalars_len % num_output_bytes == 0);
(scalars_len / num_output_bytes).try_into().unwrap()
}
/// Handle to compute multi-scalar multiplications (MSMs) with pre-specified generators
///
/// # Example 1 - compute an MSM using the handle
///```no_run
#[doc = include_str!("../../examples/simple_fixed_msm.rs")]
///```
pub struct MsmHandle<T: CurveId> {
handle: *mut blitzar_sys::sxt_multiexp_handle,
phantom: PhantomData<T>,
}
unsafe impl<T: CurveId> Send for MsmHandle<T> {}
unsafe impl<T: CurveId> Sync for MsmHandle<T> {}
impl<T: CurveId> MsmHandle<T> {
/// New handle from the specified generators.
///
/// Note: any MSMs computed with the handle must have length less than or equal
/// to the number of generators used to create the handle.
pub fn new(generators: &[T]) -> Self {
init_backend();
unsafe {
let handle = blitzar_sys::sxt_multiexp_handle_new(
T::CURVE_ID,
generators.as_ptr() as *const std::ffi::c_void,
generators.len() as u32,
);
Self {
handle,
phantom: PhantomData,
}
}
}
/// New handle from a serialized file.
///
/// Note: any MSMs computed with the handle must have length less than or equal
/// to the number of generators used to create the handle.
pub fn new_from_file(filename: &str) -> Self {
init_backend();
let filename = CString::new(filename).expect("filename cannot have null bytes");
unsafe {
let handle =
blitzar_sys::sxt_multiexp_handle_new_from_file(T::CURVE_ID, filename.as_ptr());
Self {
handle,
phantom: PhantomData,
}
}
}
/// Serialize the handle to a file.
///
/// This function can be used together with new_from_file to reduce
/// the cost of creating a handle.
pub fn write(&self, filename: &str) {
let filename = CString::new(filename).expect("filename cannot have null bytes");
unsafe {
blitzar_sys::sxt_multiexp_handle_write_to_file(self.handle, filename.as_ptr());
}
}
/// Compute an MSM using pre-specified generators.
///
/// Suppose g_1, ..., g_n are pre-specified generators and
///
/// s_11, s_12, ..., s_1n
/// s_21, s_22, ..., s_2n
/// .
/// . .
/// . .
/// s_m1, sm2, ..., s_mn
///
/// is an array of scalars of element_num_bytes each.
///
/// If msm is called with the slice of scalars of size element_num_bytes * m * n
/// defined by
///
/// scalars = [s_11, s_21, ..., s_m1, s_12, s_22, ..., s_m2, ..., s_mn ]
///
/// then res will contain the MSM result
///
/// res[0] = s_11 * g_1 + s_12 * g_2 + ... + s_1n * g_n
/// .
/// .
/// .
/// res[m-1] = s_m1 * g_1 + s_12 * g_2 + ... + s_mn * g_n
pub fn msm(&self, res: &mut [T], element_num_bytes: u32, scalars: &[u8]) {
let num_outputs = res.len() as u32;
assert!(scalars.len() as u32 % (num_outputs * element_num_bytes) == 0);
let n = scalars.len() as u32 / (num_outputs * element_num_bytes);
unsafe {
blitzar_sys::sxt_fixed_multiexponentiation(
res.as_ptr() as *mut std::ffi::c_void,
self.handle,
element_num_bytes,
num_outputs,
n,
scalars.as_ptr(),
);
}
}
/// Compute an MSM in packed format using pre-specified generators.
///
/// On completion `res` contains an array of size `num_outputs` for the multiexponentiation
/// of the given `scalars` array.
///
/// An entry output_bit_table[output_index] specifies the number of scalar bits used for
/// output_index.
///
/// Put
/// bit_sum = sum_{output_index} output_bit_table[output_index]
/// and let num_bytes denote the smallest integer greater than or equal to bit_sum that is a
/// multiple of 8.
///
///
/// `scalars` specifies a contiguous multi-dimension `num_bytes` by `n` array laid out in
/// a packed column-major order as specified by output_bit_table. A given row determines the scalar
/// exponents for generator g_i with the output scalars packed contiguously and padded with zeros.
pub fn packed_msm(&self, res: &mut [T], output_bit_table: &[u32], scalars: &[u8]) {
let num_outputs = res.len() as u32;
let n = count_scalars_per_output(scalars.len(), output_bit_table);
unsafe {
blitzar_sys::sxt_fixed_packed_multiexponentiation(
res.as_ptr() as *mut std::ffi::c_void,
self.handle,
output_bit_table.as_ptr(),
num_outputs,
n,
scalars.as_ptr(),
);
}
}
/// Compute a varying lengthing multiexponentiation of scalars in packed format using a handle to
/// pre-specified generators.
///
/// On completion `res` contains an array of size `num_outputs` for the multiexponentiation
/// of the given `scalars` array.
///
/// An entry output_bit_table[output_index] specifies the number of scalar bits used for
/// output_index and output_lengths[output_index] specifies the length used for output_index.
///
/// Note: output_lengths must be sorted in ascending order
///
/// Put
/// bit_sum = sum_{output_index} output_bit_table[output_index]
/// and let num_bytes denote the smallest integer greater than or equal to bit_sum that is a
/// multiple of 8.
///
/// Let n denote the length of the longest output. Then `scalars` specifies a contiguous
/// multi-dimension `num_bytes` by `n` array laid out in a packed column-major order as specified by
/// output_bit_table. A given row determines the scalar exponents for generator g_i with the output
/// scalars packed contiguously and padded with zeros.
pub fn vlen_msm(
&self,
res: &mut [T],
output_bit_table: &[u32],
output_lengths: &[u32],
scalars: &[u8],
) {
let num_outputs = res.len() as u32;
assert_eq!(output_bit_table.len(), num_outputs as usize);
assert_eq!(output_lengths.len(), num_outputs as usize);
unsafe {
blitzar_sys::sxt_fixed_vlen_multiexponentiation(
res.as_ptr() as *mut std::ffi::c_void,
self.handle,
output_bit_table.as_ptr(),
output_lengths.as_ptr(),
num_outputs,
scalars.as_ptr(),
);
}
}
}
impl<T: CurveId> Drop for MsmHandle<T> {
fn drop(&mut self) {
unsafe {
blitzar_sys::sxt_multiexp_handle_free(self.handle);
}
}
}
/// Extend MsmHandle to work with affine coordinates for short Weierstrass curve elements
pub trait SwMsmHandle {
/// Type of an Affine curve element
type AffineElement;
/// Create a handle from affine generators
fn new_with_affine(generators: &[Self::AffineElement]) -> Self;
/// Compute a MSM with the result given as affine elements
fn affine_msm(&self, res: &mut [Self::AffineElement], element_num_bytes: u32, scalars: &[u8]);
/// Compute a packed MSM with the result given as affine elements
fn affine_packed_msm(
&self,
res: &mut [Self::AffineElement],
output_bit_table: &[u32],
scalars: &[u8],
);
/// Compute a variable length MSM with the result given as affine elements
fn affine_vlen_msm(
&self,
res: &mut [Self::AffineElement],
output_bit_table: &[u32],
output_lengths: &[u32],
scalars: &[u8],
);
}
impl<C: SwCurveConfig + Clone> SwMsmHandle for MsmHandle<ElementP2<C>> {
type AffineElement = Affine<C>;
fn new_with_affine(generators: &[Self::AffineElement]) -> Self {
let generators: Vec<ElementP2<C>> = generators.iter().map(|e| e.into()).collect();
MsmHandle::new(&generators)
}
fn affine_msm(&self, res: &mut [Self::AffineElement], element_num_bytes: u32, scalars: &[u8]) {
let mut res_p: Vec<ElementP2<C>> = vec![ElementP2::<C>::default(); res.len()];
self.msm(&mut res_p, element_num_bytes, scalars);
res.par_iter_mut().zip(res_p).for_each(|(resi, resi_p)| {
*resi = resi_p.into();
});
}
fn affine_packed_msm(
&self,
res: &mut [Self::AffineElement],
output_bit_table: &[u32],
scalars: &[u8],
) {
let mut res_p: Vec<ElementP2<C>> = vec![ElementP2::<C>::default(); res.len()];
self.packed_msm(&mut res_p, output_bit_table, scalars);
res.par_iter_mut().zip(res_p).for_each(|(resi, resi_p)| {
*resi = resi_p.into();
});
}
fn affine_vlen_msm(
&self,
res: &mut [Self::AffineElement],
output_bit_table: &[u32],
output_lengths: &[u32],
scalars: &[u8],
) {
let mut res_p: Vec<ElementP2<C>> = vec![ElementP2::<C>::default(); res.len()];
self.vlen_msm(&mut res_p, output_bit_table, output_lengths, scalars);
res.par_iter_mut().zip(res_p).for_each(|(resi, resi_p)| {
*resi = resi_p.into();
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn we_can_count_the_number_of_scalars_per_output() {
let output_bit_table = [1];
let n = count_scalars_per_output(1, &output_bit_table);
assert_eq!(n, 1);
let output_bit_table = [14, 2];
let n = count_scalars_per_output(10, &output_bit_table);
assert_eq!(n, 5);
// we handle cases that overflow
let output_bit_table = [u32::MAX, 1];
let n = count_scalars_per_output((u32::MAX as usize) + 1, &output_bit_table);
assert_eq!(n, 8);
}
}