use super::CudaProver;
use sp1_core_executor::SP1CoreOpts;
use sp1_cuda::CudaProver as CudaProverImpl;
use sp1_prover::worker::SP1LightNode;
#[derive(Debug, Default)]
pub struct CudaProverBuilder {
cuda_device_id: Option<u32>,
core_opts: Option<SP1CoreOpts>,
}
impl CudaProverBuilder {
#[must_use]
pub fn with_device_id(mut self, id: u32) -> Self {
self.cuda_device_id = Some(id);
self
}
#[must_use]
pub fn core_opts(mut self, opts: SP1CoreOpts) -> Self {
self.core_opts = Some(opts);
self
}
#[must_use]
pub fn with_opts(self, opts: SP1CoreOpts) -> Self {
self.core_opts(opts)
}
#[must_use]
pub async fn build(self) -> CudaProver {
let node = SP1LightNode::with_opts(self.core_opts.unwrap_or_default()).await;
let cuda_prover = match self.cuda_device_id {
Some(id) => CudaProverImpl::new_with_id(id).await,
None => CudaProverImpl::new().await,
};
CudaProver { node, prover: cuda_prover.expect("Failed to create the CUDA prover impl") }
}
}