ec_gpu_gen/
source.rs

1use std::collections::HashSet;
2use std::fmt::{self, Write};
3use std::hash::{Hash, Hasher};
4use std::marker::PhantomData;
5use std::mem;
6#[cfg(any(feature = "opencl", feature = "cuda"))]
7use std::path::PathBuf;
8#[cfg(any(feature = "opencl", feature = "cuda"))]
9use std::{env, fs};
10
11use ec_gpu::{GpuField, GpuName};
12use group::prime::PrimeCurveAffine;
13
14static COMMON_SRC: &str = include_str!("cl/common.cl");
15static FIELD_SRC: &str = include_str!("cl/field.cl");
16static FIELD2_SRC: &str = include_str!("cl/field2.cl");
17static EC_SRC: &str = include_str!("cl/ec.cl");
18static FFT_SRC: &str = include_str!("cl/fft.cl");
19static MULTIEXP_SRC: &str = include_str!("cl/multiexp.cl");
20
21#[derive(Clone, Copy)]
22enum Limb32Or64 {
23    Limb32,
24    Limb64,
25}
26
27/// This trait is used to uniquely identify items by some identifier (`name`) and to return the GPU
28/// source code they produce.
29trait NameAndSource {
30    /// The name to identify the item.
31    fn name(&self) -> String;
32    /// The GPU source code that is generated.
33    fn source(&self, limb: Limb32Or64) -> String;
34}
35
36impl PartialEq for dyn NameAndSource {
37    fn eq(&self, other: &Self) -> bool {
38        self.name() == other.name()
39    }
40}
41
42impl Eq for dyn NameAndSource {}
43
44impl Hash for dyn NameAndSource {
45    fn hash<H: Hasher>(&self, state: &mut H) {
46        self.name().hash(state)
47    }
48}
49
50/// Prints the name by default, the source code of the 32-bit limb in the alternate mode via
51/// `{:#?}`.
52impl fmt::Debug for dyn NameAndSource {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        if f.alternate() {
55            f.debug_map()
56                .entries(vec![
57                    ("name", self.name()),
58                    ("source", self.source(Limb32Or64::Limb32)),
59                ])
60                .finish()
61        } else {
62            write!(f, "{:?}", self.name())
63        }
64    }
65}
66
67/// A field that might also be an extension field.
68///
69/// When the field is an extension field, we also add its sub-field to the list of fields. This
70/// enum is used to indicate that it's a sub-field that has a corresponding extension field. This
71/// way we can make sure that when the source is generated, that also the source for the sub-field
72/// is generated, while not having duplicated field definitions.
73// Storing the sub-field as a string is a bit of a hack around Rust's type system. If we would
74// store the generic type, then the enum would need to be generic over two fields, even in
75// the case when no extension field is used. This would make the API harder to use.
76#[derive(Debug)]
77enum Field<F: GpuField> {
78    /// A field, might be an extension field.
79    Field(PhantomData<F>),
80    /// A sub-field with the given name that has a corresponding extension field.
81    SubField(String),
82}
83
84impl<F: GpuField> Field<F> {
85    /// Create a new field for the given generic type.
86    pub fn new() -> Self {
87        // By default it's added as a field. If it's an extension field, then the `add_field()`
88        // function will create a copy of it, as `SubField` variant.
89        Self::Field(PhantomData)
90    }
91}
92
93impl<F: GpuField> Default for Field<F> {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99fn field_source<F: GpuField>(limb: Limb32Or64) -> String {
100    match limb {
101        Limb32Or64::Limb32 => [
102            params::<F, Limb32>(),
103            field_add_sub_nvidia::<F, Limb32>().expect("preallocated"),
104            String::from(FIELD_SRC),
105        ]
106        .join("\n"),
107        Limb32Or64::Limb64 => [
108            params::<F, Limb64>(),
109            field_add_sub_nvidia::<F, Limb64>().expect("preallocated"),
110            String::from(FIELD_SRC),
111        ]
112        .join("\n"),
113    }
114}
115
116impl<F: GpuField> NameAndSource for Field<F> {
117    fn name(&self) -> String {
118        match self {
119            Self::Field(_) => F::name(),
120            Self::SubField(name) => name.to_string(),
121        }
122    }
123
124    fn source(&self, limb: Limb32Or64) -> String {
125        match self {
126            Self::Field(_) => {
127                // If it's an extension field.
128                if let Some(sub_field_name) = F::sub_field_name() {
129                    String::from(FIELD2_SRC)
130                        .replace("FIELD2", &F::name())
131                        .replace("FIELD", &sub_field_name)
132                } else {
133                    field_source::<F>(limb).replace("FIELD", &F::name())
134                }
135            }
136            Self::SubField(sub_field_name) => {
137                // The `GpuField` implementation of the extension field contains the constants of
138                // the sub-field. Hence we can just forward the `F`. It's important that those
139                // functions do *not* use the name of the field, else we might generate the
140                // sub-field named like the extension field.
141                field_source::<F>(limb).replace("FIELD", sub_field_name)
142            }
143        }
144    }
145}
146
147/// Struct that generates FFT GPU source code.
148struct Fft<F: GpuName>(PhantomData<F>);
149
150impl<F: GpuName> NameAndSource for Fft<F> {
151    fn name(&self) -> String {
152        F::name()
153    }
154
155    fn source(&self, _limb: Limb32Or64) -> String {
156        String::from(FFT_SRC).replace("FIELD", &F::name())
157    }
158}
159
160/// Struct that generates multiexp GPU smource code.
161struct Multiexp<P: GpuName, F: GpuName, Exp: GpuName> {
162    curve_point: PhantomData<P>,
163    field: PhantomData<F>,
164    exponent: PhantomData<Exp>,
165}
166
167impl<P: GpuName, F: GpuName, Exp: GpuName> Multiexp<P, F, Exp> {
168    pub fn new() -> Self {
169        Self {
170            curve_point: PhantomData::<P>,
171            field: PhantomData::<F>,
172            exponent: PhantomData::<Exp>,
173        }
174    }
175}
176
177impl<P: GpuName, F: GpuName, Exp: GpuName> NameAndSource for Multiexp<P, F, Exp> {
178    fn name(&self) -> String {
179        P::name()
180    }
181
182    fn source(&self, _limb: Limb32Or64) -> String {
183        let ec = String::from(EC_SRC)
184            .replace("FIELD", &F::name())
185            .replace("POINT", &P::name());
186        let multiexp = String::from(MULTIEXP_SRC)
187            .replace("POINT", &P::name())
188            .replace("EXPONENT", &Exp::name());
189        [ec, multiexp].concat()
190    }
191}
192
193/// Builder to create the source code of a GPU kernel.
194///
195/// # Example
196///
197/// ```
198/// use blstrs::{Fp, Fp2, G1Affine, G2Affine, Scalar};
199/// use ec_gpu_gen::SourceBuilder;
200///
201/// # #[cfg(any(feature = "cuda", feature = "opencl"))]
202/// let source = SourceBuilder::new()
203///     .add_fft::<Scalar>()
204///     .add_multiexp::<G1Affine, Fp>()
205///     .add_multiexp::<G2Affine, Fp2>()
206///     .build_32_bit_limbs();
207///```
208// In the `HashSet`s the concrete types cannot be used, as each item of the set should be able to
209// have its own (different) generic type.
210// We distinguish between extension fields and other fields as sub-fields need to be defined first
211// in the source code (due to being C, where the order of declaration matters).
212pub struct SourceBuilder {
213    /// The [`Field`]s that are used in this kernel.
214    fields: HashSet<Box<dyn NameAndSource>>,
215    /// The extension [`Field`]s that are used in this kernel.
216    extension_fields: HashSet<Box<dyn NameAndSource>>,
217    /// The [`Fft`]s that are used in this kernel.
218    ffts: HashSet<Box<dyn NameAndSource>>,
219    /// The [`Multiexp`]s that are used in this kernel.
220    multiexps: HashSet<Box<dyn NameAndSource>>,
221    /// Additional source that is appended at the end of the generated source.
222    extra_sources: Vec<String>,
223}
224
225impl SourceBuilder {
226    /// Create a new configuration to generation a GPU kernel.
227    pub fn new() -> Self {
228        Self {
229            fields: HashSet::new(),
230            extension_fields: HashSet::new(),
231            ffts: HashSet::new(),
232            multiexps: HashSet::new(),
233            extra_sources: Vec::new(),
234        }
235    }
236
237    /// Add a field to the configuration.
238    ///
239    /// If it is an extension field, then the extension field *and* the sub-field is added.
240    pub fn add_field<F>(mut self) -> Self
241    where
242        F: GpuField + 'static,
243    {
244        let field = Field::<F>::new();
245        // If it's an extension field, also add the corresponding sub-field.
246        if let Some(sub_field_name) = F::sub_field_name() {
247            self.extension_fields.insert(Box::new(field));
248            let sub_field = Field::<F>::SubField(sub_field_name);
249            self.fields.insert(Box::new(sub_field));
250        } else {
251            self.fields.insert(Box::new(field));
252        }
253        self
254    }
255
256    /// Add an FFT kernel function to the configuration.
257    pub fn add_fft<F>(self) -> Self
258    where
259        F: GpuField + 'static,
260    {
261        let mut config = self.add_field::<F>();
262        let fft = Fft::<F>(PhantomData);
263        config.ffts.insert(Box::new(fft));
264        config
265    }
266
267    /// Add an Multiexp kernel function to the configuration.
268    ///
269    /// The field must be given explicitly as currently it cannot derived from the curve point
270    /// directly.
271    pub fn add_multiexp<C, F>(self) -> Self
272    where
273        C: PrimeCurveAffine + GpuName,
274        C::Scalar: GpuField,
275        F: GpuField + 'static,
276    {
277        let mut config = self.add_field::<F>().add_field::<C::Scalar>();
278        let multiexp = Multiexp::<C, F, C::Scalar>::new();
279        config.multiexps.insert(Box::new(multiexp));
280        config
281    }
282
283    /// Appends some given source at the end of the generated source.
284    ///
285    /// This is useful for cases where you use this library as building block, but have your own
286    /// kernel implementation. If this function is is called several times, then those sources are
287    /// appended in that call order.
288    pub fn append_source(mut self, source: String) -> Self {
289        self.extra_sources.push(source);
290        self
291    }
292
293    /// Generate the GPU kernel source code based on the current configuration with 32-bit limbs.
294    ///
295    /// On CUDA 32-bit limbs are recommended.
296    pub fn build_32_bit_limbs(&self) -> String {
297        self.build(Limb32Or64::Limb32)
298    }
299
300    /// Generate the GPU kernel source code based on the current configuration with 64-bit limbs.
301    ///
302    /// On OpenCL 32-bit limbs are recommended.
303    pub fn build_64_bit_limbs(&self) -> String {
304        self.build(Limb32Or64::Limb64)
305    }
306
307    /// Generate the GPU kernel source code based on the current configuration.
308    fn build(&self, limb_size: Limb32Or64) -> String {
309        let fields = self
310            .fields
311            .iter()
312            .map(|field| field.source(limb_size))
313            .collect();
314        let extension_fields = self
315            .extension_fields
316            .iter()
317            .map(|field| field.source(limb_size))
318            .collect();
319        let ffts = self.ffts.iter().map(|fft| fft.source(limb_size)).collect();
320        let multiexps = self
321            .multiexps
322            .iter()
323            .map(|multiexp| multiexp.source(limb_size))
324            .collect();
325        let extra_sources = self.extra_sources.join("\n");
326        vec![
327            COMMON_SRC.to_string(),
328            fields,
329            extension_fields,
330            ffts,
331            multiexps,
332            extra_sources,
333        ]
334        .join("\n\n")
335    }
336}
337
338impl Default for SourceBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344/// Trait to implement limbs of different underlying bit sizes.
345pub trait Limb: Sized + Clone + Copy {
346    /// The underlying size of the limb, e.g. `u32`
347    type LimbType: Clone + std::fmt::Display;
348    /// Returns the value representing zero.
349    fn zero() -> Self;
350    /// Returns a new limb.
351    fn new(val: Self::LimbType) -> Self;
352    /// Returns the raw value of the limb.
353    fn value(&self) -> Self::LimbType;
354    /// Returns the bit size of the limb.
355    fn bits() -> usize {
356        mem::size_of::<Self::LimbType>() * 8
357    }
358    /// Returns a tuple with the strings that PTX is using to describe the type and the register.
359    fn ptx_info() -> (&'static str, &'static str);
360    /// Returns the type that OpenCL is using to represent the limb.
361    fn opencl_type() -> &'static str;
362    /// Returns the limbs that represent the multiplicative identity of the given field.
363    fn one_limbs<F: GpuField>() -> Vec<Self>;
364    /// Returns the field modulus in non-Montgomery form as a vector of `Self::LimbType` (least
365    /// significant limb first).
366    fn modulus_limbs<F: GpuField>() -> Vec<Self>;
367    /// Calculate the `INV` parameter of Montgomery reduction algorithm for 32/64bit limbs
368    /// * `a` - Is the first limb of modulus.
369    fn calc_inv(a: Self) -> Self;
370    /// Returns the limbs that represent `R ^ 2 mod P`.
371    fn calculate_r2<F: GpuField>() -> Vec<Self>;
372}
373
374/// A 32-bit limb.
375#[derive(Clone, Copy)]
376pub struct Limb32(u32);
377impl Limb for Limb32 {
378    type LimbType = u32;
379    fn zero() -> Self {
380        Self(0)
381    }
382    fn new(val: Self::LimbType) -> Self {
383        Self(val)
384    }
385    fn value(&self) -> Self::LimbType {
386        self.0
387    }
388    fn ptx_info() -> (&'static str, &'static str) {
389        ("u32", "r")
390    }
391    fn opencl_type() -> &'static str {
392        "uint"
393    }
394    fn one_limbs<F: GpuField>() -> Vec<Self> {
395        F::one().into_iter().map(Self::new).collect()
396    }
397    fn modulus_limbs<F: GpuField>() -> Vec<Self> {
398        F::modulus().into_iter().map(Self::new).collect()
399    }
400    fn calc_inv(a: Self) -> Self {
401        let mut inv = 1u32;
402        for _ in 0..31 {
403            inv = inv.wrapping_mul(inv);
404            inv = inv.wrapping_mul(a.value());
405        }
406        Self(inv.wrapping_neg())
407    }
408    fn calculate_r2<F: GpuField>() -> Vec<Self> {
409        F::r2().into_iter().map(Self::new).collect()
410    }
411}
412
413/// A 64-bit limb.
414#[derive(Clone, Copy)]
415pub struct Limb64(u64);
416impl Limb for Limb64 {
417    type LimbType = u64;
418    fn zero() -> Self {
419        Self(0)
420    }
421    fn new(val: Self::LimbType) -> Self {
422        Self(val)
423    }
424    fn value(&self) -> Self::LimbType {
425        self.0
426    }
427    fn ptx_info() -> (&'static str, &'static str) {
428        ("u64", "l")
429    }
430    fn opencl_type() -> &'static str {
431        "ulong"
432    }
433    fn one_limbs<F: GpuField>() -> Vec<Self> {
434        F::one()
435            .chunks(2)
436            .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
437            .collect()
438    }
439
440    fn modulus_limbs<F: GpuField>() -> Vec<Self> {
441        F::modulus()
442            .chunks(2)
443            .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
444            .collect()
445    }
446
447    fn calc_inv(a: Self) -> Self {
448        let mut inv = 1u64;
449        for _ in 0..63 {
450            inv = inv.wrapping_mul(inv);
451            inv = inv.wrapping_mul(a.value());
452        }
453        Self(inv.wrapping_neg())
454    }
455    fn calculate_r2<F: GpuField>() -> Vec<Self> {
456        F::r2()
457            .chunks(2)
458            .map(|chunk| Self::new(((chunk[1] as u64) << 32) + (chunk[0] as u64)))
459            .collect()
460    }
461}
462
463fn const_field<L: Limb>(name: &str, limbs: Vec<L>) -> String {
464    format!(
465        "CONSTANT FIELD {} = {{ {{ {} }} }};",
466        name,
467        limbs
468            .iter()
469            .map(|l| l.value().to_string())
470            .collect::<Vec<_>>()
471            .join(", ")
472    )
473}
474
475/// Generates CUDA/OpenCL constants and type definitions of prime-field `F`
476fn params<F, L>() -> String
477where
478    F: GpuField,
479    L: Limb,
480{
481    let one = L::one_limbs::<F>(); // Get Montgomery form of F::one()
482    let p = L::modulus_limbs::<F>(); // Get field modulus in non-Montgomery form
483    let r2 = L::calculate_r2::<F>();
484    let limbs = one.len(); // Number of limbs
485    let inv = L::calc_inv(p[0]);
486    let limb_def = format!("#define FIELD_limb {}", L::opencl_type());
487    let limbs_def = format!("#define FIELD_LIMBS {}", limbs);
488    let limb_bits_def = format!("#define FIELD_LIMB_BITS {}", L::bits());
489    let p_def = const_field("FIELD_P", p);
490    let r2_def = const_field("FIELD_R2", r2);
491    let one_def = const_field("FIELD_ONE", one);
492    let zero_def = const_field("FIELD_ZERO", vec![L::zero(); limbs]);
493    let inv_def = format!("#define FIELD_INV {}", inv.value());
494    let typedef = "typedef struct { FIELD_limb val[FIELD_LIMBS]; } FIELD;".to_string();
495    [
496        limb_def,
497        limbs_def,
498        limb_bits_def,
499        inv_def,
500        typedef,
501        one_def,
502        p_def,
503        r2_def,
504        zero_def,
505    ]
506    .join("\n")
507}
508
509/// Generates PTX-Assembly implementation of FIELD_add_/FIELD_sub_
510fn field_add_sub_nvidia<F, L>() -> Result<String, std::fmt::Error>
511where
512    F: GpuField,
513    L: Limb,
514{
515    let mut result = String::new();
516    let (ptx_type, ptx_reg) = L::ptx_info();
517
518    writeln!(result, "#if defined(OPENCL_NVIDIA) || defined(CUDA)\n")?;
519    for op in &["sub", "add"] {
520        let len = L::one_limbs::<F>().len();
521
522        writeln!(
523            result,
524            "DEVICE FIELD FIELD_{}_nvidia(FIELD a, FIELD b) {{",
525            op
526        )?;
527        if len > 1 {
528            write!(result, "asm(")?;
529            writeln!(result, "\"{}.cc.{} %0, %0, %{};\\r\\n\"", op, ptx_type, len)?;
530
531            for i in 1..len - 1 {
532                writeln!(
533                    result,
534                    "\"{}c.cc.{} %{}, %{}, %{};\\r\\n\"",
535                    op,
536                    ptx_type,
537                    i,
538                    i,
539                    len + i
540                )?;
541            }
542            writeln!(
543                result,
544                "\"{}c.{} %{}, %{}, %{};\\r\\n\"",
545                op,
546                ptx_type,
547                len - 1,
548                len - 1,
549                2 * len - 1
550            )?;
551
552            write!(result, ":")?;
553            for n in 0..len {
554                write!(result, "\"+{}\"(a.val[{}])", ptx_reg, n)?;
555                if n != len - 1 {
556                    write!(result, ", ")?;
557                }
558            }
559
560            write!(result, "\n:")?;
561            for n in 0..len {
562                write!(result, "\"{}\"(b.val[{}])", ptx_reg, n)?;
563                if n != len - 1 {
564                    write!(result, ", ")?;
565                }
566            }
567            writeln!(result, ");")?;
568        }
569        writeln!(result, "return a;\n}}")?;
570    }
571    writeln!(result, "#endif")?;
572
573    Ok(result)
574}
575
576/// Convience function to generate a kernel/source based on a source builder.
577///
578/// When the `cuda` feature is enabled it will compile a CUDA fatbin. The path to the file is
579/// stored in the `_EC_GPU_CUDA_KERNEL_FATBIN` environment variable, that will automatically be
580/// used by the `ec-gpu-gen` functionality that needs a kernel.
581///
582///
583/// When the `opencl` feature is enabled it will generate the source code for OpenCL. The path to
584/// the source file is stored in the `_EC_GPU_OPENCL_KERNEL_SOURCE` environment variable, that will
585/// automatically be used by the `ec-gpu-gen` functionality that needs a kernel. OpenCL compiles
586/// the source at run time).
587#[allow(unused_variables)]
588pub fn generate(source_builder: &SourceBuilder) {
589    #[cfg(feature = "cuda")]
590    generate_cuda(source_builder);
591    #[cfg(feature = "opencl")]
592    generate_opencl(source_builder);
593}
594
595#[cfg(feature = "cuda")]
596fn generate_cuda(source_builder: &SourceBuilder) -> PathBuf {
597    use sha2::{Digest, Sha256};
598
599    // This is a hack when no properly compiled kernel is needed. That's the case when the
600    // documentation is built on docs.rs and when Clippy is run. We can use arbitrary bytes as
601    // input then.
602    if env::var("DOCS_RS").is_ok() || cfg!(feature = "cargo-clippy") {
603        println!("cargo:rustc-env=_EC_GPU_CUDA_KERNEL_FATBIN=../build.rs");
604        return PathBuf::from("../build.rs");
605    }
606
607    let kernel_source = source_builder.build_32_bit_limbs();
608    let out_dir = env::var("OUT_DIR").expect("OUT_DIR was not set.");
609
610    // Make it possible to override the default options. Though the source and output file is
611    // always set automatically.
612    let mut nvcc = match env::var("EC_GPU_CUDA_NVCC_ARGS") {
613        Ok(args) => execute::command(format!("nvcc {}", args)),
614        Err(_) => {
615            let mut command = std::process::Command::new("nvcc");
616            command
617                .arg("--optimize=6")
618                // Compile with as many threads as CPUs are available.
619                .arg("--threads=0")
620                .arg("--fatbin")
621                .arg("--gpu-architecture=sm_86")
622                .arg("--generate-code=arch=compute_86,code=sm_86")
623                .arg("--generate-code=arch=compute_80,code=sm_80")
624                .arg("--generate-code=arch=compute_75,code=sm_75");
625            command
626        }
627    };
628
629    // Hash the source and the compile flags. Use that as the filename, so that the kernel is only
630    // rebuilt if any of them change.
631    let mut hasher = Sha256::new();
632    hasher.update(kernel_source.as_bytes());
633    hasher.update(&format!("{:?}", &nvcc));
634    let kernel_digest = hex::encode(hasher.finalize());
635
636    let source_path: PathBuf = [&out_dir, &format!("{}.cu", &kernel_digest)]
637        .iter()
638        .collect();
639    let fatbin_path: PathBuf = [&out_dir, &format!("{}.fatbin", &kernel_digest)]
640        .iter()
641        .collect();
642
643    fs::write(&source_path, &kernel_source).unwrap_or_else(|_| {
644        panic!(
645            "Cannot write kernel source at {}.",
646            source_path.to_str().unwrap()
647        )
648    });
649
650    // Only compile if the output doesn't exist yet.
651    if !fatbin_path.as_path().exists() {
652        let status = nvcc
653            .arg("--output-file")
654            .arg(&fatbin_path)
655            .arg(&source_path)
656            .status()
657            .expect("Cannot run nvcc. Install the NVIDIA toolkit or disable the `cuda` feature.");
658
659        if !status.success() {
660            panic!(
661                "nvcc failed. See the kernel source at {}",
662                source_path.to_str().unwrap()
663            );
664        }
665    }
666
667    // The idea to put the path to the farbin into a compile-time env variable is from
668    // https://github.com/LutzCle/fast-interconnects-demo/blob/b80ea8e04825167f486ab8ac1b5d67cf7dd51d2c/rust-demo/build.rs
669    println!(
670        "cargo:rustc-env=_EC_GPU_CUDA_KERNEL_FATBIN={}",
671        fatbin_path.to_str().unwrap()
672    );
673
674    fatbin_path
675}
676
677#[cfg(feature = "opencl")]
678fn generate_opencl(source_builder: &SourceBuilder) -> PathBuf {
679    let kernel_source = source_builder.build_64_bit_limbs();
680    let out_dir = env::var("OUT_DIR").expect("OUT_DIR was not set.");
681
682    // Generating the kernel source is cheap, hence use a fixed name and override it on every
683    // build.
684    let source_path: PathBuf = [&out_dir, "kernel.cl"].iter().collect();
685
686    fs::write(&source_path, kernel_source).unwrap_or_else(|_| {
687        panic!(
688            "Cannot write kernel source at {}.",
689            source_path.to_str().unwrap()
690        )
691    });
692
693    // For OpenCL we only need the kernel source, it is compiled at runtime.
694    println!(
695        "cargo:rustc-env=_EC_GPU_OPENCL_KERNEL_SOURCE={}",
696        source_path.to_str().unwrap()
697    );
698
699    source_path
700}
701
702#[cfg(all(test, any(feature = "opencl", feature = "cuda")))]
703mod tests {
704    use super::*;
705
706    use std::sync::Mutex;
707
708    #[cfg(feature = "cuda")]
709    use rust_gpu_tools::cuda;
710    #[cfg(feature = "opencl")]
711    use rust_gpu_tools::opencl;
712    use rust_gpu_tools::{program_closures, Device, GPUError, Program};
713
714    use blstrs::Scalar;
715    use ff::{Field as _, PrimeField};
716    use lazy_static::lazy_static;
717    use rand::{thread_rng, Rng};
718
719    static TEST_SRC: &str = include_str!("./cl/test.cl");
720
721    #[derive(PartialEq, Debug, Clone, Copy)]
722    #[repr(transparent)]
723    pub struct GpuScalar(pub Scalar);
724    impl Default for GpuScalar {
725        fn default() -> Self {
726            Self(Scalar::ZERO)
727        }
728    }
729
730    #[cfg(feature = "cuda")]
731    impl cuda::KernelArgument for GpuScalar {
732        fn as_c_void(&self) -> *mut std::ffi::c_void {
733            &self.0 as *const _ as _
734        }
735    }
736
737    #[cfg(feature = "opencl")]
738    impl opencl::KernelArgument for GpuScalar {
739        fn push(&self, kernel: &mut opencl::Kernel) {
740            unsafe { kernel.builder.set_arg(&self.0) };
741        }
742    }
743
744    /// The `run` call needs to return a result, use this struct as placeholder.
745    #[derive(Debug)]
746    struct NoError;
747    impl From<GPUError> for NoError {
748        fn from(_error: GPUError) -> Self {
749            Self
750        }
751    }
752
753    fn test_source() -> SourceBuilder {
754        let test_source = String::from(TEST_SRC).replace("FIELD", &Scalar::name());
755        SourceBuilder::new()
756            .add_field::<Scalar>()
757            .append_source(test_source)
758    }
759
760    #[cfg(feature = "cuda")]
761    lazy_static! {
762        static ref CUDA_PROGRAM: Mutex<Program> = {
763            use std::ffi::CString;
764
765            let source = test_source();
766            let fatbin_path = generate_cuda(&source);
767
768            let device = *Device::all().first().expect("Cannot get a default device.");
769            let cuda_device = device.cuda_device().unwrap();
770            let fatbin_path_cstring =
771                CString::new(fatbin_path.to_str().expect("path is not valid UTF-8."))
772                    .expect("path contains NULL byte.");
773            let program =
774                cuda::Program::from_binary(cuda_device, fatbin_path_cstring.as_c_str()).unwrap();
775            Mutex::new(Program::Cuda(program))
776        };
777    }
778
779    #[cfg(feature = "opencl")]
780    lazy_static! {
781        static ref OPENCL_PROGRAM: Mutex<(Program, Program)> = {
782            let device = *Device::all().first().expect("Cannot get a default device");
783            let opencl_device = device.opencl_device().unwrap();
784            let source_32 = test_source().build_32_bit_limbs();
785            let program_32 = opencl::Program::from_opencl(opencl_device, &source_32).unwrap();
786            let source_64 = test_source().build_64_bit_limbs();
787            let program_64 = opencl::Program::from_opencl(opencl_device, &source_64).unwrap();
788            Mutex::new((Program::Opencl(program_32), Program::Opencl(program_64)))
789        };
790    }
791
792    fn call_kernel(name: &str, scalars: &[GpuScalar], uints: &[u32]) -> Scalar {
793        let closures = program_closures!(|program, _args| -> Result<Scalar, NoError> {
794            let mut cpu_buffer = vec![GpuScalar::default()];
795            let buffer = program.create_buffer_from_slice(&cpu_buffer).unwrap();
796
797            let mut kernel = program.create_kernel(name, 1, 64).unwrap();
798            for scalar in scalars {
799                kernel = kernel.arg(scalar);
800            }
801            for uint in uints {
802                kernel = kernel.arg(uint);
803            }
804            kernel.arg(&buffer).run().unwrap();
805
806            program.read_into_buffer(&buffer, &mut cpu_buffer).unwrap();
807            Ok(cpu_buffer[0].0)
808        });
809
810        // For CUDA we only test 32-bit limbs.
811        #[cfg(all(feature = "cuda", not(feature = "opencl")))]
812        return CUDA_PROGRAM.lock().unwrap().run(closures, ()).unwrap();
813
814        // For OpenCL we test for 32 and 64-bi limbs.
815        #[cfg(all(feature = "opencl", not(feature = "cuda")))]
816        {
817            let result_32 = OPENCL_PROGRAM.lock().unwrap().0.run(closures, ()).unwrap();
818            let result_64 = OPENCL_PROGRAM.lock().unwrap().1.run(closures, ()).unwrap();
819            assert_eq!(
820                result_32, result_64,
821                "Results for 32-bit and 64-bit limbs must be the same."
822            );
823            result_32
824        }
825
826        // When both features are enabled, check if the results are the same
827        #[cfg(all(feature = "cuda", feature = "opencl"))]
828        {
829            let cuda_result = CUDA_PROGRAM.lock().unwrap().run(closures, ()).unwrap();
830            let opencl_32_result = OPENCL_PROGRAM.lock().unwrap().0.run(closures, ()).unwrap();
831            let opencl_64_result = OPENCL_PROGRAM.lock().unwrap().1.run(closures, ()).unwrap();
832            assert_eq!(
833                opencl_32_result, opencl_64_result,
834                "Results for 32-bit and 64-bit limbs on OpenCL must be the same."
835            );
836            assert_eq!(
837                cuda_result, opencl_32_result,
838                "Results for CUDA and OpenCL must be the same."
839            );
840            cuda_result
841        }
842    }
843
844    #[test]
845    fn test_add() {
846        let mut rng = thread_rng();
847        for _ in 0..10 {
848            let a = Scalar::random(&mut rng);
849            let b = Scalar::random(&mut rng);
850            let c = a + b;
851
852            assert_eq!(
853                call_kernel("test_add", &[GpuScalar(a), GpuScalar(b)], &[]),
854                c
855            );
856        }
857    }
858
859    #[test]
860    fn test_sub() {
861        let mut rng = thread_rng();
862        for _ in 0..10 {
863            let a = Scalar::random(&mut rng);
864            let b = Scalar::random(&mut rng);
865            let c = a - b;
866            assert_eq!(
867                call_kernel("test_sub", &[GpuScalar(a), GpuScalar(b)], &[]),
868                c
869            );
870        }
871    }
872
873    #[test]
874    fn test_mul() {
875        let mut rng = thread_rng();
876        for _ in 0..10 {
877            let a = Scalar::random(&mut rng);
878            let b = Scalar::random(&mut rng);
879            let c = a * b;
880
881            assert_eq!(
882                call_kernel("test_mul", &[GpuScalar(a), GpuScalar(b)], &[]),
883                c
884            );
885        }
886    }
887
888    #[test]
889    fn test_pow() {
890        let mut rng = thread_rng();
891        for _ in 0..10 {
892            let a = Scalar::random(&mut rng);
893            let b = rng.gen::<u32>();
894            let c = a.pow_vartime([b as u64]);
895            assert_eq!(call_kernel("test_pow", &[GpuScalar(a)], &[b]), c);
896        }
897    }
898
899    #[test]
900    fn test_sqr() {
901        let mut rng = thread_rng();
902        for _ in 0..10 {
903            let a = Scalar::random(&mut rng);
904            let b = a.square();
905
906            assert_eq!(call_kernel("test_sqr", &[GpuScalar(a)], &[]), b);
907        }
908    }
909
910    #[test]
911    fn test_double() {
912        let mut rng = thread_rng();
913        for _ in 0..10 {
914            let a = Scalar::random(&mut rng);
915            let b = a.double();
916
917            assert_eq!(call_kernel("test_double", &[GpuScalar(a)], &[]), b);
918        }
919    }
920
921    #[test]
922    fn test_unmont() {
923        let mut rng = thread_rng();
924        for _ in 0..10 {
925            let a = Scalar::random(&mut rng);
926            let b: Scalar = unsafe { std::mem::transmute(a.to_repr()) };
927            assert_eq!(call_kernel("test_unmont", &[GpuScalar(a)], &[]), b);
928        }
929    }
930
931    #[test]
932    fn test_mont() {
933        let mut rng = thread_rng();
934        for _ in 0..10 {
935            let a_repr = Scalar::random(&mut rng).to_repr();
936            let a: Scalar = unsafe { std::mem::transmute(a_repr) };
937            let b = Scalar::from_repr(a_repr).unwrap();
938            assert_eq!(call_kernel("test_mont", &[GpuScalar(a)], &[]), b);
939        }
940    }
941}