use super::super::context::{GpuLinalgContext, LinalgMetadata};
use crate::gpu::buffer::GpuBuffer;
use crate::{Result, Shape, TensorError};
use bytemuck::{Pod, Zeroable};
use scirs2_core::numeric::{Float, One};
use wgpu::{BufferDescriptor, BufferUsages};
pub fn eigenvalues<T>(
context: &mut GpuLinalgContext,
input: &GpuBuffer<T>,
eigenvalues: &GpuBuffer<T>,
eigenvectors: Option<&GpuBuffer<T>>,
shape: &Shape,
) -> Result<()>
where
T: Float + Pod + Zeroable + Clone + Send + Sync + 'static,
{
context.eigenvalues(input, eigenvalues, eigenvectors, shape)
}
impl GpuLinalgContext {
pub fn eigenvalues<T>(
&mut self,
input: &GpuBuffer<T>,
eigenvalues: &GpuBuffer<T>,
eigenvectors: Option<&GpuBuffer<T>>,
shape: &Shape,
) -> Result<()>
where
T: Float + Pod + Zeroable + Clone + Send + Sync + 'static,
{
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::invalid_shape_simple(
"Eigenvalue computation requires a square matrix".to_string(),
));
}
let n = shape[0];
if n == 0 {
return Ok(()); }
if n < 4 {
return Err(TensorError::ComputeError {
operation: "gpu_eigenvalue".to_string(),
details: "GPU eigenvalue computation requires matrices >= 4x4 - use CPU fallback for smaller matrices".to_string(),
retry_possible: false,
context: None,
});
}
if self.eigenvalue_pipeline.is_none() {
self.initialize_eigenvalue_pipeline()?;
}
let matrix_size = n * n;
let working_matrix = self.device().create_buffer(&BufferDescriptor {
label: Some("eigen_working_matrix"),
size: (matrix_size * std::mem::size_of::<T>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let q_matrix = self.device().create_buffer(&BufferDescriptor {
label: Some("eigen_q_matrix"),
size: (matrix_size * std::mem::size_of::<T>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let metadata = LinalgMetadata::new(n, n)
.with_tolerance(
T::from(1e-10)
.unwrap_or_else(|| T::from(0.0).expect("fallback value computation failed"))
.to_f64() as f32,
)
.with_max_iterations(100 * n as u32);
let metadata_buffer = self.create_metadata_buffer(&metadata)?;
self.execute_eigenvalue_computation(
input,
eigenvalues,
eigenvectors,
&working_matrix,
&q_matrix,
&metadata_buffer,
&metadata,
n,
)
}
fn execute_eigenvalue_computation<T>(
&mut self,
input: &GpuBuffer<T>,
eigenvalues: &GpuBuffer<T>,
eigenvectors: Option<&GpuBuffer<T>>,
working_matrix: &wgpu::Buffer,
q_matrix: &wgpu::Buffer,
metadata_buffer: &wgpu::Buffer,
metadata: &LinalgMetadata,
n: usize,
) -> Result<()>
where
T: Float + Pod + Zeroable + Clone + Send + Sync + 'static,
{
let pipelines = self.create_eigenvalue_pipelines()?;
let eigenvectors_buffer = if let Some(eigenvecs) = eigenvectors {
eigenvecs.buffer()
} else {
&self.device().create_buffer(&BufferDescriptor {
label: Some("temp_eigenvectors"),
size: (n * n * std::mem::size_of::<T>()) as u64,
usage: BufferUsages::STORAGE,
mapped_at_creation: false,
})
};
let bind_group = self.device().create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("eigen_bind_group"),
layout: &pipelines.init.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: eigenvalues.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: eigenvectors_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: working_matrix.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: q_matrix.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: metadata_buffer.as_entire_binding(),
},
],
});
self.initialize_eigenvalue_matrices(&pipelines.init, &bind_group, n)?;
self.perform_jacobi_iterations(
&pipelines.givens,
input,
eigenvalues,
eigenvectors_buffer,
working_matrix,
q_matrix,
metadata,
n,
)?;
self.finalize_eigenvalue_computation(
&pipelines,
&bind_group,
eigenvectors.is_some(),
n,
)?;
Ok(())
}
fn initialize_eigenvalue_matrices(
&mut self,
init_pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
n: usize,
) -> Result<()> {
let mut encoder = self
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("eigen_init_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("eigen_init_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(init_pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
let workgroups_x = (n as u32 + 15) / 16;
let workgroups_y = (n as u32 + 15) / 16;
compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
}
self.queue().submit(std::iter::once(encoder.finish()));
Ok(())
}
fn perform_jacobi_iterations<T>(
&mut self,
givens_pipeline: &wgpu::ComputePipeline,
input: &GpuBuffer<T>,
eigenvalues: &GpuBuffer<T>,
eigenvectors_buffer: &wgpu::Buffer,
working_matrix: &wgpu::Buffer,
q_matrix: &wgpu::Buffer,
metadata: &LinalgMetadata,
n: usize,
) -> Result<()>
where
T: Float + Pod + Zeroable + Clone + Send + Sync + 'static,
{
let max_iterations = metadata.max_iterations.min(100);
for _iter in 0..max_iterations {
let mut converged = true;
for i in 0..n {
for j in (i + 1)..n {
let updated_metadata = LinalgMetadata {
rows_a: n as u32,
cols_a: n as u32,
rows_b: i as u32, cols_b: j as u32, batch_size: 1,
tolerance: metadata.tolerance,
max_iterations: metadata.max_iterations,
_padding: 0,
};
let iter_metadata_buffer = self.create_metadata_buffer(&updated_metadata)?;
let iter_bind_group =
self.device().create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("eigen_iter_bind_group"),
layout: &givens_pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: eigenvalues.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: eigenvectors_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: working_matrix.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: q_matrix.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: iter_metadata_buffer.as_entire_binding(),
},
],
});
let mut encoder =
self.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("eigen_givens_encoder"),
});
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("eigen_givens_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(givens_pipeline);
compute_pass.set_bind_group(0, &iter_bind_group, &[]);
compute_pass.dispatch_workgroups((n as u32 + 255) / 256, 1, 1);
}
self.queue().submit(std::iter::once(encoder.finish()));
converged = false; }
}
if converged {
break;
}
}
Ok(())
}
fn finalize_eigenvalue_computation(
&mut self,
pipelines: &EigenvaluePipelines,
bind_group: &wgpu::BindGroup,
compute_eigenvectors: bool,
n: usize,
) -> Result<()> {
let mut encoder = self
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("eigen_extract_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("eigen_extract_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipelines.extract);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups((n as u32 + 255) / 256, 1, 1);
}
self.queue().submit(std::iter::once(encoder.finish()));
if compute_eigenvectors {
self.sort_eigenvalues(&pipelines.sort, bind_group)?;
self.normalize_eigenvectors(&pipelines.normalize, bind_group, n)?;
}
Ok(())
}
fn sort_eigenvalues(
&mut self,
sort_pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
) -> Result<()> {
let mut encoder = self
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("eigen_sort_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("eigen_sort_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(sort_pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(1, 1, 1); }
self.queue().submit(std::iter::once(encoder.finish()));
Ok(())
}
fn normalize_eigenvectors(
&mut self,
normalize_pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
n: usize,
) -> Result<()> {
let mut encoder = self
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("eigen_normalize_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("eigen_normalize_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(normalize_pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups((n as u32 + 255) / 256, 1, 1);
}
self.queue().submit(std::iter::once(encoder.finish()));
Ok(())
}
fn create_eigenvalue_pipelines(&self) -> Result<EigenvaluePipelines> {
let shader_source = include_str!("../../shaders/linalg_eigenvalue.wgsl");
let shader_module = self
.device()
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("eigenvalue_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let init = self
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("eigen_init_pipeline"),
layout: None,
module: &shader_module,
entry_point: Some("initialize_eigen"),
cache: None,
compilation_options: Default::default(),
});
let givens = self
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("eigen_givens_pipeline"),
layout: None,
module: &shader_module,
entry_point: Some("apply_givens_eigen"),
cache: None,
compilation_options: Default::default(),
});
let extract = self
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("eigen_extract_pipeline"),
layout: None,
module: &shader_module,
entry_point: Some("extract_eigenvalues"),
cache: None,
compilation_options: Default::default(),
});
let sort = self
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("eigen_sort_pipeline"),
layout: None,
module: &shader_module,
entry_point: Some("sort_eigenvalues"),
cache: None,
compilation_options: Default::default(),
});
let normalize = self
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("eigen_normalize_pipeline"),
layout: None,
module: &shader_module,
entry_point: Some("normalize_eigenvectors"),
cache: None,
compilation_options: Default::default(),
});
Ok(EigenvaluePipelines {
init,
givens,
extract,
sort,
normalize,
})
}
}
struct EigenvaluePipelines {
init: wgpu::ComputePipeline,
givens: wgpu::ComputePipeline,
extract: wgpu::ComputePipeline,
sort: wgpu::ComputePipeline,
normalize: wgpu::ComputePipeline,
}