1#![allow(dead_code)]
16
17use std::sync::Arc;
18
19use oxicuda_blas::GpuFloat;
20use oxicuda_driver::Module;
21use oxicuda_launch::{Kernel, LaunchParams};
22use oxicuda_memory::DeviceBuffer;
23use oxicuda_ptx::ir::PtxType;
24use oxicuda_ptx::prelude::*;
25
26use crate::error::{SolverError, SolverResult};
27use crate::handle::SolverHandle;
28use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
29
30const TRIDIAG_QR_MAX_ITER: u32 = 300;
32
33const TRIDIAG_QR_TOL: f64 = 1e-14;
35
36const TRIDIAG_BLOCK_SIZE: u32 = 64;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum EigJob {
46 ValuesOnly,
48 ValuesAndVectors,
50}
51
52pub fn syevd<T: GpuFloat>(
78 handle: &mut SolverHandle,
79 a: &mut DeviceBuffer<T>,
80 n: u32,
81 lda: u32,
82 eigenvalues: &mut DeviceBuffer<T>,
83 job: EigJob,
84) -> SolverResult<()> {
85 if n == 0 {
87 return Ok(());
88 }
89 if lda < n {
90 return Err(SolverError::DimensionMismatch(format!(
91 "syevd: lda ({lda}) must be >= n ({n})"
92 )));
93 }
94 let required = n as usize * lda as usize;
95 if a.len() < required {
96 return Err(SolverError::DimensionMismatch(format!(
97 "syevd: buffer too small ({} < {required})",
98 a.len()
99 )));
100 }
101 if eigenvalues.len() < n as usize {
102 return Err(SolverError::DimensionMismatch(format!(
103 "syevd: eigenvalues buffer too small ({} < {n})",
104 eigenvalues.len()
105 )));
106 }
107
108 let tau_size = n.saturating_sub(1) as usize * T::SIZE;
110 let diag_size = n as usize * std::mem::size_of::<f64>();
111 let off_diag_size = n.saturating_sub(1) as usize * std::mem::size_of::<f64>();
112 let ws_needed = tau_size + diag_size + off_diag_size;
113 handle.ensure_workspace(ws_needed)?;
114
115 let mut tau = DeviceBuffer::<T>::zeroed(n.saturating_sub(1) as usize)?;
117 tridiagonalize(handle, a, n, lda, &mut tau)?;
118
119 let mut d = vec![0.0_f64; n as usize];
121 let mut e = vec![0.0_f64; n.saturating_sub(1) as usize];
122 extract_tridiagonal::<T>(a, n, lda, &mut d, &mut e)?;
123
124 let mut vectors = if job == EigJob::ValuesAndVectors {
126 let mut v = vec![0.0_f64; n as usize * n as usize];
127 for i in 0..n as usize {
129 v[i * n as usize + i] = 1.0;
130 }
131 Some(v)
132 } else {
133 None
134 };
135
136 let converged = tridiagonal_qr(&mut d, &mut e, n, vectors.as_deref_mut())?;
137
138 if !converged {
139 return Err(SolverError::ConvergenceFailure {
140 iterations: TRIDIAG_QR_MAX_ITER,
141 residual: e.iter().map(|v| v * v).sum::<f64>().sqrt(),
142 });
143 }
144
145 sort_eigenvalues(&mut d, vectors.as_deref_mut(), n as usize);
147
148 let eig_stage = stage_eigenvalues_to_device::<T>(eigenvalues.len(), &d);
150 eigenvalues.copy_from_host(&eig_stage)?;
151
152 if job == EigJob::ValuesAndVectors {
154 if let Some(ref _vecs) = vectors {
155 back_transform_eigenvectors(handle, a, n, lda, &tau, vectors.as_deref())?;
158 }
159 }
160
161 Ok(())
162}
163
164fn tridiagonalize<T: GpuFloat>(
177 handle: &SolverHandle,
178 a: &mut DeviceBuffer<T>,
179 n: u32,
180 lda: u32,
181 tau: &mut DeviceBuffer<T>,
182) -> SolverResult<()> {
183 if n <= 1 {
184 return Ok(());
185 }
186
187 let sm = handle.sm_version();
188 let ptx = emit_tridiag_step::<T>(sm)?;
189 let module = Arc::new(Module::from_ptx(&ptx)?);
190 let kernel = Kernel::from_module(module, &tridiag_step_name::<T>())?;
191
192 let nb = TRIDIAG_BLOCK_SIZE.min(n - 1);
193 let num_blocks = (n - 1).div_ceil(nb);
194
195 for block_idx in 0..num_blocks {
196 let j = block_idx * nb;
197 let jb = nb.min(n - 1 - j);
198 let trailing = n - j;
199
200 let shared_bytes = trailing * jb * T::size_u32();
202 let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
203
204 let a_offset = (j as u64 + j as u64 * lda as u64) * T::SIZE as u64;
205 let tau_offset = j as u64 * T::SIZE as u64;
206
207 let args = (
208 a.as_device_ptr() + a_offset,
209 tau.as_device_ptr() + tau_offset,
210 trailing,
211 jb,
212 lda,
213 );
214 kernel.launch(¶ms, handle.stream(), &args)?;
215 }
216
217 Ok(())
218}
219
220fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
225 if T::SIZE == 8 {
226 f64::from_bits(val.to_bits_u64())
227 } else {
228 f64::from(f32::from_bits(val.to_bits_u64() as u32))
229 }
230}
231
232fn from_f64_to_t<T: GpuFloat>(val: f64) -> T {
233 if T::SIZE == 8 {
234 T::from_bits_u64(val.to_bits())
235 } else {
236 T::from_bits_u64(u64::from((val as f32).to_bits()))
237 }
238}
239
240fn extract_tridiagonal<T: GpuFloat>(
245 a: &DeviceBuffer<T>,
246 n: u32,
247 lda: u32,
248 d: &mut [f64],
249 e: &mut [f64],
250) -> SolverResult<()> {
251 let n_usize = n as usize;
252 let lda_usize = lda as usize;
253 let total = lda_usize * n_usize;
254 let mut host = vec![T::gpu_zero(); total];
255 a.copy_to_host(&mut host).map_err(|e_err| {
256 SolverError::InternalError(format!("extract_tridiagonal copy_to_host failed: {e_err}"))
257 })?;
258
259 for i in 0..n_usize {
261 d[i] = t_to_f64(host[i * lda_usize + i]);
262 }
263
264 for i in 0..n_usize.saturating_sub(1) {
266 e[i] = t_to_f64(host[i * lda_usize + (i + 1)]);
267 }
268
269 Ok(())
270}
271
272fn tridiagonal_qr(
283 d: &mut [f64],
284 e: &mut [f64],
285 n: u32,
286 mut vectors: Option<&mut [f64]>,
287) -> SolverResult<bool> {
288 let n_usize = n as usize;
289 if n_usize <= 1 {
290 return Ok(true);
291 }
292
293 let tol = TRIDIAG_QR_TOL;
294
295 for _iter in 0..TRIDIAG_QR_MAX_ITER {
296 let mut q = n_usize - 1;
298 while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
299 e[q - 1] = 0.0;
300 q -= 1;
301 }
302 if q == 0 {
303 return Ok(true);
304 }
305
306 let mut p = q - 1;
307 while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
308 p -= 1;
309 }
310
311 implicit_qr_step(d, e, p, q, vectors.as_deref_mut(), n_usize);
313 }
314
315 let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
317 Ok(off_norm <= tol)
318}
319
320fn implicit_qr_step(
325 d: &mut [f64],
326 e: &mut [f64],
327 start: usize,
328 end: usize,
329 mut vectors: Option<&mut [f64]>,
330 n: usize,
331) {
332 let delta = (d[end - 1] - d[end]) * 0.5;
334 let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
335 let e_sq = e[end - 1] * e[end - 1];
336 let mu = d[end] - e_sq / (delta + sign_delta * (delta * delta + e_sq).sqrt());
337
338 let mut x = d[start] - mu;
340 let mut z = e[start];
341
342 for k in start..end {
343 let (cs, sn) = givens_rotation(x, z);
345
346 if k > start {
348 e[k - 1] = cs * x + sn * z;
349 }
350 let dk = d[k];
351 let dk1 = d[k + 1];
352 let ek = e[k];
353
354 d[k] = cs * cs * dk + 2.0 * cs * sn * ek + sn * sn * dk1;
355 d[k + 1] = sn * sn * dk - 2.0 * cs * sn * ek + cs * cs * dk1;
356 e[k] = cs * sn * (dk1 - dk) + (cs * cs - sn * sn) * ek;
357
358 if k + 1 < end {
360 x = e[k];
361 z = sn * e[k + 1];
362 e[k + 1] *= cs;
363 }
364
365 if let Some(ref mut vecs) = vectors.as_deref_mut() {
367 for i in 0..n {
368 let vi_k = vecs[k * n + i];
369 let vi_k1 = vecs[(k + 1) * n + i];
370 vecs[k * n + i] = cs * vi_k + sn * vi_k1;
371 vecs[(k + 1) * n + i] = -sn * vi_k + cs * vi_k1;
372 }
373 }
374 }
375}
376
377fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
379 if b.abs() < 1e-300 {
380 return (1.0, 0.0);
381 }
382 if a.abs() < 1e-300 {
383 return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
384 }
385 let r = (a * a + b * b).sqrt();
386 (a / r, b / r)
387}
388
389fn sort_eigenvalues(d: &mut [f64], mut vectors: Option<&mut [f64]>, n: usize) {
391 for i in 0..n {
393 let mut min_idx = i;
394 let mut min_val = d[i];
395 for (offset, &val) in d[(i + 1)..n].iter().enumerate() {
396 if val < min_val {
397 min_val = val;
398 min_idx = i + 1 + offset;
399 }
400 }
401 if min_idx != i {
402 d.swap(i, min_idx);
403 if let Some(ref mut vecs) = vectors.as_deref_mut() {
404 for row in 0..n {
406 let a = i * n + row;
407 let b = min_idx * n + row;
408 vecs.swap(a, b);
409 }
410 }
411 }
412 }
413}
414
415fn back_transform_eigenvectors<T: GpuFloat>(
420 _handle: &SolverHandle,
421 a: &mut DeviceBuffer<T>,
422 n: u32,
423 lda: u32,
424 _tau: &DeviceBuffer<T>,
425 vectors: Option<&[f64]>,
426) -> SolverResult<()> {
427 let Some(vecs) = vectors else {
429 return Ok(());
430 };
431
432 let n_usize = n as usize;
433 let lda_usize = lda as usize;
434 let required = n_usize * lda_usize;
435 if a.len() < required {
436 return Err(SolverError::DimensionMismatch(format!(
437 "back_transform_eigenvectors: matrix buffer too small ({} < {required})",
438 a.len()
439 )));
440 }
441
442 let stage = stage_eigenvectors_col_major_to_lda::<T>(vecs, n_usize, lda_usize, a.len())?;
443 a.copy_from_host(&stage)?;
444
445 Ok(())
446}
447
448fn stage_eigenvalues_to_device<T: GpuFloat>(dst_len: usize, d: &[f64]) -> Vec<T> {
449 let mut out = vec![T::gpu_zero(); dst_len];
450 for (idx, &val) in d.iter().enumerate() {
451 if idx >= dst_len {
452 break;
453 }
454 out[idx] = from_f64_to_t(val);
455 }
456 out
457}
458
459fn stage_eigenvectors_col_major_to_lda<T: GpuFloat>(
460 vectors: &[f64],
461 n: usize,
462 lda: usize,
463 dst_len: usize,
464) -> SolverResult<Vec<T>> {
465 if vectors.len() < n * n {
466 return Err(SolverError::DimensionMismatch(format!(
467 "stage_eigenvectors_col_major_to_lda: vectors too small ({} < {})",
468 vectors.len(),
469 n * n
470 )));
471 }
472 if dst_len < n * lda {
473 return Err(SolverError::DimensionMismatch(format!(
474 "stage_eigenvectors_col_major_to_lda: destination too small ({} < {})",
475 dst_len,
476 n * lda
477 )));
478 }
479
480 let mut out = vec![T::gpu_zero(); dst_len];
481 for col in 0..n {
482 for row in 0..n {
483 out[col * lda + row] = from_f64_to_t(vectors[col * n + row]);
485 }
486 }
487 Ok(out)
488}
489
490fn tridiag_step_name<T: GpuFloat>() -> String {
495 format!("solver_tridiag_step_{}", T::NAME)
496}
497
498fn emit_tridiag_step<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
504 let name = tridiag_step_name::<T>();
505 let float_ty = T::PTX_TYPE;
506
507 let ptx = KernelBuilder::new(&name)
508 .target(sm)
509 .max_threads_per_block(SOLVER_BLOCK_SIZE)
510 .param("a_ptr", PtxType::U64)
511 .param("tau_ptr", PtxType::U64)
512 .param("trailing", PtxType::U32)
513 .param("jb", PtxType::U32)
514 .param("lda", PtxType::U32)
515 .body(move |b| {
516 let tid = b.thread_id_x();
517 let trailing = b.load_param_u32("trailing");
518 let jb = b.load_param_u32("jb");
519 let lda = b.load_param_u32("lda");
520
521 let _ = (tid, trailing, jb, lda, float_ty);
530
531 b.ret();
532 })
533 .build()?;
534
535 Ok(ptx)
536}
537
538#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn eig_job_equality() {
548 assert_eq!(EigJob::ValuesOnly, EigJob::ValuesOnly);
549 assert_ne!(EigJob::ValuesOnly, EigJob::ValuesAndVectors);
550 }
551
552 #[test]
553 fn givens_rotation_basic() {
554 let (cs, sn) = givens_rotation(3.0, 4.0);
555 let r = cs * 3.0 + sn * 4.0;
556 assert!((r - 5.0).abs() < 1e-10);
557 }
558
559 #[test]
560 fn givens_rotation_zero_b() {
561 let (cs, sn) = givens_rotation(5.0, 0.0);
562 assert!((cs - 1.0).abs() < 1e-15);
563 assert!(sn.abs() < 1e-15);
564 }
565
566 #[test]
567 fn sort_eigenvalues_basic() {
568 let mut d = vec![3.0, 1.0, 2.0];
569 sort_eigenvalues(&mut d, None, 3);
570 assert!((d[0] - 1.0).abs() < 1e-15);
571 assert!((d[1] - 2.0).abs() < 1e-15);
572 assert!((d[2] - 3.0).abs() < 1e-15);
573 }
574
575 #[test]
576 fn sort_eigenvalues_already_sorted() {
577 let mut d = vec![1.0, 2.0, 3.0];
578 sort_eigenvalues(&mut d, None, 3);
579 assert!((d[0] - 1.0).abs() < 1e-15);
580 assert!((d[2] - 3.0).abs() < 1e-15);
581 }
582
583 #[test]
584 fn tridiag_qr_trivial() {
585 let mut d = vec![1.0, 2.0, 3.0];
586 let mut e = vec![0.0, 0.0];
587 let result = tridiagonal_qr(&mut d, &mut e, 3, None);
588 assert!(result.is_ok());
589 assert!(result.ok() == Some(true));
590 }
591
592 #[test]
593 fn tridiag_qr_single() {
594 let mut d = vec![5.0];
595 let mut e: Vec<f64> = vec![];
596 let result = tridiagonal_qr(&mut d, &mut e, 1, None);
597 assert!(result.is_ok());
598 }
599
600 #[test]
601 fn tridiag_step_name_format() {
602 let name = tridiag_step_name::<f32>();
603 assert!(name.contains("f32"));
604 }
605
606 #[test]
607 fn tridiag_step_name_f64() {
608 let name = tridiag_step_name::<f64>();
609 assert!(name.contains("f64"));
610 }
611
612 #[test]
613 fn stage_eigenvalues_prefix_copy() {
614 let d = vec![1.5_f64, 2.5, 3.5];
615 let out = stage_eigenvalues_to_device::<f64>(5, &d);
616 assert_eq!(out.len(), 5);
617 assert_eq!(out[0], 1.5);
618 assert_eq!(out[1], 2.5);
619 assert_eq!(out[2], 3.5);
620 assert_eq!(out[3], 0.0);
621 assert_eq!(out[4], 0.0);
622 }
623
624 #[test]
625 fn stage_eigenvectors_to_lda_maps_columns() {
626 let vecs = vec![1.0_f64, 2.0, 3.0, 4.0];
628 let out = stage_eigenvectors_col_major_to_lda::<f64>(&vecs, 2, 3, 6);
629 assert!(out.is_ok());
630 let out = out.unwrap_or_default();
631 assert_eq!(out.len(), 6);
632 assert_eq!(out[0], 1.0);
634 assert_eq!(out[1], 2.0);
635 assert_eq!(out[3], 3.0);
637 assert_eq!(out[4], 4.0);
638 assert_eq!(out[2], 0.0);
640 assert_eq!(out[5], 0.0);
641 }
642}