use crate::reed_solomon::ReedSolomon;
use binary_fields::BinaryFieldElement;
#[cfg(feature = "webgpu")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "webgpu")]
use crate::gpu::{fft::GpuFft, GpuDevice};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendHint {
Auto,
Cpu,
Gpu,
}
impl BackendHint {
pub fn from_env() -> Self {
match std::env::var("LIGERITO_BACKEND").as_deref() {
Ok("cpu") | Ok("CPU") => BackendHint::Cpu,
Ok("gpu") | Ok("GPU") => BackendHint::Gpu,
Ok("auto") | Ok("AUTO") | Ok(_) => BackendHint::Auto,
Err(_) => BackendHint::Auto,
}
}
}
pub trait Backend: Send + Sync {
fn fft_inplace<F>(&self, data: &mut [F], twiddles: &[F], parallel: bool) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static;
fn encode_cols<F>(
&self,
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
parallel: bool,
) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static;
fn name(&self) -> &'static str;
}
pub struct CpuBackend;
impl Backend for CpuBackend {
fn fft_inplace<F>(&self, data: &mut [F], twiddles: &[F], parallel: bool) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
crate::reed_solomon::fft(data, twiddles, parallel);
Ok(())
}
fn encode_cols<F>(
&self,
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
parallel: bool,
) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
crate::ligero::encode_cols(poly_mat, rs, parallel);
Ok(())
}
fn name(&self) -> &'static str {
"CPU"
}
}
#[cfg(feature = "webgpu")]
pub struct GpuBackend {
fft: Arc<Mutex<GpuFft>>,
cpu_fallback: CpuBackend,
enabled: Arc<Mutex<bool>>,
}
#[cfg(feature = "webgpu")]
impl GpuBackend {
pub fn new() -> crate::Result<Self> {
let device = pollster::block_on(GpuDevice::new())
.map_err(|e| crate::LigeritoError::GpuInitFailed(e.to_string()))?;
let fft = GpuFft::new(device);
Ok(Self {
fft: Arc::new(Mutex::new(fft)),
cpu_fallback: CpuBackend,
enabled: Arc::new(Mutex::new(true)),
})
}
fn is_enabled(&self) -> bool {
*self.enabled.lock().unwrap()
}
fn disable(&self) {
*self.enabled.lock().unwrap() = false;
}
}
#[cfg(feature = "webgpu")]
impl Backend for GpuBackend {
fn fft_inplace<F>(&self, data: &mut [F], twiddles: &[F], parallel: bool) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
if !self.is_enabled() {
return self.cpu_fallback.fft_inplace(data, twiddles, parallel);
}
let mut fft = self.fft.lock().unwrap();
match pollster::block_on(fft.fft_inplace(data)) {
Ok(_) => Ok(()),
Err(e) => {
eprintln!("GPU FFT failed: {}. Falling back to CPU.", e);
self.disable();
drop(fft); self.cpu_fallback.fft_inplace(data, twiddles, parallel)
}
}
}
fn encode_cols<F>(
&self,
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
parallel: bool,
) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
self.cpu_fallback.encode_cols(poly_mat, rs, parallel)
}
fn name(&self) -> &'static str {
if self.is_enabled() {
"GPU (WebGPU)"
} else {
"GPU (disabled, using CPU)"
}
}
}
pub enum BackendImpl {
Cpu(CpuBackend),
#[cfg(feature = "webgpu")]
Gpu(GpuBackend),
}
impl Backend for BackendImpl {
fn fft_inplace<F>(&self, data: &mut [F], twiddles: &[F], parallel: bool) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
match self {
BackendImpl::Cpu(cpu) => cpu.fft_inplace(data, twiddles, parallel),
#[cfg(feature = "webgpu")]
BackendImpl::Gpu(gpu) => gpu.fft_inplace(data, twiddles, parallel),
}
}
fn encode_cols<F>(
&self,
poly_mat: &mut Vec<Vec<F>>,
rs: &ReedSolomon<F>,
parallel: bool,
) -> crate::Result<()>
where
F: BinaryFieldElement + Send + Sync + bytemuck::Pod + 'static,
{
match self {
BackendImpl::Cpu(cpu) => cpu.encode_cols(poly_mat, rs, parallel),
#[cfg(feature = "webgpu")]
BackendImpl::Gpu(gpu) => gpu.encode_cols(poly_mat, rs, parallel),
}
}
fn name(&self) -> &'static str {
match self {
BackendImpl::Cpu(cpu) => cpu.name(),
#[cfg(feature = "webgpu")]
BackendImpl::Gpu(gpu) => gpu.name(),
}
}
}
pub struct BackendSelector {
backend: BackendImpl,
hint: BackendHint,
}
impl BackendSelector {
pub fn new(hint: BackendHint) -> Self {
let backend = Self::select_backend(hint);
Self { backend, hint }
}
pub fn auto() -> Self {
Self::new(BackendHint::from_env())
}
pub fn backend(&self) -> &BackendImpl {
&self.backend
}
fn select_backend(hint: BackendHint) -> BackendImpl {
match hint {
BackendHint::Cpu => BackendImpl::Cpu(CpuBackend),
BackendHint::Gpu => {
#[cfg(feature = "webgpu")]
{
match GpuBackend::new() {
Ok(gpu) => {
#[cfg(not(target_arch = "wasm32"))]
eprintln!("GPU initialized successfully");
return BackendImpl::Gpu(gpu);
}
Err(e) => {
eprintln!("GPU initialization failed: {:?}. Falling back to CPU.", e);
}
}
}
#[cfg(not(feature = "webgpu"))]
{
eprintln!("GPU requested but not compiled. Use --features webgpu. Falling back to CPU.");
}
BackendImpl::Cpu(CpuBackend)
}
BackendHint::Auto => {
#[cfg(feature = "webgpu")]
{
match GpuBackend::new() {
Ok(gpu) => {
#[cfg(not(target_arch = "wasm32"))]
eprintln!("GPU detected and initialized");
return BackendImpl::Gpu(gpu);
}
Err(_) => {
}
}
}
BackendImpl::Cpu(CpuBackend)
}
}
}
}
impl Default for BackendSelector {
fn default() -> Self {
Self::auto()
}
}
impl Clone for BackendSelector {
fn clone(&self) -> Self {
Self::new(self.hint)
}
}