1use std::collections::HashSet;
2use std::fmt::{self, Write};
3use std::hash::{Hash, Hasher};
4use std::marker::PhantomData;
5use std::mem;
6#[cfg(any(feature = "opencl", feature = "cuda"))]
7use std::path::PathBuf;
8#[cfg(any(feature = "opencl", feature = "cuda"))]
9use std::{env, fs};
10
11use ec_gpu::{GpuField, GpuName};
12use group::prime::PrimeCurveAffine;
13
14static COMMON_SRC: &str = include_str!("cl/common.cl");
15static FIELD_SRC: &str = include_str!("cl/field.cl");
16static FIELD2_SRC: &str = include_str!("cl/field2.cl");
17static EC_SRC: &str = include_str!("cl/ec.cl");
18static FFT_SRC: &str = include_str!("cl/fft.cl");
19static MULTIEXP_SRC: &str = include_str!("cl/multiexp.cl");
20
21#[derive(Clone, Copy)]
22enum Limb32Or64 {
23 Limb32,
24 Limb64,
25}
26
27trait NameAndSource {
30 fn name(&self) -> String;
32 fn source(&self, limb: Limb32Or64) -> String;
34}
35
36impl PartialEq for dyn NameAndSource {
37 fn eq(&self, other: &Self) -> bool {
38 self.name() == other.name()
39 }
40}
41
42impl Eq for dyn NameAndSource {}
43
44impl Hash for dyn NameAndSource {
45 fn hash<H: Hasher>(&self, state: &mut H) {
46 self.name().hash(state)
47 }
48}
49
50impl fmt::Debug for dyn NameAndSource {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 if f.alternate() {
55 f.debug_map()
56 .entries(vec![
57 ("name", self.name()),
58 ("source", self.source(Limb32Or64::Limb32)),
59 ])
60 .finish()
61 } else {
62 write!(f, "{:?}", self.name())
63 }
64 }
65}
66
67#[derive(Debug)]
77enum Field<F: GpuField> {
78 Field(PhantomData<F>),
80 SubField(String),
82}
83
84impl<F: GpuField> Field<F> {
85 pub fn new() -> Self {
87 Self::Field(PhantomData)
90 }
91}
92
93impl<F: GpuField> Default for Field<F> {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99fn field_source<F: GpuField>(limb: Limb32Or64) -> String {
100 match limb {
101 Limb32Or64::Limb32 => [
102 params::<F, Limb32>(),
103 field_add_sub_nvidia::<F, Limb32>().expect("preallocated"),
104 String::from(FIELD_SRC),
105 ]
106 .join("\n"),
107 Limb32Or64::Limb64 => [
108 params::<F, Limb64>(),
109 field_add_sub_nvidia::<F, Limb64>().expect("preallocated"),
110 String::from(FIELD_SRC),
111 ]
112 .join("\n"),
113 }
114}
115
116impl<F: GpuField> NameAndSource for Field<F> {
117 fn name(&self) -> String {
118 match self {
119 Self::Field(_) => F::name(),
120 Self::SubField(name) => name.to_string(),
121 }
122 }
123
124 fn source(&self, limb: Limb32Or64) -> String {
125 match self {
126 Self::Field(_) => {
127 if let Some(sub_field_name) = F::sub_field_name() {
129 String::from(FIELD2_SRC)
130 .replace("FIELD2", &F::name())
131 .replace("FIELD", &sub_field_name)
132 } else {
133 field_source::<F>(limb).replace("FIELD", &F::name())
134 }
135 }
136 Self::SubField(sub_field_name) => {
137 field_source::<F>(limb).replace("FIELD", sub_field_name)
142 }
143 }
144 }
145}
146
147struct Fft<F: GpuName>(PhantomData<F>);
149
150impl<F: GpuName> NameAndSource for Fft<F> {
151 fn name(&self) -> String {
152 F::name()
153 }
154
155 fn source(&self, _limb: Limb32Or64) -> String {
156 String::from(FFT_SRC).replace("FIELD", &F::name())
157 }
158}
159
160struct Multiexp<P: GpuName, F: GpuName, Exp: GpuName> {
162 curve_point: PhantomData<P>,
163 field: PhantomData<F>,
164 exponent: PhantomData<Exp>,
165}
166
167impl<P: GpuName, F: GpuName, Exp: GpuName> Multiexp<P, F, Exp> {
168 pub fn new() -> Self {
169 Self {
170 curve_point: PhantomData::<P>,
171 field: PhantomData::<F>,
172 exponent: PhantomData::<Exp>,
173 }
174 }
175}
176
177impl<P: GpuName, F: GpuName, Exp: GpuName> NameAndSource for Multiexp<P, F, Exp> {
178 fn name(&self) -> String {
179 P::name()
180 }
181
182 fn source(&self, _limb: Limb32Or64) -> String {
183 let ec = String::from(EC_SRC)
184 .replace("FIELD", &F::name())
185 .replace("POINT", &P::name());
186 let multiexp = String::from(MULTIEXP_SRC)
187 .replace("POINT", &P::name())
188 .replace("EXPONENT", &Exp::name());
189 [ec, multiexp].concat()
190 }
191}
192
193pub struct SourceBuilder {
213 fields: HashSet<Box<dyn NameAndSource>>,
215 extension_fields: HashSet<Box<dyn NameAndSource>>,
217 ffts: HashSet<Box<dyn NameAndSource>>,
219 multiexps: HashSet<Box<dyn NameAndSource>>,
221 extra_sources: Vec<String>,
223}
224
225impl SourceBuilder {
226 pub fn new() -> Self {
228 Self {
229 fields: HashSet::new(),
230 extension_fields: HashSet::new(),
231 ffts: HashSet::new(),
232 multiexps: HashSet::new(),
233 extra_sources: Vec::new(),
234 }
235 }
236
237 pub fn add_field<F>(mut self) -> Self
241 where
242 F: GpuField + 'static,
243 {
244 let field = Field::<F>::new();
245 if let Some(sub_field_name) = F::sub_field_name() {
247 self.extension_fields.insert(Box::new(field));
248 let sub_field = Field::<F>::SubField(sub_field_name);
249 self.fields.insert(Box::new(sub_field));
250 } else {
251 self.fields.insert(Box::new(field));
252 }
253 self
254 }
255
256 pub fn add_fft<F>(self) -> Self
258 where
259 F: GpuField + 'static,
260 {
261 let mut config = self.add_field::<F>();
262 let fft = Fft::<F>(PhantomData);
263 config.ffts.insert(Box::new(fft));
264 config
265 }
266
267 pub fn add_multiexp<C, F>(self) -> Self
272 where
273 C: PrimeCurveAffine + GpuName,
274 C::Scalar: GpuField,
275 F: GpuField + 'static,
276 {
277 let mut config = self.add_field::<F>().add_field::<C::Scalar>();
278 let multiexp = Multiexp::<C, F, C::Scalar>::new();
279 config.multiexps.insert(Box::new(multiexp));
280 config
281 }
282
283 pub fn append_source(mut self, source: String) -> Self {
289 self.extra_sources.push(source);
290 self
291 }
292
293 pub fn build_32_bit_limbs(&self) -> String {
297 self.build(Limb32Or64::Limb32)
298 }
299
300 pub fn build_64_bit_limbs(&self) -> String {
304 self.build(Limb32Or64::Limb64)
305 }
306
307 fn build(&self, limb_size: Limb32Or64) -> String {
309 let fields = self
310 .fields
311 .iter()
312 .map(|field| field.source(limb_size))
313 .collect();
314 let extension_fields = self
315 .extension_fields
316 .iter()
317 .map(|field| field.source(limb_size))
318 .collect();
319 let ffts = self.ffts.iter().map(|fft| fft.source(limb_size)).collect();
320 let multiexps = self
321 .multiexps
322 .iter()
323 .map(|multiexp| multiexp.source(limb_size))
324 .collect();
325 let extra_sources = self.extra_sources.join("\n");
326 vec![
327 COMMON_SRC.to_string(),
328 fields,
329 extension_fields,
330 ffts,
331 multiexps,
332 extra_sources,
333 ]
334 .join("\n\n")
335 }
336}
337
338impl Default for SourceBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344pub trait Limb: Sized + Clone + Copy {
346 type LimbType: Clone + std::fmt::Display;
348 fn zero() -> Self;
350 fn new(val: Self::LimbType) -> Self;
352 fn value(&self) -> Self::LimbType;
354 fn bits() -> usize {
356 mem::size_of::<Self::LimbType>() * 8
357 }
358 fn ptx_info() -> (&'static str, &'static str);
360 fn opencl_type() -> &'static str;
362 fn one_limbs<F: GpuField>() -> Vec<Self>;
364 fn modulus_limbs<F: GpuField>() -> Vec<Self>;
367 fn calc_inv(a: Self) -> Self;
370 fn calculate_r2<F: GpuField>() -> Vec<Self>;
372}
373
374#[derive(Clone, Copy)]
376pub struct Limb32(u32);
377impl Limb for Limb32 {
378 type LimbType = u32;
379 fn zero() -> Self {
380 Self(0)
381 }
382 fn new(val: Self::LimbType) -> Self {
383 Self(val)
384 }
385 fn value(&self) -> Self::LimbType {
386 self.0
387 }
388 fn ptx_info() -> (&'static str, &'static str) {
389 ("u32", "r")
390 }
391 fn opencl_type() -> &'static str {
392 "uint"
393 }
394 fn one_limbs<F: GpuField>() -> Vec<Self> {
395 F::one().into_iter().map(Self::new).collect()
396 }
397 fn modulus_limbs<F: GpuField>() -> Vec<Self> {
398 F::modulus().into_iter().map(Self::new).collect()
399 }
400 fn calc_inv(a: Self) -> Self {
401 let mut inv = 1u32;
402 for _ in 0..31 {
403 inv = inv.wrapping_mul(inv);
404 inv = inv.wrapping_mul(a.value());
405 }
406 Self(inv.wrapping_neg())
407 }
408 fn calculate_r2<F: GpuField>() -> Vec<Self> {
409 F::r2().into_iter().map(Self::new).collect()
410 }
411}
412
413#[derive(Clone, Copy)]
415pub struct Limb64(u64);
416impl Limb for Limb64 {
417 type LimbType = u64;
418 fn zero() -> Self {
419 Self(0)
420 }
421 fn new(val: Self::LimbType) -> Self {
422 Self(val)
423 }
424 fn value(&self) -> Self::LimbType {
425 self.0
426 }
427 fn ptx_info() -> (&'static str, &'static str) {
428 ("u64", "l")
429 }
430 fn opencl_type() -> &'static str {
431 "ulong"
432 }
433 fn one_limbs<F: GpuField>() -> Vec<Self> {
434 F::one()
435 .chunks(2)
436 .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
437 .collect()
438 }
439
440 fn modulus_limbs<F: GpuField>() -> Vec<Self> {
441 F::modulus()
442 .chunks(2)
443 .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
444 .collect()
445 }
446
447 fn calc_inv(a: Self) -> Self {
448 let mut inv = 1u64;
449 for _ in 0..63 {
450 inv = inv.wrapping_mul(inv);
451 inv = inv.wrapping_mul(a.value());
452 }
453 Self(inv.wrapping_neg())
454 }
455 fn calculate_r2<F: GpuField>() -> Vec<Self> {
456 F::r2()
457 .chunks(2)
458 .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
459 .collect()
460 }
461}
462
463fn const_field<L: Limb>(name: &str, limbs: Vec<L>) -> String {
464 format!(
465 "CONSTANT FIELD {} = {{ {{ {} }} }};",
466 name,
467 limbs
468 .iter()
469 .map(|l| l.value().to_string())
470 .collect::<Vec<_>>()
471 .join(", ")
472 )
473}
474
475fn params<F, L>() -> String
477where
478 F: GpuField,
479 L: Limb,
480{
481 let one = L::one_limbs::<F>(); let p = L::modulus_limbs::<F>(); let r2 = L::calculate_r2::<F>();
484 let limbs = one.len(); let inv = L::calc_inv(p[0]);
486 let limb_def = format!("#define FIELD_limb {}", L::opencl_type());
487 let limbs_def = format!("#define FIELD_LIMBS {}", limbs);
488 let limb_bits_def = format!("#define FIELD_LIMB_BITS {}", L::bits());
489 let p_def = const_field("FIELD_P", p);
490 let r2_def = const_field("FIELD_R2", r2);
491 let one_def = const_field("FIELD_ONE", one);
492 let zero_def = const_field("FIELD_ZERO", vec![L::zero(); limbs]);
493 let inv_def = format!("#define FIELD_INV {}", inv.value());
494 let typedef = "typedef struct { FIELD_limb val[FIELD_LIMBS]; } FIELD;".to_string();
495 [
496 limb_def,
497 limbs_def,
498 limb_bits_def,
499 inv_def,
500 typedef,
501 one_def,
502 p_def,
503 r2_def,
504 zero_def,
505 ]
506 .join("\n")
507}
508
509fn field_add_sub_nvidia<F, L>() -> Result<String, std::fmt::Error>
511where
512 F: GpuField,
513 L: Limb,
514{
515 let mut result = String::new();
516 let (ptx_type, ptx_reg) = L::ptx_info();
517
518 writeln!(result, "#if defined(OPENCL_NVIDIA) || defined(CUDA)\n")?;
519 for op in &["sub", "add"] {
520 let len = L::one_limbs::<F>().len();
521
522 writeln!(
523 result,
524 "DEVICE FIELD FIELD_{}_nvidia(FIELD a, FIELD b) {{",
525 op
526 )?;
527 if len > 1 {
528 write!(result, "asm(")?;
529 writeln!(result, "\"{}.cc.{} %0, %0, %{};\\r\\n\"", op, ptx_type, len)?;
530
531 for i in 1..len - 1 {
532 writeln!(
533 result,
534 "\"{}c.cc.{} %{}, %{}, %{};\\r\\n\"",
535 op,
536 ptx_type,
537 i,
538 i,
539 len + i
540 )?;
541 }
542 writeln!(
543 result,
544 "\"{}c.{} %{}, %{}, %{};\\r\\n\"",
545 op,
546 ptx_type,
547 len - 1,
548 len - 1,
549 2 * len - 1
550 )?;
551
552 write!(result, ":")?;
553 for n in 0..len {
554 write!(result, "\"+{}\"(a.val[{}])", ptx_reg, n)?;
555 if n != len - 1 {
556 write!(result, ", ")?;
557 }
558 }
559
560 write!(result, "\n:")?;
561 for n in 0..len {
562 write!(result, "\"{}\"(b.val[{}])", ptx_reg, n)?;
563 if n != len - 1 {
564 write!(result, ", ")?;
565 }
566 }
567 writeln!(result, ");")?;
568 }
569 writeln!(result, "return a;\n}}")?;
570 }
571 writeln!(result, "#endif")?;
572
573 Ok(result)
574}
575
576#[allow(unused_variables)]
588pub fn generate(source_builder: &SourceBuilder) {
589 #[cfg(feature = "cuda")]
590 generate_cuda(source_builder);
591 #[cfg(feature = "opencl")]
592 generate_opencl(source_builder);
593}
594
595#[cfg(feature = "cuda")]
596fn generate_cuda(source_builder: &SourceBuilder) -> PathBuf {
597 use sha2::{Digest, Sha256};
598
599 if env::var("DOCS_RS").is_ok() || cfg!(feature = "cargo-clippy") {
603 println!("cargo:rustc-env=_EC_GPU_CUDA_KERNEL_FATBIN=../build.rs");
604 return PathBuf::from("../build.rs");
605 }
606
607 let kernel_source = source_builder.build_32_bit_limbs();
608 let out_dir = env::var("OUT_DIR").expect("OUT_DIR was not set.");
609
610 let mut nvcc = match env::var("EC_GPU_CUDA_NVCC_ARGS") {
613 Ok(args) => execute::command(format!("nvcc {}", args)),
614 Err(_) => {
615 let mut command = std::process::Command::new("nvcc");
616 command
617 .arg("--optimize=6")
618 .arg("--threads=0")
620 .arg("--fatbin")
621 .arg("--gpu-architecture=sm_86")
622 .arg("--generate-code=arch=compute_86,code=sm_86")
623 .arg("--generate-code=arch=compute_80,code=sm_80")
624 .arg("--generate-code=arch=compute_75,code=sm_75");
625 command
626 }
627 };
628
629 let mut hasher = Sha256::new();
632 hasher.update(kernel_source.as_bytes());
633 hasher.update(&format!("{:?}", &nvcc));
634 let kernel_digest = hex::encode(hasher.finalize());
635
636 let source_path: PathBuf = [&out_dir, &format!("{}.cu", &kernel_digest)]
637 .iter()
638 .collect();
639 let fatbin_path: PathBuf = [&out_dir, &format!("{}.fatbin", &kernel_digest)]
640 .iter()
641 .collect();
642
643 fs::write(&source_path, &kernel_source).unwrap_or_else(|_| {
644 panic!(
645 "Cannot write kernel source at {}.",
646 source_path.to_str().unwrap()
647 )
648 });
649
650 if !fatbin_path.as_path().exists() {
652 let status = nvcc
653 .arg("--output-file")
654 .arg(&fatbin_path)
655 .arg(&source_path)
656 .status()
657 .expect("Cannot run nvcc. Install the NVIDIA toolkit or disable the `cuda` feature.");
658
659 if !status.success() {
660 panic!(
661 "nvcc failed. See the kernel source at {}",
662 source_path.to_str().unwrap()
663 );
664 }
665 }
666
667 println!(
670 "cargo:rustc-env=_EC_GPU_CUDA_KERNEL_FATBIN={}",
671 fatbin_path.to_str().unwrap()
672 );
673
674 fatbin_path
675}
676
677#[cfg(feature = "opencl")]
678fn generate_opencl(source_builder: &SourceBuilder) -> PathBuf {
679 let kernel_source = source_builder.build_64_bit_limbs();
680 let out_dir = env::var("OUT_DIR").expect("OUT_DIR was not set.");
681
682 let source_path: PathBuf = [&out_dir, "kernel.cl"].iter().collect();
685
686 fs::write(&source_path, kernel_source).unwrap_or_else(|_| {
687 panic!(
688 "Cannot write kernel source at {}.",
689 source_path.to_str().unwrap()
690 )
691 });
692
693 println!(
695 "cargo:rustc-env=_EC_GPU_OPENCL_KERNEL_SOURCE={}",
696 source_path.to_str().unwrap()
697 );
698
699 source_path
700}
701
702#[cfg(all(test, any(feature = "opencl", feature = "cuda")))]
703mod tests {
704 use super::*;
705
706 use std::sync::Mutex;
707
708 #[cfg(feature = "cuda")]
709 use rust_gpu_tools::cuda;
710 #[cfg(feature = "opencl")]
711 use rust_gpu_tools::opencl;
712 use rust_gpu_tools::{program_closures, Device, GPUError, Program};
713
714 use blstrs::Scalar;
715 use ff::{Field as _, PrimeField};
716 use lazy_static::lazy_static;
717 use rand::{thread_rng, Rng};
718
719 static TEST_SRC: &str = include_str!("./cl/test.cl");
720
721 #[derive(PartialEq, Debug, Clone, Copy)]
722 #[repr(transparent)]
723 pub struct GpuScalar(pub Scalar);
724 impl Default for GpuScalar {
725 fn default() -> Self {
726 Self(Scalar::ZERO)
727 }
728 }
729
730 #[cfg(feature = "cuda")]
731 impl cuda::KernelArgument for GpuScalar {
732 fn as_c_void(&self) -> *mut std::ffi::c_void {
733 &self.0 as *const _ as _
734 }
735 }
736
737 #[cfg(feature = "opencl")]
738 impl opencl::KernelArgument for GpuScalar {
739 fn push(&self, kernel: &mut opencl::Kernel) {
740 unsafe { kernel.builder.set_arg(&self.0) };
741 }
742 }
743
744 #[derive(Debug)]
746 struct NoError;
747 impl From<GPUError> for NoError {
748 fn from(_error: GPUError) -> Self {
749 Self
750 }
751 }
752
753 fn test_source() -> SourceBuilder {
754 let test_source = String::from(TEST_SRC).replace("FIELD", &Scalar::name());
755 SourceBuilder::new()
756 .add_field::<Scalar>()
757 .append_source(test_source)
758 }
759
760 #[cfg(feature = "cuda")]
761 lazy_static! {
762 static ref CUDA_PROGRAM: Mutex<Program> = {
763 use std::ffi::CString;
764
765 let source = test_source();
766 let fatbin_path = generate_cuda(&source);
767
768 let device = *Device::all().first().expect("Cannot get a default device.");
769 let cuda_device = device.cuda_device().unwrap();
770 let fatbin_path_cstring =
771 CString::new(fatbin_path.to_str().expect("path is not valid UTF-8."))
772 .expect("path contains NULL byte.");
773 let program =
774 cuda::Program::from_binary(cuda_device, fatbin_path_cstring.as_c_str()).unwrap();
775 Mutex::new(Program::Cuda(program))
776 };
777 }
778
779 #[cfg(feature = "opencl")]
780 lazy_static! {
781 static ref OPENCL_PROGRAM: Mutex<(Program, Program)> = {
782 let device = *Device::all().first().expect("Cannot get a default device");
783 let opencl_device = device.opencl_device().unwrap();
784 let source_32 = test_source().build_32_bit_limbs();
785 let program_32 = opencl::Program::from_opencl(opencl_device, &source_32).unwrap();
786 let source_64 = test_source().build_64_bit_limbs();
787 let program_64 = opencl::Program::from_opencl(opencl_device, &source_64).unwrap();
788 Mutex::new((Program::Opencl(program_32), Program::Opencl(program_64)))
789 };
790 }
791
792 fn call_kernel(name: &str, scalars: &[GpuScalar], uints: &[u32]) -> Scalar {
793 let closures = program_closures!(|program, _args| -> Result<Scalar, NoError> {
794 let mut cpu_buffer = vec![GpuScalar::default()];
795 let buffer = program.create_buffer_from_slice(&cpu_buffer).unwrap();
796
797 let mut kernel = program.create_kernel(name, 1, 64).unwrap();
798 for scalar in scalars {
799 kernel = kernel.arg(scalar);
800 }
801 for uint in uints {
802 kernel = kernel.arg(uint);
803 }
804 kernel.arg(&buffer).run().unwrap();
805
806 program.read_into_buffer(&buffer, &mut cpu_buffer).unwrap();
807 Ok(cpu_buffer[0].0)
808 });
809
810 #[cfg(all(feature = "cuda", not(feature = "opencl")))]
812 return CUDA_PROGRAM.lock().unwrap().run(closures, ()).unwrap();
813
814 #[cfg(all(feature = "opencl", not(feature = "cuda")))]
816 {
817 let result_32 = OPENCL_PROGRAM.lock().unwrap().0.run(closures, ()).unwrap();
818 let result_64 = OPENCL_PROGRAM.lock().unwrap().1.run(closures, ()).unwrap();
819 assert_eq!(
820 result_32, result_64,
821 "Results for 32-bit and 64-bit limbs must be the same."
822 );
823 result_32
824 }
825
826 #[cfg(all(feature = "cuda", feature = "opencl"))]
828 {
829 let cuda_result = CUDA_PROGRAM.lock().unwrap().run(closures, ()).unwrap();
830 let opencl_32_result = OPENCL_PROGRAM.lock().unwrap().0.run(closures, ()).unwrap();
831 let opencl_64_result = OPENCL_PROGRAM.lock().unwrap().1.run(closures, ()).unwrap();
832 assert_eq!(
833 opencl_32_result, opencl_64_result,
834 "Results for 32-bit and 64-bit limbs on OpenCL must be the same."
835 );
836 assert_eq!(
837 cuda_result, opencl_32_result,
838 "Results for CUDA and OpenCL must be the same."
839 );
840 cuda_result
841 }
842 }
843
844 #[test]
845 fn test_add() {
846 let mut rng = thread_rng();
847 for _ in 0..10 {
848 let a = Scalar::random(&mut rng);
849 let b = Scalar::random(&mut rng);
850 let c = a + b;
851
852 assert_eq!(
853 call_kernel("test_add", &[GpuScalar(a), GpuScalar(b)], &[]),
854 c
855 );
856 }
857 }
858
859 #[test]
860 fn test_sub() {
861 let mut rng = thread_rng();
862 for _ in 0..10 {
863 let a = Scalar::random(&mut rng);
864 let b = Scalar::random(&mut rng);
865 let c = a - b;
866 assert_eq!(
867 call_kernel("test_sub", &[GpuScalar(a), GpuScalar(b)], &[]),
868 c
869 );
870 }
871 }
872
873 #[test]
874 fn test_mul() {
875 let mut rng = thread_rng();
876 for _ in 0..10 {
877 let a = Scalar::random(&mut rng);
878 let b = Scalar::random(&mut rng);
879 let c = a * b;
880
881 assert_eq!(
882 call_kernel("test_mul", &[GpuScalar(a), GpuScalar(b)], &[]),
883 c
884 );
885 }
886 }
887
888 #[test]
889 fn test_pow() {
890 let mut rng = thread_rng();
891 for _ in 0..10 {
892 let a = Scalar::random(&mut rng);
893 let b = rng.gen::<u32>();
894 let c = a.pow_vartime([b as u64]);
895 assert_eq!(call_kernel("test_pow", &[GpuScalar(a)], &[b]), c);
896 }
897 }
898
899 #[test]
900 fn test_sqr() {
901 let mut rng = thread_rng();
902 for _ in 0..10 {
903 let a = Scalar::random(&mut rng);
904 let b = a.square();
905
906 assert_eq!(call_kernel("test_sqr", &[GpuScalar(a)], &[]), b);
907 }
908 }
909
910 #[test]
911 fn test_double() {
912 let mut rng = thread_rng();
913 for _ in 0..10 {
914 let a = Scalar::random(&mut rng);
915 let b = a.double();
916
917 assert_eq!(call_kernel("test_double", &[GpuScalar(a)], &[]), b);
918 }
919 }
920
921 #[test]
922 fn test_unmont() {
923 let mut rng = thread_rng();
924 for _ in 0..10 {
925 let a = Scalar::random(&mut rng);
926 let b: Scalar = unsafe { std::mem::transmute(a.to_repr()) };
927 assert_eq!(call_kernel("test_unmont", &[GpuScalar(a)], &[]), b);
928 }
929 }
930
931 #[test]
932 fn test_mont() {
933 let mut rng = thread_rng();
934 for _ in 0..10 {
935 let a_repr = Scalar::random(&mut rng).to_repr();
936 let a: Scalar = unsafe { std::mem::transmute(a_repr) };
937 let b = Scalar::from_repr(a_repr).unwrap();
938 assert_eq!(call_kernel("test_mont", &[GpuScalar(a)], &[]), b);
939 }
940 }
941}