Skip to main content

trueno/backends/gpu/device/
eigen.rs

1//! GPU eigendecomposition operations
2//!
3//! Symmetric eigendecomposition using Jacobi algorithm with GPU-accelerated
4//! Givens rotations, plus CPU fallback for small matrices.
5
6#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
7use super::super::runtime;
8use super::super::shaders;
9use super::GpuDevice;
10
11impl GpuDevice {
12    /// Execute symmetric eigendecomposition on GPU (sync, native only)
13    ///
14    /// Computes eigenvalues and eigenvectors using Jacobi algorithm with GPU-accelerated
15    /// Givens rotations. Returns (eigenvalues, eigenvector_data) where eigenvector_data
16    /// is in row-major format.
17    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
18    pub fn symmetric_eigen(
19        &self,
20        matrix: &[f32],
21        n: usize,
22    ) -> Result<(Vec<f32>, Vec<f32>), String> {
23        runtime::block_on(async { self.symmetric_eigen_async(matrix, n).await })
24    }
25
26    /// Execute symmetric eigendecomposition on GPU (async, works on all platforms)
27    ///
28    /// Computes eigenvalues and eigenvectors using Jacobi algorithm with GPU-accelerated
29    /// Givens rotations.
30    pub async fn symmetric_eigen_async(
31        &self,
32        matrix: &[f32],
33        n: usize,
34    ) -> Result<(Vec<f32>, Vec<f32>), String> {
35        if matrix.len() != n * n {
36            return Err(format!(
37                "Matrix size mismatch: expected {} elements for {}x{} matrix, got {}",
38                n * n,
39                n,
40                n,
41                matrix.len()
42            ));
43        }
44
45        if n == 0 {
46            return Ok((Vec::new(), Vec::new()));
47        }
48
49        // For small matrices, use CPU (GPU overhead not worth it)
50        if n < 64 {
51            return self.symmetric_eigen_cpu(matrix, n);
52        }
53
54        // Create shader module for Jacobi rotation
55        let rotation_shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
56            label: Some("Jacobi Rotation Shader"),
57            source: wgpu::ShaderSource::Wgsl(shaders::JACOBI_ROTATION_SHADER.into()),
58        });
59
60        // Create buffers
61        let matrix_size = (n * n * std::mem::size_of::<f32>()) as u64;
62
63        let matrix_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
64            label: Some("Matrix Buffer"),
65            size: matrix_size,
66            usage: wgpu::BufferUsages::STORAGE
67                | wgpu::BufferUsages::COPY_DST
68                | wgpu::BufferUsages::COPY_SRC,
69            mapped_at_creation: false,
70        });
71
72        let eigenvectors_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
73            label: Some("Eigenvectors Buffer"),
74            size: matrix_size,
75            usage: wgpu::BufferUsages::STORAGE
76                | wgpu::BufferUsages::COPY_DST
77                | wgpu::BufferUsages::COPY_SRC,
78            mapped_at_creation: false,
79        });
80
81        // JacobiParams uniform buffer
82        #[repr(C)]
83        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
84        struct JacobiParams {
85            n: u32,
86            p: u32,
87            q: u32,
88            c: f32,
89            s: f32,
90            _padding: [u32; 3],
91        }
92
93        let params_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
94            label: Some("Jacobi Params"),
95            size: std::mem::size_of::<JacobiParams>() as u64,
96            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
97            mapped_at_creation: false,
98        });
99
100        // Initialize eigenvectors to identity matrix
101        let mut eigenvectors = vec![0.0f32; n * n];
102        for i in 0..n {
103            eigenvectors[i * n + i] = 1.0;
104        }
105
106        // Write initial data
107        self.queue.write_buffer(&matrix_buffer, 0, bytemuck::cast_slice(matrix));
108        self.queue.write_buffer(&eigenvectors_buffer, 0, bytemuck::cast_slice(&eigenvectors));
109
110        // Create bind group layout
111        let bind_group_layout =
112            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
113                label: Some("Jacobi Bind Group Layout"),
114                entries: &[
115                    wgpu::BindGroupLayoutEntry {
116                        binding: 0,
117                        visibility: wgpu::ShaderStages::COMPUTE,
118                        ty: wgpu::BindingType::Buffer {
119                            ty: wgpu::BufferBindingType::Storage { read_only: false },
120                            has_dynamic_offset: false,
121                            min_binding_size: None,
122                        },
123                        count: None,
124                    },
125                    wgpu::BindGroupLayoutEntry {
126                        binding: 1,
127                        visibility: wgpu::ShaderStages::COMPUTE,
128                        ty: wgpu::BindingType::Buffer {
129                            ty: wgpu::BufferBindingType::Storage { read_only: false },
130                            has_dynamic_offset: false,
131                            min_binding_size: None,
132                        },
133                        count: None,
134                    },
135                    wgpu::BindGroupLayoutEntry {
136                        binding: 2,
137                        visibility: wgpu::ShaderStages::COMPUTE,
138                        ty: wgpu::BindingType::Buffer {
139                            ty: wgpu::BufferBindingType::Uniform,
140                            has_dynamic_offset: false,
141                            min_binding_size: None,
142                        },
143                        count: None,
144                    },
145                ],
146            });
147
148        // Create bind group
149        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
150            label: Some("Jacobi Bind Group"),
151            layout: &bind_group_layout,
152            entries: &[
153                wgpu::BindGroupEntry { binding: 0, resource: matrix_buffer.as_entire_binding() },
154                wgpu::BindGroupEntry {
155                    binding: 1,
156                    resource: eigenvectors_buffer.as_entire_binding(),
157                },
158                wgpu::BindGroupEntry { binding: 2, resource: params_buffer.as_entire_binding() },
159            ],
160        });
161
162        // Create pipeline
163        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
164            label: Some("Jacobi Pipeline Layout"),
165            bind_group_layouts: &[&bind_group_layout],
166            push_constant_ranges: &[],
167        });
168
169        let rotation_pipeline =
170            self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
171                label: Some("Jacobi Rotation Pipeline"),
172                layout: Some(&pipeline_layout),
173                module: &rotation_shader,
174                entry_point: Some("main"),
175                compilation_options: wgpu::PipelineCompilationOptions::default(),
176                cache: None,
177            });
178
179        // Jacobi iteration
180        let max_sweeps = 50;
181        let tolerance = 1e-7 * (matrix.iter().map(|x| x * x).sum::<f32>().sqrt()).max(1.0);
182
183        // Working copy of matrix for CPU-side pivot selection
184        let mut a = matrix.to_vec();
185
186        for _sweep in 0..max_sweeps {
187            let mut converged = true;
188
189            // Cyclic Jacobi: process all pairs (i, j) where i < j
190            for i in 0..n {
191                for j in (i + 1)..n {
192                    let aij = a[i * n + j];
193
194                    if aij.abs() < tolerance {
195                        continue;
196                    }
197
198                    converged = false;
199
200                    // Compute rotation parameters
201                    let aii = a[i * n + i];
202                    let ajj = a[j * n + j];
203
204                    let tau = (ajj - aii) / (2.0 * aij);
205                    let t = if tau >= 0.0 {
206                        1.0 / (tau + (1.0 + tau * tau).sqrt())
207                    } else {
208                        -1.0 / (-tau + (1.0 + tau * tau).sqrt())
209                    };
210
211                    let c = 1.0 / (1.0 + t * t).sqrt();
212                    let s = t * c;
213
214                    // Update params and dispatch GPU
215                    let params = JacobiParams {
216                        n: n as u32,
217                        p: i as u32,
218                        q: j as u32,
219                        c,
220                        s,
221                        _padding: [0; 3],
222                    };
223
224                    self.queue.write_buffer(&params_buffer, 0, bytemuck::bytes_of(&params));
225
226                    // Create command encoder and dispatch
227                    let mut encoder =
228                        self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
229                            label: Some("Jacobi Rotation Encoder"),
230                        });
231
232                    {
233                        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
234                            label: Some("Jacobi Rotation Pass"),
235                            timestamp_writes: None,
236                        });
237                        pass.set_pipeline(&rotation_pipeline);
238                        pass.set_bind_group(0, &bind_group, &[]);
239                        pass.dispatch_workgroups((n as u32).div_ceil(256), 1, 1);
240                    }
241
242                    self.queue.submit(Some(encoder.finish()));
243
244                    // Update local copy of diagonal and off-diagonal
245                    a[i * n + i] = aii - t * aij;
246                    a[j * n + j] = ajj + t * aij;
247                    a[i * n + j] = 0.0;
248                    a[j * n + i] = 0.0;
249
250                    // Update off-diagonal elements in rows/columns i and j
251                    for k in 0..n {
252                        if k != i && k != j {
253                            let aki = a[k * n + i];
254                            let akj = a[k * n + j];
255                            a[k * n + i] = c * aki - s * akj;
256                            a[i * n + k] = a[k * n + i];
257                            a[k * n + j] = s * aki + c * akj;
258                            a[j * n + k] = a[k * n + j];
259                        }
260                    }
261                }
262            }
263
264            if converged {
265                break;
266            }
267        }
268
269        // Read back results
270        let staging_matrix = self.device.create_buffer(&wgpu::BufferDescriptor {
271            label: Some("Staging Matrix"),
272            size: matrix_size,
273            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
274            mapped_at_creation: false,
275        });
276
277        let staging_eigenvectors = self.device.create_buffer(&wgpu::BufferDescriptor {
278            label: Some("Staging Eigenvectors"),
279            size: matrix_size,
280            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
281            mapped_at_creation: false,
282        });
283
284        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
285            label: Some("Copy Encoder"),
286        });
287
288        encoder.copy_buffer_to_buffer(&matrix_buffer, 0, &staging_matrix, 0, matrix_size);
289        encoder.copy_buffer_to_buffer(
290            &eigenvectors_buffer,
291            0,
292            &staging_eigenvectors,
293            0,
294            matrix_size,
295        );
296
297        self.queue.submit(Some(encoder.finish()));
298
299        // Map and read eigenvectors
300        let eigenvector_slice = staging_eigenvectors.slice(..);
301        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
302        eigenvector_slice.map_async(wgpu::MapMode::Read, move |result| {
303            sender.send(result).expect("oneshot channel receiver dropped");
304        });
305
306        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
307
308        receiver
309            .receive()
310            .await
311            .ok_or("Failed to receive mapping result")?
312            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
313
314        let mut result_eigenvectors = vec![0.0f32; n * n];
315        {
316            let data = eigenvector_slice.get_mapped_range();
317            let output_data: &[f32] = bytemuck::cast_slice(&data);
318            result_eigenvectors.copy_from_slice(output_data);
319        }
320        staging_eigenvectors.unmap();
321
322        // Extract eigenvalues from diagonal of working matrix
323        let eigenvalues: Vec<f32> = (0..n).map(|i| a[i * n + i]).collect();
324
325        Ok((eigenvalues, result_eigenvectors))
326    }
327
328    /// CPU fallback for small matrices (GPU overhead not worthwhile)
329    pub(super) fn symmetric_eigen_cpu(
330        &self,
331        matrix: &[f32],
332        n: usize,
333    ) -> Result<(Vec<f32>, Vec<f32>), String> {
334        let max_sweeps = 50;
335        let tolerance = 1e-7 * (matrix.iter().map(|x| x * x).sum::<f32>().sqrt()).max(1.0);
336
337        let mut a = matrix.to_vec();
338        let mut v = vec![0.0f32; n * n];
339        for i in 0..n {
340            v[i * n + i] = 1.0;
341        }
342
343        for _sweep in 0..max_sweeps {
344            let mut converged = true;
345
346            for i in 0..n {
347                for j in (i + 1)..n {
348                    let aij = a[i * n + j];
349
350                    if aij.abs() < tolerance {
351                        continue;
352                    }
353
354                    converged = false;
355
356                    let aii = a[i * n + i];
357                    let ajj = a[j * n + j];
358
359                    let tau = (ajj - aii) / (2.0 * aij);
360                    let t = if tau >= 0.0 {
361                        1.0 / (tau + (1.0 + tau * tau).sqrt())
362                    } else {
363                        -1.0 / (-tau + (1.0 + tau * tau).sqrt())
364                    };
365
366                    let c = 1.0 / (1.0 + t * t).sqrt();
367                    let s = t * c;
368
369                    // Update diagonal
370                    a[i * n + i] = aii - t * aij;
371                    a[j * n + j] = ajj + t * aij;
372                    a[i * n + j] = 0.0;
373                    a[j * n + i] = 0.0;
374
375                    // Update off-diagonal
376                    for k in 0..n {
377                        if k != i && k != j {
378                            let aki = a[k * n + i];
379                            let akj = a[k * n + j];
380                            a[k * n + i] = c * aki - s * akj;
381                            a[i * n + k] = a[k * n + i];
382                            a[k * n + j] = s * aki + c * akj;
383                            a[j * n + k] = a[k * n + j];
384                        }
385                    }
386
387                    // Update eigenvectors
388                    for k in 0..n {
389                        let vki = v[k * n + i];
390                        let vkj = v[k * n + j];
391                        v[k * n + i] = c * vki - s * vkj;
392                        v[k * n + j] = s * vki + c * vkj;
393                    }
394                }
395            }
396
397            if converged {
398                break;
399            }
400        }
401
402        let eigenvalues: Vec<f32> = (0..n).map(|i| a[i * n + i]).collect();
403        Ok((eigenvalues, v))
404    }
405}