use super::CudaProver;
use sp1_core_executor::SP1CoreOpts;
use sp1_core_machine::riscv::RiscvAir;
use sp1_cuda::CudaProver as CudaProverImpl;
use sp1_hypercube::Machine;
use sp1_primitives::SP1Field;
use sp1_prover::worker::SP1LightNode;
use crate::blocking::block_on;
#[derive(Debug)]
pub struct CudaProverBuilder {
cuda_device_id: Option<u32>,
core_opts: Option<SP1CoreOpts>,
machine: Machine<SP1Field, RiscvAir<SP1Field>>,
}
impl Default for CudaProverBuilder {
fn default() -> Self {
Self::new()
}
}
impl CudaProverBuilder {
#[must_use]
pub fn new() -> Self {
Self::new_with_machine(RiscvAir::machine())
}
#[must_use]
pub const fn new_with_machine(machine: Machine<SP1Field, RiscvAir<SP1Field>>) -> Self {
Self { cuda_device_id: None, core_opts: None, machine }
}
#[must_use]
pub fn with_device_id(mut self, id: u32) -> Self {
self.cuda_device_id = Some(id);
self
}
#[must_use]
pub fn core_opts(mut self, opts: SP1CoreOpts) -> Self {
self.core_opts = Some(opts);
self
}
#[must_use]
pub fn with_opts(self, opts: SP1CoreOpts) -> Self {
self.core_opts(opts)
}
#[must_use]
pub fn build(self) -> CudaProver {
tracing::info!("initializing cuda prover");
let node = block_on(SP1LightNode::with_opts_and_machine(
self.machine,
self.core_opts.unwrap_or_default(),
));
let cuda_prover = match self.cuda_device_id {
Some(id) => crate::blocking::block_on(CudaProverImpl::new_with_id(id)),
None => crate::blocking::block_on(CudaProverImpl::new()),
};
CudaProver { node, prover: cuda_prover.expect("Failed to create the CUDA prover impl") }
}
}