Skip to main content

singe_cusolver/
svd.rs

1#[allow(unused_imports)]
2use crate::{eigen::xsyevd, error::Status};
3
4use std::ptr;
5
6use singe_cuda::{
7    data_type::{DataType, DataTypeLike},
8    memory::DeviceMemory,
9    types::{Complex32, Complex64},
10};
11
12use crate::{
13    context::Context,
14    error::{Error, Result},
15    layout::{
16        ByteWorkspaceMut, MatrixMut, MatrixRef, StridedBatchedMatrixMut, StridedBatchedMatrixRef,
17        StridedBatchedVectorMut, StridedBatchedVectorRef, WorkspaceSizes,
18    },
19    params::Params,
20    sys, try_ffi,
21    types::{EigenMode, SvdMode, TruncatedSvdMode},
22    utility::{to_i32, to_i64, to_usize},
23};
24
25#[derive(Debug)]
26pub struct GesvdjInfo {
27    handle: sys::gesvdjInfo_t,
28}
29
30// gesvdj info handles store solver options and expose mutation only through
31// &mut self, so immutable sharing is allowed.
32unsafe impl Send for GesvdjInfo {}
33unsafe impl Sync for GesvdjInfo {}
34
35impl GesvdjInfo {
36    /// Creates `gesvdj` and `gesvdjBatched` parameter storage with default values.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if cuSOLVER cannot allocate the parameter storage or if
41    /// it does not return a valid handle.
42    pub fn create() -> Result<Self> {
43        let mut handle = ptr::null_mut();
44        unsafe {
45            try_ffi!(sys::cusolverDnCreateGesvdjInfo(&raw mut handle))?;
46        }
47
48        if handle.is_null() {
49            return Err(Error::NullHandle);
50        }
51
52        Ok(Self { handle })
53    }
54
55    /// Configures the `gesvdj` tolerance.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if cuSOLVER rejects `tolerance`.
60    pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
61        unsafe {
62            try_ffi!(sys::cusolverDnXgesvdjSetTolerance(self.as_raw(), tolerance,))?;
63        }
64        Ok(())
65    }
66
67    /// Configures the maximum number of `gesvdj` sweeps.
68    /// The default value is 100.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if cuSOLVER rejects `max_sweeps`.
73    pub fn set_max_sweeps(&mut self, max_sweeps: i32) -> Result<()> {
74        unsafe {
75            try_ffi!(sys::cusolverDnXgesvdjSetMaxSweeps(
76                self.as_raw(),
77                max_sweeps,
78            ))?;
79        }
80        Ok(())
81    }
82
83    /// If `sort_eigenvalues` is false, the singular values are not sorted.
84    /// This setting only applies to `gesvdjBatched`.
85    /// `gesvdj` always sorts singular values in descending order.
86    /// By default, singular values are always sorted in descending order.
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if cuSOLVER rejects the sort setting.
91    pub fn set_sort_eigenvalues(&mut self, sort_eigenvalues: bool) -> Result<()> {
92        unsafe {
93            try_ffi!(sys::cusolverDnXgesvdjSetSortEig(
94                self.as_raw(),
95                i32::from(sort_eigenvalues),
96            ))?;
97        }
98        Ok(())
99    }
100
101    /// Returns the Frobenius norm of the internal residual reported by `gesvdj`.
102    /// Not the Frobenius norm of the exact residual.
103    ///
104    /// This accessor does not support `gesvdjBatched`.
105    /// Calling this after `gesvdjBatched` returns [`Status::NotSupported`].
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the info handle was used with `gesvdjBatched`,
110    /// which does not report a residual.
111    pub fn residual(&self, ctx: &Context) -> Result<f64> {
112        ctx.bind()?;
113
114        let mut residual = 0.0;
115        unsafe {
116            try_ffi!(sys::cusolverDnXgesvdjGetResidual(
117                ctx.as_raw(),
118                self.as_raw(),
119                &raw mut residual,
120            ))?;
121        }
122        Ok(residual)
123    }
124
125    /// Returns the number of executed `gesvdj` sweeps.
126    /// This accessor does not support `gesvdjBatched`.
127    /// Calling this after `gesvdjBatched` returns [`Status::NotSupported`].
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the info handle was used with `gesvdjBatched`,
132    /// which does not report a sweep count.
133    pub fn executed_sweeps(&self, ctx: &Context) -> Result<i32> {
134        ctx.bind()?;
135
136        let mut sweeps = 0;
137        unsafe {
138            try_ffi!(sys::cusolverDnXgesvdjGetSweeps(
139                ctx.as_raw(),
140                self.as_raw(),
141                &raw mut sweeps,
142            ))?;
143        }
144        Ok(sweeps)
145    }
146
147    pub fn as_raw(&self) -> sys::gesvdjInfo_t {
148        self.handle
149    }
150}
151
152impl Drop for GesvdjInfo {
153    fn drop(&mut self) {
154        unsafe {
155            if let Err(err) = try_ffi!(sys::cusolverDnDestroyGesvdjInfo(self.handle)) {
156                #[cfg(debug_assertions)]
157                eprintln!("failed to destroy cusolver gesvdj info: {err}");
158            }
159        }
160    }
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub struct Gesvd {
165    pub job_u: SvdMode,
166    pub job_vt: SvdMode,
167    pub rows: usize,
168    pub columns: usize,
169}
170
171impl Gesvd {
172    pub fn new(job_u: SvdMode, job_vt: SvdMode, rows: usize, columns: usize) -> Self {
173        Self {
174            job_u,
175            job_vt,
176            rows,
177            columns,
178        }
179    }
180
181    pub fn workspace_size<
182        TA: DataTypeLike,
183        TS: DataTypeLike,
184        TU: DataTypeLike,
185        TVT: DataTypeLike,
186    >(
187        self,
188        ctx: &Context,
189        params: &Params,
190        input: GesvdInput<'_, TA, TS, TU, TVT>,
191    ) -> Result<WorkspaceSizes> {
192        xgesvd_buffer_size(
193            ctx,
194            params,
195            self.job_u,
196            self.job_vt,
197            self.rows,
198            self.columns,
199            input.a,
200            input.singular_values,
201            input.left_vectors,
202            input.right_vectors_transposed,
203        )
204    }
205
206    pub fn execute<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
207        self,
208        ctx: &Context,
209        params: &Params,
210        bindings: GesvdBindings<'_, TA, TS, TU, TVT>,
211    ) -> Result<()> {
212        xgesvd(
213            ctx,
214            params,
215            self.job_u,
216            self.job_vt,
217            self.rows,
218            self.columns,
219            bindings.a,
220            bindings.singular_values,
221            bindings.left_vectors,
222            bindings.right_vectors_transposed,
223            bindings.workspace,
224            bindings.dev_info,
225        )
226    }
227}
228
229#[derive(Debug, Clone, Copy)]
230pub struct GesvdInput<'a, TA, TS, TU, TVT> {
231    pub a: MatrixRef<'a, TA>,
232    pub singular_values: &'a DeviceMemory<TS>,
233    pub left_vectors: Option<MatrixRef<'a, TU>>,
234    pub right_vectors_transposed: Option<MatrixRef<'a, TVT>>,
235}
236
237#[derive(Debug)]
238pub struct GesvdBindings<'a, TA, TS, TU, TVT> {
239    pub a: MatrixMut<'a, TA>,
240    pub singular_values: &'a mut DeviceMemory<TS>,
241    pub left_vectors: Option<MatrixMut<'a, TU>>,
242    pub right_vectors_transposed: Option<MatrixMut<'a, TVT>>,
243    pub workspace: ByteWorkspaceMut<'a>,
244    pub dev_info: &'a mut DeviceMemory<i32>,
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct Gesvdj {
249    pub mode: EigenMode,
250    pub economy: bool,
251    pub rows: usize,
252    pub columns: usize,
253}
254
255impl Gesvdj {
256    pub fn new(mode: EigenMode, economy: bool, rows: usize, columns: usize) -> Self {
257        Self {
258            mode,
259            economy,
260            rows,
261            columns,
262        }
263    }
264
265    pub fn workspace_size_f32(
266        self,
267        ctx: &Context,
268        input: GesvdjInput<'_, f32, f32>,
269        params: &GesvdjInfo,
270    ) -> Result<usize> {
271        sgesvdj_buffer_size(
272            ctx,
273            self.mode,
274            self.economy,
275            self.rows,
276            self.columns,
277            input.a,
278            input.singular_values,
279            input.left_vectors,
280            input.right_vectors,
281            params,
282        )
283    }
284
285    pub fn execute_f32(
286        self,
287        ctx: &Context,
288        bindings: GesvdjBindings<'_, f32, f32>,
289        params: &GesvdjInfo,
290    ) -> Result<()> {
291        sgesvdj(
292            ctx,
293            self.mode,
294            self.economy,
295            self.rows,
296            self.columns,
297            bindings.a,
298            bindings.singular_values,
299            bindings.left_vectors,
300            bindings.right_vectors,
301            bindings.workspace,
302            bindings.dev_info,
303            params,
304        )
305    }
306
307    pub fn workspace_size_f64(
308        self,
309        ctx: &Context,
310        input: GesvdjInput<'_, f64, f64>,
311        params: &GesvdjInfo,
312    ) -> Result<usize> {
313        dgesvdj_buffer_size(
314            ctx,
315            self.mode,
316            self.economy,
317            self.rows,
318            self.columns,
319            input.a,
320            input.singular_values,
321            input.left_vectors,
322            input.right_vectors,
323            params,
324        )
325    }
326
327    pub fn execute_f64(
328        self,
329        ctx: &Context,
330        bindings: GesvdjBindings<'_, f64, f64>,
331        params: &GesvdjInfo,
332    ) -> Result<()> {
333        dgesvdj(
334            ctx,
335            self.mode,
336            self.economy,
337            self.rows,
338            self.columns,
339            bindings.a,
340            bindings.singular_values,
341            bindings.left_vectors,
342            bindings.right_vectors,
343            bindings.workspace,
344            bindings.dev_info,
345            params,
346        )
347    }
348
349    pub fn workspace_size_complex_f32(
350        self,
351        ctx: &Context,
352        input: GesvdjInput<'_, Complex32, f32>,
353        params: &GesvdjInfo,
354    ) -> Result<usize> {
355        cgesvdj_buffer_size(
356            ctx,
357            self.mode,
358            self.economy,
359            self.rows,
360            self.columns,
361            input.a,
362            input.singular_values,
363            input.left_vectors,
364            input.right_vectors,
365            params,
366        )
367    }
368
369    pub fn execute_complex_f32(
370        self,
371        ctx: &Context,
372        bindings: GesvdjBindings<'_, Complex32, f32>,
373        params: &GesvdjInfo,
374    ) -> Result<()> {
375        cgesvdj(
376            ctx,
377            self.mode,
378            self.economy,
379            self.rows,
380            self.columns,
381            bindings.a,
382            bindings.singular_values,
383            bindings.left_vectors,
384            bindings.right_vectors,
385            bindings.workspace,
386            bindings.dev_info,
387            params,
388        )
389    }
390
391    pub fn workspace_size_complex_f64(
392        self,
393        ctx: &Context,
394        input: GesvdjInput<'_, Complex64, f64>,
395        params: &GesvdjInfo,
396    ) -> Result<usize> {
397        zgesvdj_buffer_size(
398            ctx,
399            self.mode,
400            self.economy,
401            self.rows,
402            self.columns,
403            input.a,
404            input.singular_values,
405            input.left_vectors,
406            input.right_vectors,
407            params,
408        )
409    }
410
411    pub fn execute_complex_f64(
412        self,
413        ctx: &Context,
414        bindings: GesvdjBindings<'_, Complex64, f64>,
415        params: &GesvdjInfo,
416    ) -> Result<()> {
417        zgesvdj(
418            ctx,
419            self.mode,
420            self.economy,
421            self.rows,
422            self.columns,
423            bindings.a,
424            bindings.singular_values,
425            bindings.left_vectors,
426            bindings.right_vectors,
427            bindings.workspace,
428            bindings.dev_info,
429            params,
430        )
431    }
432}
433
434#[derive(Debug, Clone, Copy)]
435pub struct GesvdjInput<'a, TA, TS> {
436    pub a: MatrixRef<'a, TA>,
437    pub singular_values: &'a DeviceMemory<TS>,
438    pub left_vectors: Option<MatrixRef<'a, TA>>,
439    pub right_vectors: Option<MatrixRef<'a, TA>>,
440}
441
442#[derive(Debug)]
443pub struct GesvdjBindings<'a, TA, TS> {
444    pub a: MatrixMut<'a, TA>,
445    pub singular_values: &'a mut DeviceMemory<TS>,
446    pub left_vectors: Option<MatrixMut<'a, TA>>,
447    pub right_vectors: Option<MatrixMut<'a, TA>>,
448    pub workspace: &'a mut DeviceMemory<TA>,
449    pub dev_info: &'a mut DeviceMemory<i32>,
450}
451
452pub fn xgesvd_buffer_size<
453    TA: DataTypeLike,
454    TS: DataTypeLike,
455    TU: DataTypeLike,
456    TVT: DataTypeLike,
457>(
458    ctx: &Context,
459    params: &Params,
460    job_u: SvdMode,
461    job_vt: SvdMode,
462    m: usize,
463    n: usize,
464    a: MatrixRef<'_, TA>,
465    s: &DeviceMemory<TS>,
466    u: Option<MatrixRef<'_, TU>>,
467    vt: Option<MatrixRef<'_, TVT>>,
468) -> Result<WorkspaceSizes> {
469    let a_type = TA::data_type();
470    let s_type = TS::data_type();
471    let u_type = TU::data_type();
472    let vt_type = TVT::data_type();
473    ctx.bind()?;
474    validate_gesvd_dims(m, n)?;
475    validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
476    validate_x_vector(m.min(n), s.byte_len(), s_type)?;
477    validate_x_svd_output(m, m, matrix_ref_parts(u), job_u, u_type)?;
478    validate_x_svd_output(n, n, matrix_ref_parts(vt), job_vt, vt_type)?;
479    if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
480        return Err(Error::InvalidSvdMode);
481    }
482
483    let (u_ptr, ldu) = optional_x_matrix_ptr(matrix_ref_parts(u), m, m, job_u, u_type)?;
484    let (vt_ptr, ldvt) = optional_x_matrix_ptr(matrix_ref_parts(vt), n, n, job_vt, vt_type)?;
485    let mut device_bytes = 0;
486    let mut host_bytes = 0;
487    unsafe {
488        try_ffi!(sys::cusolverDnXgesvd_bufferSize(
489            ctx.as_raw(),
490            params.as_raw(),
491            job_u.as_raw(),
492            job_vt.as_raw(),
493            to_i64(m, "m")?,
494            to_i64(n, "n")?,
495            a_type.into(),
496            a.data.as_ptr().cast(),
497            to_i64(a.leading_dimension, "lda")?,
498            s_type.into(),
499            s.as_ptr().cast(),
500            u_type.into(),
501            u_ptr.cast(),
502            ldu,
503            vt_type.into(),
504            vt_ptr.cast(),
505            ldvt,
506            a_type.into(),
507            &raw mut device_bytes,
508            &raw mut host_bytes,
509        ))?;
510    }
511    Ok(WorkspaceSizes::new(
512        device_bytes as usize,
513        host_bytes as usize,
514    ))
515}
516
517/// Use [`xgesvd_buffer_size`] to calculate the sizes needed for pre-allocated
518/// workspace.
519///
520/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix
521/// `A` and the corresponding left and/or right singular vectors.
522/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
523/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
524/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
525/// $n \times n$ unitary matrix.
526/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
527/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
528///
529/// Provide device and host workspace through `workspace`.
530/// Use [`xgesvd_buffer_size`] to determine the required sizes for
531/// `workspace.device` and `workspace.host`.
532///
533/// If the reported `info` value is `-i`, the `i`th parameter is invalid. If `bdsqr` did not converge, `info` specifies how many superdiagonals of an intermediate bidiagonal form did not converge to zero.
534///
535/// Currently, [`xgesvd`] supports only the default algorithm.
536///
537/// **Algorithms supported by [`xgesvd`]**
538///
539/// | Algorithm | Notes |
540/// | --- | --- |
541/// | [`AlgorithmMode::Default`](crate::types::AlgorithmMode::Default) | Default algorithm. |
542///
543/// `gesvd` only supports `m >= n`.
544///
545/// Returns $V^H$, not `V`.
546///
547/// List of input arguments for [`xgesvd_buffer_size`] and [`xgesvd`]:
548///
549/// The generic cuSOLVER routine separates matrix, singular-value, vector, and compute data
550/// types: `data_type_a` is the data type of matrix `A`, `data_type_s` is the
551/// data type of vector `S`, `data_type_u` is the data type of matrix `U`,
552/// `data_type_vt` is the data type of matrix `VT`, and `compute_type` is the
553/// operation's compute type.
554/// [`xgesvd`] only supports the following four combinations.
555///
556/// **Valid combination of data type and compute type**
557///
558/// | **data_type_a** | **data_type_s** | **data_type_u** | **data_type_vt** | **compute_type** | **Meaning** |
559/// | --- | --- | --- | --- | --- | --- |
560/// | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | `SGESVD` |
561/// | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | `DGESVD` |
562/// | [`DataType::ComplexF32`] | [`DataType::F32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | `CGESVD` |
563/// | [`DataType::ComplexF64`] | [`DataType::F64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | `ZGESVD` |
564///
565/// # Errors
566///
567/// Returns an error if cuSOLVER has not been initialized, if the
568/// matrix dimensions, leading dimensions, output modes, or output buffers are
569/// invalid, or if cuSOLVER reports an internal failure.
570pub fn xgesvd<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
571    ctx: &Context,
572    params: &Params,
573    job_u: SvdMode,
574    job_vt: SvdMode,
575    m: usize,
576    n: usize,
577    a: MatrixMut<'_, TA>,
578    s: &mut DeviceMemory<TS>,
579    u: Option<MatrixMut<'_, TU>>,
580    vt: Option<MatrixMut<'_, TVT>>,
581    workspace: ByteWorkspaceMut<'_>,
582    dev_info: &mut DeviceMemory<i32>,
583) -> Result<()> {
584    let a_type = TA::data_type();
585    let s_type = TS::data_type();
586    let u_type = TU::data_type();
587    let vt_type = TVT::data_type();
588    ctx.bind()?;
589    validate_gesvd_dims(m, n)?;
590    validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
591    validate_x_vector(m.min(n), s.byte_len(), s_type)?;
592    validate_x_svd_output(m, m, matrix_mut_ref_parts(u.as_ref()), job_u, u_type)?;
593    validate_x_svd_output(n, n, matrix_mut_ref_parts(vt.as_ref()), job_vt, vt_type)?;
594    if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
595        return Err(Error::InvalidSvdMode);
596    }
597    require_info_buffer(dev_info)?;
598
599    let workspace_sizes = xgesvd_buffer_size(
600        ctx,
601        params,
602        job_u,
603        job_vt,
604        m,
605        n,
606        a.as_ref(),
607        s,
608        matrix_mut_ref_option(u.as_ref()),
609        matrix_mut_ref_option(vt.as_ref()),
610    )?;
611    require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
612    require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
613
614    let (u_ptr, ldu) = optional_x_matrix_mut_ptr(matrix_mut_parts(u), m, m, job_u, u_type)?;
615    let (vt_ptr, ldvt) = optional_x_matrix_mut_ptr(matrix_mut_parts(vt), n, n, job_vt, vt_type)?;
616    unsafe {
617        try_ffi!(sys::cusolverDnXgesvd(
618            ctx.as_raw(),
619            params.as_raw(),
620            job_u.as_raw(),
621            job_vt.as_raw(),
622            to_i64(m, "m")?,
623            to_i64(n, "n")?,
624            a_type.into(),
625            a.data.as_mut_ptr().cast(),
626            to_i64(a.leading_dimension, "lda")?,
627            s_type.into(),
628            s.as_mut_ptr().cast(),
629            u_type.into(),
630            u_ptr.cast(),
631            ldu,
632            vt_type.into(),
633            vt_ptr.cast(),
634            ldvt,
635            a_type.into(),
636            workspace.device.as_mut_ptr().cast(),
637            workspace_sizes.device_bytes as _,
638            workspace.host.as_mut_ptr().cast(),
639            workspace_sizes.host_bytes as _,
640            dev_info.as_mut_ptr().cast(),
641        ))?;
642    }
643    Ok(())
644}
645
646pub fn xgesvdp_buffer_size<
647    TA: DataTypeLike,
648    TS: DataTypeLike,
649    TU: DataTypeLike,
650    TV: DataTypeLike,
651>(
652    ctx: &Context,
653    params: &Params,
654    jobz: EigenMode,
655    econ: bool,
656    m: usize,
657    n: usize,
658    a: MatrixRef<'_, TA>,
659    s: &DeviceMemory<TS>,
660    u: Option<MatrixRef<'_, TU>>,
661    v: Option<MatrixRef<'_, TV>>,
662) -> Result<WorkspaceSizes> {
663    let a_type = TA::data_type();
664    let s_type = TS::data_type();
665    let u_type = TU::data_type();
666    let v_type = TV::data_type();
667    ctx.bind()?;
668    validate_xgesvdp_inputs(
669        m,
670        n,
671        a.data.byte_len(),
672        a.leading_dimension,
673        a_type,
674        s.byte_len(),
675        s_type,
676        jobz,
677        econ,
678        matrix_ref_parts(u).as_ref(),
679        u_type,
680        matrix_ref_parts(v).as_ref(),
681        v_type,
682    )?;
683    let (u_ptr, ldu) = optional_x_eig_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ, u_type)?;
684    let (v_ptr, ldv) = optional_x_eig_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ, v_type)?;
685    let mut device_bytes = 0;
686    let mut host_bytes = 0;
687    unsafe {
688        try_ffi!(sys::cusolverDnXgesvdp_bufferSize(
689            ctx.as_raw(),
690            params.as_raw(),
691            jobz.into(),
692            i32::from(econ),
693            to_i64(m, "m")?,
694            to_i64(n, "n")?,
695            a_type.into(),
696            a.data.as_ptr().cast(),
697            to_i64(a.leading_dimension, "lda")?,
698            s_type.into(),
699            s.as_ptr().cast(),
700            u_type.into(),
701            u_ptr.cast(),
702            ldu,
703            v_type.into(),
704            v_ptr.cast(),
705            ldv,
706            a_type.into(),
707            &raw mut device_bytes,
708            &raw mut host_bytes,
709        ))?;
710    }
711    Ok(WorkspaceSizes::new(
712        device_bytes as usize,
713        host_bytes as usize,
714    ))
715}
716
717/// Use [`xgesvdp_buffer_size`] to calculate the sizes needed for pre-allocated
718/// workspace.
719///
720/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix
721/// `A` and the corresponding left and/or right singular vectors.
722/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
723/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
724/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
725/// $n \times n$ unitary matrix.
726/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
727/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
728///
729/// [`xgesvdp`] combines polar decomposition in \[14\] and the symmetric
730/// eigensolver used by this crate to compute the SVD.
731/// It is much faster than [`xgesvd`], which is based on a QR algorithm.
732/// However polar decomposition in \[14\] may not deliver a full unitary matrix when the matrix A has a singular value close to zero.
733/// To workaround the issue when the singular value is close to zero, we add a small perturbation so polar decomposition can deliver the correct result.
734/// The consequence is inaccurate singular values shifted by this perturbation.
735/// `residual` stores the magnitude of this perturbation when requested.
736/// In other words, it reports the accuracy of the SVD approximation.
737///
738/// Provide device and host workspace through `workspace`.
739/// Use [`xgesvdp_buffer_size`] to determine the required sizes for
740/// `workspace.device` and `workspace.host`.
741///
742/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
743///
744/// Currently, [`xgesvdp`] supports only the default algorithm.
745///
746/// **Algorithms supported by [`xgesvdp`]**
747///
748/// | Algorithm | Notes |
749/// | --- | --- |
750/// | [`AlgorithmMode::Default`](crate::types::AlgorithmMode::Default) | Default algorithm. |
751///
752/// `gesvdp` also supports `n >= m`.
753///
754/// Returns `V`, not $V^{H}$.
755///
756/// List of input arguments for [`xgesvdp_buffer_size`] and [`xgesvdp`]:
757///
758/// The generic cuSOLVER routine separates matrix, singular-value, vector, and compute data
759/// types: `data_type_a` is the data type of matrix `A`, `data_type_s` is the
760/// data type of vector `S`, `data_type_u` is the data type of matrix `U`,
761/// `data_type_v` is the data type of matrix `V`, and `compute_type` is the
762/// operation's compute type.
763/// [`xgesvdp`] only supports the following four combinations:
764///
765/// **Valid combination of data type and compute type**
766///
767/// | **data_type_a** | **data_type_s** | **data_type_u** | **data_type_v** | **compute_type** | **Meaning** |
768/// | --- | --- | --- | --- | --- | --- |
769/// | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | `SGESVDP` |
770/// | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | `DGESVDP` |
771/// | [`DataType::ComplexF32`] | [`DataType::F32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | `CGESVDP` |
772/// | [`DataType::ComplexF64`] | [`DataType::F64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | `ZGESVDP` |
773///
774/// # Errors
775///
776/// Returns an error if cuSOLVER has not been initialized, if the
777/// matrix dimensions or leading dimensions are invalid, or if cuSOLVER reports
778/// an internal failure.
779pub fn xgesvdp<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
780    ctx: &Context,
781    params: &Params,
782    jobz: EigenMode,
783    econ: bool,
784    m: usize,
785    n: usize,
786    a: MatrixMut<'_, TA>,
787    s: &mut DeviceMemory<TS>,
788    u: Option<MatrixMut<'_, TU>>,
789    v: Option<MatrixMut<'_, TV>>,
790    workspace: ByteWorkspaceMut<'_>,
791    dev_info: &mut DeviceMemory<i32>,
792    err_sigma: Option<&mut f64>,
793) -> Result<()> {
794    let a_type = TA::data_type();
795    let s_type = TS::data_type();
796    let u_type = TU::data_type();
797    let v_type = TV::data_type();
798    ctx.bind()?;
799    validate_xgesvdp_inputs(
800        m,
801        n,
802        a.data.byte_len(),
803        a.leading_dimension,
804        a_type,
805        s.byte_len(),
806        s_type,
807        jobz,
808        econ,
809        matrix_mut_ref_parts(u.as_ref()).as_ref(),
810        u_type,
811        matrix_mut_ref_parts(v.as_ref()).as_ref(),
812        v_type,
813    )?;
814    require_info_buffer(dev_info)?;
815    let workspace_sizes = xgesvdp_buffer_size(
816        ctx,
817        params,
818        jobz,
819        econ,
820        m,
821        n,
822        a.as_ref(),
823        s,
824        matrix_mut_ref_option(u.as_ref()),
825        matrix_mut_ref_option(v.as_ref()),
826    )?;
827    require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
828    require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
829
830    let (u_ptr, ldu) =
831        optional_x_eig_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ, u_type)?;
832    let (v_ptr, ldv) =
833        optional_x_eig_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ, v_type)?;
834    unsafe {
835        try_ffi!(sys::cusolverDnXgesvdp(
836            ctx.as_raw(),
837            params.as_raw(),
838            jobz.into(),
839            i32::from(econ),
840            to_i64(m, "m")?,
841            to_i64(n, "n")?,
842            a_type.into(),
843            a.data.as_mut_ptr().cast(),
844            to_i64(a.leading_dimension, "lda")?,
845            s_type.into(),
846            s.as_mut_ptr().cast(),
847            u_type.into(),
848            u_ptr.cast(),
849            ldu,
850            v_type.into(),
851            v_ptr.cast(),
852            ldv,
853            a_type.into(),
854            workspace.device.as_mut_ptr().cast(),
855            workspace_sizes.device_bytes as _,
856            workspace.host.as_mut_ptr().cast(),
857            workspace_sizes.host_bytes as _,
858            dev_info.as_mut_ptr().cast(),
859            err_sigma.map_or(ptr::null_mut(), |value| value as *mut f64),
860        ))?;
861    }
862    Ok(())
863}
864
865pub fn xgesvdr_buffer_size<
866    TA: DataTypeLike,
867    TS: DataTypeLike,
868    TU: DataTypeLike,
869    TV: DataTypeLike,
870>(
871    ctx: &Context,
872    params: &Params,
873    job_u: TruncatedSvdMode,
874    job_v: TruncatedSvdMode,
875    m: usize,
876    n: usize,
877    k: usize,
878    p: usize,
879    niters: usize,
880    a: MatrixRef<'_, TA>,
881    s: &DeviceMemory<TS>,
882    u: Option<MatrixRef<'_, TU>>,
883    v: Option<MatrixRef<'_, TV>>,
884) -> Result<WorkspaceSizes> {
885    let a_type = TA::data_type();
886    let s_type = TS::data_type();
887    let u_type = TU::data_type();
888    let v_type = TV::data_type();
889    ctx.bind()?;
890    validate_xgesvdr_inputs(
891        m,
892        n,
893        k,
894        p,
895        niters,
896        a.data.byte_len(),
897        a.leading_dimension,
898        a_type,
899        s.byte_len(),
900        s_type,
901        job_u,
902        matrix_ref_parts(u).as_ref(),
903        u_type,
904        job_v,
905        matrix_ref_parts(v).as_ref(),
906        v_type,
907    )?;
908    let (u_ptr, ldu) = optional_x_truncated_u_ptr(matrix_ref_parts(u), m, k, job_u, u_type)?;
909    let (v_ptr, ldv) = optional_x_truncated_v_ptr(matrix_ref_parts(v), n, k, job_v, v_type)?;
910    let mut device_bytes = 0;
911    let mut host_bytes = 0;
912    unsafe {
913        try_ffi!(sys::cusolverDnXgesvdr_bufferSize(
914            ctx.as_raw(),
915            params.as_raw(),
916            job_u.as_raw(),
917            job_v.as_raw(),
918            to_i64(m, "m")?,
919            to_i64(n, "n")?,
920            to_i64(k, "k")?,
921            to_i64(p, "p")?,
922            to_i64(niters, "niters")?,
923            a_type.into(),
924            a.data.as_ptr().cast(),
925            to_i64(a.leading_dimension, "lda")?,
926            s_type.into(),
927            s.as_ptr().cast(),
928            u_type.into(),
929            u_ptr.cast(),
930            ldu,
931            v_type.into(),
932            v_ptr.cast(),
933            ldv,
934            a_type.into(),
935            &raw mut device_bytes,
936            &raw mut host_bytes,
937        ))?;
938    }
939    Ok(WorkspaceSizes::new(
940        device_bytes as usize,
941        host_bytes as usize,
942    ))
943}
944
945/// Use [`xgesvdr_buffer_size`] to calculate the sizes needed for pre-allocated
946/// workspace.
947///
948/// Computes the approximate rank-k singular value decomposition (k-SVD) of an
949/// $m \times n$ matrix `A` and the corresponding left and/or right singular
950/// vectors.
951/// The k-SVD is written as
952///
953/// where `Σ` is a $k \times k$ matrix which is zero except for its diagonal elements, `U` is an $m \times k$ orthonormal matrix, and `V` is an $k \times n$ orthonormal matrix.
954/// The diagonal elements of `Σ` are the approximated singular values of `A`; they are real and non-negative, and are returned in descending order.
955/// The columns of `U` and `V` are the top-`k` left and right singular vectors of `A`.
956///
957/// [`xgesvdr`] implements randomized methods described in \[15\] to compute k-SVD that is accurate with high probability if the conditions described in \[15\] hold.
958/// [`xgesvdr`] is intended to compute a small portion of the spectrum of `A`
959/// quickly and accurately, especially when `k` is much smaller than
960/// `min(m,n)` and the matrix dimensions are large.
961///
962/// The accuracy of the method depends on the spectrum of `A`, the number of power iterations `niters`, the oversampling parameter `p` and the ratio between `p` and the dimensions of the matrix `A`.
963/// Larger values of oversampling `p` or more iterations `niters` may produce
964/// more accurate approximations, but also increase the run time of
965/// [`xgesvdr`].
966///
967/// Our recommendation is to use two iterations and set the oversampling to at least `2k`.
968/// Once the solver provides enough accuracy, adjust the values of `k` and `niters` for better performance.
969///
970/// Provide device and host workspace through `workspace`.
971/// Use [`xgesvdr_buffer_size`] to determine the required sizes for
972/// `workspace.device` and `workspace.host`.
973///
974/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
975///
976/// Currently, [`xgesvdr`] supports only the default algorithm.
977///
978/// **Algorithms supported by [`xgesvdr`]**
979///
980/// | Algorithm | Notes |
981/// | --- | --- |
982/// | [`AlgorithmMode::Default`](crate::types::AlgorithmMode::Default) | Default algorithm. |
983///
984/// `gesvdr` also supports `n >= m`.
985///
986/// Returns `V`, not $V^{H}$.
987///
988/// List of input arguments for [`xgesvdr_buffer_size`] and [`xgesvdr`]:
989///
990/// The generic cuSOLVER routine separates matrix, singular-value, vector, and compute data
991/// types: `data_type_a` is the data type of matrix `A`, `data_type_s` is the
992/// data type of vector `S`, `data_type_u` is the data type of matrix `U`,
993/// `data_type_v` is the data type of matrix `V`, and `compute_type` is the
994/// operation's compute type.
995/// [`xgesvdr`] only supports the following four combinations.
996///
997/// **Valid combination of data type and compute type**
998///
999/// | **data_type_a** | **data_type_s** | **data_type_u** | **data_type_v** | **compute_type** | **Meaning** |
1000/// | --- | --- | --- | --- | --- | --- |
1001/// | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | [`DataType::F32`] | `SGESVDR` |
1002/// | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | [`DataType::F64`] | `DGESVDR` |
1003/// | [`DataType::ComplexF32`] | [`DataType::F32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | [`DataType::ComplexF32`] | `CGESVDR` |
1004/// | [`DataType::ComplexF64`] | [`DataType::F64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | [`DataType::ComplexF64`] | `ZGESVDR` |
1005///
1006/// # Errors
1007///
1008/// Returns an error if cuSOLVER has not been initialized, if the
1009/// matrix dimensions or leading dimensions are invalid, or if cuSOLVER reports
1010/// an internal failure.
1011pub fn xgesvdr<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
1012    ctx: &Context,
1013    params: &Params,
1014    job_u: TruncatedSvdMode,
1015    job_v: TruncatedSvdMode,
1016    m: usize,
1017    n: usize,
1018    k: usize,
1019    p: usize,
1020    niters: usize,
1021    a: MatrixMut<'_, TA>,
1022    s: &mut DeviceMemory<TS>,
1023    u: Option<MatrixMut<'_, TU>>,
1024    v: Option<MatrixMut<'_, TV>>,
1025    workspace: ByteWorkspaceMut<'_>,
1026    dev_info: &mut DeviceMemory<i32>,
1027) -> Result<()> {
1028    let a_type = TA::data_type();
1029    let s_type = TS::data_type();
1030    let u_type = TU::data_type();
1031    let v_type = TV::data_type();
1032    ctx.bind()?;
1033    validate_xgesvdr_inputs(
1034        m,
1035        n,
1036        k,
1037        p,
1038        niters,
1039        a.data.byte_len(),
1040        a.leading_dimension,
1041        a_type,
1042        s.byte_len(),
1043        s_type,
1044        job_u,
1045        matrix_mut_ref_parts(u.as_ref()).as_ref(),
1046        u_type,
1047        job_v,
1048        matrix_mut_ref_parts(v.as_ref()).as_ref(),
1049        v_type,
1050    )?;
1051    require_info_buffer(dev_info)?;
1052    let workspace_sizes = xgesvdr_buffer_size(
1053        ctx,
1054        params,
1055        job_u,
1056        job_v,
1057        m,
1058        n,
1059        k,
1060        p,
1061        niters,
1062        a.as_ref(),
1063        s,
1064        matrix_mut_ref_option(u.as_ref()),
1065        matrix_mut_ref_option(v.as_ref()),
1066    )?;
1067    require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
1068    require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
1069    let (u_ptr, ldu) = optional_x_truncated_u_mut_ptr(matrix_mut_parts(u), m, k, job_u, u_type)?;
1070    let (v_ptr, ldv) = optional_x_truncated_v_mut_ptr(matrix_mut_parts(v), n, k, job_v, v_type)?;
1071    unsafe {
1072        try_ffi!(sys::cusolverDnXgesvdr(
1073            ctx.as_raw(),
1074            params.as_raw(),
1075            job_u.as_raw(),
1076            job_v.as_raw(),
1077            to_i64(m, "m")?,
1078            to_i64(n, "n")?,
1079            to_i64(k, "k")?,
1080            to_i64(p, "p")?,
1081            to_i64(niters, "niters")?,
1082            a_type.into(),
1083            a.data.as_mut_ptr().cast(),
1084            to_i64(a.leading_dimension, "lda")?,
1085            s_type.into(),
1086            s.as_mut_ptr().cast(),
1087            u_type.into(),
1088            u_ptr.cast(),
1089            ldu,
1090            v_type.into(),
1091            v_ptr.cast(),
1092            ldv,
1093            a_type.into(),
1094            workspace.device.as_mut_ptr().cast(),
1095            workspace_sizes.device_bytes as _,
1096            workspace.host.as_mut_ptr().cast(),
1097            workspace_sizes.host_bytes as _,
1098            dev_info.as_mut_ptr().cast(),
1099        ))?;
1100    }
1101    Ok(())
1102}
1103
1104pub fn sgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1105    ctx.bind()?;
1106    validate_gesvd_dims(m, n)?;
1107    let mut lwork = 0;
1108    unsafe {
1109        try_ffi!(sys::cusolverDnSgesvd_bufferSize(
1110            ctx.as_raw(),
1111            to_i32(m, "m")?,
1112            to_i32(n, "n")?,
1113            &raw mut lwork,
1114        ))?;
1115    }
1116    to_usize(lwork, "lwork")
1117}
1118
1119pub fn dgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1120    ctx.bind()?;
1121    validate_gesvd_dims(m, n)?;
1122    let mut lwork = 0;
1123    unsafe {
1124        try_ffi!(sys::cusolverDnDgesvd_bufferSize(
1125            ctx.as_raw(),
1126            to_i32(m, "m")?,
1127            to_i32(n, "n")?,
1128            &raw mut lwork,
1129        ))?;
1130    }
1131    to_usize(lwork, "lwork")
1132}
1133
1134pub fn cgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1135    ctx.bind()?;
1136    validate_gesvd_dims(m, n)?;
1137    let mut lwork = 0;
1138    unsafe {
1139        try_ffi!(sys::cusolverDnCgesvd_bufferSize(
1140            ctx.as_raw(),
1141            to_i32(m, "m")?,
1142            to_i32(n, "n")?,
1143            &raw mut lwork,
1144        ))?;
1145    }
1146    to_usize(lwork, "lwork")
1147}
1148
1149pub fn zgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
1150    ctx.bind()?;
1151    validate_gesvd_dims(m, n)?;
1152    let mut lwork = 0;
1153    unsafe {
1154        try_ffi!(sys::cusolverDnZgesvd_bufferSize(
1155            ctx.as_raw(),
1156            to_i32(m, "m")?,
1157            to_i32(n, "n")?,
1158            &raw mut lwork,
1159        ))?;
1160    }
1161    to_usize(lwork, "lwork")
1162}
1163
1164/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1165///
1166/// The S and D data types are real valued single and double precision, respectively.
1167///
1168/// The C and Z data types are complex valued single and double precision, respectively.
1169///
1170/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1171/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
1172/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
1173/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
1174/// $n \times n$ unitary matrix.
1175/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1176/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1177///
1178/// Provide workspace through `workspace`.
1179/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1180/// The workspace size in bytes is `size_of::<T>() * lwork`.
1181///
1182/// If the reported `dev_info` value is `-i`, the `i`th parameter is invalid. If `bdsqr` did not converge, `dev_info` specifies how many superdiagonals of an intermediate bidiagonal form did not converge to zero.
1183///
1184/// `rwork` is a real workspace buffer with length `min(m, n) - 1`.
1185/// If `dev_info > 0` and `rwork` is `Some(_)`, it contains the unconverged superdiagonal elements of an upper bidiagonal matrix.
1186/// This is slightly different from LAPACK which puts unconverged superdiagonal elements in `work` if type is `real`; in `rwork` if type is `complex`.
1187/// Pass `None` for `rwork` when the unconverged superdiagonal elements are not needed.
1188///
1189/// - `gesvd` only supports `m >= n`.
1190///
1191/// - Returns $V^{H}$, not `V`.
1192///
1193/// # Errors
1194///
1195/// Returns an error if cuSOLVER has not been initialized, if the
1196/// matrix dimensions or leading dimensions are invalid, if the current GPU
1197/// architecture is unsupported, or if cuSOLVER reports an internal failure.
1198pub fn sgesvd(
1199    ctx: &Context,
1200    job_u: SvdMode,
1201    job_vt: SvdMode,
1202    m: usize,
1203    n: usize,
1204    a: MatrixMut<'_, f32>,
1205    s: &mut DeviceMemory<f32>,
1206    u: Option<MatrixMut<'_, f32>>,
1207    vt: Option<MatrixMut<'_, f32>>,
1208    workspace: &mut DeviceMemory<f32>,
1209    rwork: Option<&mut DeviceMemory<f32>>,
1210    dev_info: &mut DeviceMemory<i32>,
1211) -> Result<()> {
1212    ctx.bind()?;
1213    validate_gesvd_inputs(
1214        m,
1215        n,
1216        a.data.len(),
1217        a.leading_dimension,
1218        s.len(),
1219        job_u,
1220        matrix_mut_ref_parts(u.as_ref()).as_ref(),
1221        job_vt,
1222        matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1223    )?;
1224    require_info_buffer(dev_info)?;
1225    require_rwork_buffer(rwork.as_deref(), m, n)?;
1226    let lwork = sgesvd_buffer_size(ctx, m, n)?;
1227    require_workspace(workspace.len(), lwork)?;
1228    let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1229    let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1230    unsafe {
1231        try_ffi!(sys::cusolverDnSgesvd(
1232            ctx.as_raw(),
1233            job_u.as_raw(),
1234            job_vt.as_raw(),
1235            to_i32(m, "m")?,
1236            to_i32(n, "n")?,
1237            a.data.as_mut_ptr().cast(),
1238            to_i32(a.leading_dimension, "lda")?,
1239            s.as_mut_ptr().cast(),
1240            u_ptr.cast(),
1241            ldu,
1242            vt_ptr.cast(),
1243            ldvt,
1244            workspace.as_mut_ptr().cast(),
1245            to_i32(lwork, "lwork")?,
1246            rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1247            dev_info.as_mut_ptr().cast(),
1248        ))?;
1249    }
1250    Ok(())
1251}
1252
1253/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1254///
1255/// The S and D data types are real valued single and double precision, respectively.
1256///
1257/// The C and Z data types are complex valued single and double precision, respectively.
1258///
1259/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1260/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
1261/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
1262/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
1263/// $n \times n$ unitary matrix.
1264/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1265/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1266///
1267/// Provide workspace through `workspace`.
1268/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1269/// The workspace size in bytes is `size_of::<T>() * lwork`.
1270///
1271/// If the reported `dev_info` value is `-i`, the `i`th parameter is invalid. If `bdsqr` did not converge, `dev_info` specifies how many superdiagonals of an intermediate bidiagonal form did not converge to zero.
1272///
1273/// `rwork` is a real workspace buffer with length `min(m, n) - 1`.
1274/// If `dev_info > 0` and `rwork` is `Some(_)`, it contains the unconverged superdiagonal elements of an upper bidiagonal matrix.
1275/// This is slightly different from LAPACK which puts unconverged superdiagonal elements in `work` if type is `real`; in `rwork` if type is `complex`.
1276/// Pass `None` for `rwork` when the unconverged superdiagonal elements are not needed.
1277///
1278/// - `gesvd` only supports `m >= n`.
1279///
1280/// - Returns $V^{H}$, not `V`.
1281///
1282/// # Errors
1283///
1284/// Returns an error if cuSOLVER has not been initialized, if the
1285/// matrix dimensions or leading dimensions are invalid, if the current GPU
1286/// architecture is unsupported, or if cuSOLVER reports an internal failure.
1287pub fn dgesvd(
1288    ctx: &Context,
1289    job_u: SvdMode,
1290    job_vt: SvdMode,
1291    m: usize,
1292    n: usize,
1293    a: MatrixMut<'_, f64>,
1294    s: &mut DeviceMemory<f64>,
1295    u: Option<MatrixMut<'_, f64>>,
1296    vt: Option<MatrixMut<'_, f64>>,
1297    workspace: &mut DeviceMemory<f64>,
1298    rwork: Option<&mut DeviceMemory<f64>>,
1299    dev_info: &mut DeviceMemory<i32>,
1300) -> Result<()> {
1301    ctx.bind()?;
1302    validate_gesvd_inputs(
1303        m,
1304        n,
1305        a.data.len(),
1306        a.leading_dimension,
1307        s.len(),
1308        job_u,
1309        matrix_mut_ref_parts(u.as_ref()).as_ref(),
1310        job_vt,
1311        matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1312    )?;
1313    require_info_buffer(dev_info)?;
1314    require_rwork_buffer(rwork.as_deref(), m, n)?;
1315    let lwork = dgesvd_buffer_size(ctx, m, n)?;
1316    require_workspace(workspace.len(), lwork)?;
1317    let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1318    let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1319    unsafe {
1320        try_ffi!(sys::cusolverDnDgesvd(
1321            ctx.as_raw(),
1322            job_u.as_raw(),
1323            job_vt.as_raw(),
1324            to_i32(m, "m")?,
1325            to_i32(n, "n")?,
1326            a.data.as_mut_ptr().cast(),
1327            to_i32(a.leading_dimension, "lda")?,
1328            s.as_mut_ptr().cast(),
1329            u_ptr.cast(),
1330            ldu,
1331            vt_ptr.cast(),
1332            ldvt,
1333            workspace.as_mut_ptr().cast(),
1334            to_i32(lwork, "lwork")?,
1335            rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1336            dev_info.as_mut_ptr().cast(),
1337        ))?;
1338    }
1339    Ok(())
1340}
1341
1342/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1343///
1344/// The S and D data types are real valued single and double precision, respectively.
1345///
1346/// The C and Z data types are complex valued single and double precision, respectively.
1347///
1348/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1349/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
1350/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
1351/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
1352/// $n \times n$ unitary matrix.
1353/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1354/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1355///
1356/// Provide workspace through `workspace`.
1357/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1358/// The workspace size in bytes is `size_of::<T>() * lwork`.
1359///
1360/// If the reported `dev_info` value is `-i`, the `i`th parameter is invalid. If `bdsqr` did not converge, `dev_info` specifies how many superdiagonals of an intermediate bidiagonal form did not converge to zero.
1361///
1362/// `rwork` is a real workspace buffer with length `min(m, n) - 1`.
1363/// If `dev_info > 0` and `rwork` is `Some(_)`, it contains the unconverged superdiagonal elements of an upper bidiagonal matrix.
1364/// This is slightly different from LAPACK which puts unconverged superdiagonal elements in `work` if type is `real`; in `rwork` if type is `complex`.
1365/// Pass `None` for `rwork` when the unconverged superdiagonal elements are not needed.
1366///
1367/// - `gesvd` only supports `m >= n`.
1368///
1369/// - Returns $V^{H}$, not `V`.
1370///
1371/// # Errors
1372///
1373/// Returns an error if cuSOLVER has not been initialized, if the
1374/// matrix dimensions or leading dimensions are invalid, if the current GPU
1375/// architecture is unsupported, or if cuSOLVER reports an internal failure.
1376pub fn cgesvd(
1377    ctx: &Context,
1378    job_u: SvdMode,
1379    job_vt: SvdMode,
1380    m: usize,
1381    n: usize,
1382    a: MatrixMut<'_, Complex32>,
1383    s: &mut DeviceMemory<f32>,
1384    u: Option<MatrixMut<'_, Complex32>>,
1385    vt: Option<MatrixMut<'_, Complex32>>,
1386    workspace: &mut DeviceMemory<Complex32>,
1387    rwork: Option<&mut DeviceMemory<f32>>,
1388    dev_info: &mut DeviceMemory<i32>,
1389) -> Result<()> {
1390    ctx.bind()?;
1391    validate_gesvd_inputs(
1392        m,
1393        n,
1394        a.data.len(),
1395        a.leading_dimension,
1396        s.len(),
1397        job_u,
1398        matrix_mut_ref_parts(u.as_ref()).as_ref(),
1399        job_vt,
1400        matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1401    )?;
1402    require_info_buffer(dev_info)?;
1403    require_rwork_buffer(rwork.as_deref(), m, n)?;
1404    let lwork = cgesvd_buffer_size(ctx, m, n)?;
1405    require_workspace(workspace.len(), lwork)?;
1406    let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1407    let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1408    unsafe {
1409        try_ffi!(sys::cusolverDnCgesvd(
1410            ctx.as_raw(),
1411            job_u.as_raw(),
1412            job_vt.as_raw(),
1413            to_i32(m, "m")?,
1414            to_i32(n, "n")?,
1415            a.data.as_mut_ptr().cast(),
1416            to_i32(a.leading_dimension, "lda")?,
1417            s.as_mut_ptr().cast(),
1418            u_ptr.cast(),
1419            ldu,
1420            vt_ptr.cast(),
1421            ldvt,
1422            workspace.as_mut_ptr().cast(),
1423            to_i32(lwork, "lwork")?,
1424            rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1425            dev_info.as_mut_ptr().cast(),
1426        ))?;
1427    }
1428    Ok(())
1429}
1430
1431/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1432///
1433/// The S and D data types are real valued single and double precision, respectively.
1434///
1435/// The C and Z data types are complex valued single and double precision, respectively.
1436///
1437/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1438/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an
1439/// $m \times n$ matrix which is zero except for its `min(m,n)` diagonal
1440/// elements, `U` is an $m \times m$ unitary matrix, and `V` is an
1441/// $n \times n$ unitary matrix.
1442/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1443/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1444///
1445/// Provide workspace through `workspace`.
1446/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1447/// The workspace size in bytes is `size_of::<T>() * lwork`.
1448///
1449/// If the reported `dev_info` value is `-i`, the `i`th parameter is invalid. If `bdsqr` did not converge, `dev_info` specifies how many superdiagonals of an intermediate bidiagonal form did not converge to zero.
1450///
1451/// `rwork` is a real workspace buffer with length `min(m, n) - 1`.
1452/// If `dev_info > 0` and `rwork` is `Some(_)`, it contains the unconverged superdiagonal elements of an upper bidiagonal matrix.
1453/// This is slightly different from LAPACK which puts unconverged superdiagonal elements in `work` if type is `real`; in `rwork` if type is `complex`.
1454/// Pass `None` for `rwork` when the unconverged superdiagonal elements are not needed.
1455///
1456/// - `gesvd` only supports `m >= n`.
1457///
1458/// - Returns $V^{H}$, not `V`.
1459///
1460/// # Errors
1461///
1462/// Returns an error if cuSOLVER has not been initialized, if the
1463/// matrix dimensions or leading dimensions are invalid, if the current GPU
1464/// architecture is unsupported, or if cuSOLVER reports an internal failure.
1465pub fn zgesvd(
1466    ctx: &Context,
1467    job_u: SvdMode,
1468    job_vt: SvdMode,
1469    m: usize,
1470    n: usize,
1471    a: MatrixMut<'_, Complex64>,
1472    s: &mut DeviceMemory<f64>,
1473    u: Option<MatrixMut<'_, Complex64>>,
1474    vt: Option<MatrixMut<'_, Complex64>>,
1475    workspace: &mut DeviceMemory<Complex64>,
1476    rwork: Option<&mut DeviceMemory<f64>>,
1477    dev_info: &mut DeviceMemory<i32>,
1478) -> Result<()> {
1479    ctx.bind()?;
1480    validate_gesvd_inputs(
1481        m,
1482        n,
1483        a.data.len(),
1484        a.leading_dimension,
1485        s.len(),
1486        job_u,
1487        matrix_mut_ref_parts(u.as_ref()).as_ref(),
1488        job_vt,
1489        matrix_mut_ref_parts(vt.as_ref()).as_ref(),
1490    )?;
1491    require_info_buffer(dev_info)?;
1492    require_rwork_buffer(rwork.as_deref(), m, n)?;
1493    let lwork = zgesvd_buffer_size(ctx, m, n)?;
1494    require_workspace(workspace.len(), lwork)?;
1495    let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
1496    let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
1497    unsafe {
1498        try_ffi!(sys::cusolverDnZgesvd(
1499            ctx.as_raw(),
1500            job_u.as_raw(),
1501            job_vt.as_raw(),
1502            to_i32(m, "m")?,
1503            to_i32(n, "n")?,
1504            a.data.as_mut_ptr().cast(),
1505            to_i32(a.leading_dimension, "lda")?,
1506            s.as_mut_ptr().cast(),
1507            u_ptr.cast(),
1508            ldu,
1509            vt_ptr.cast(),
1510            ldvt,
1511            workspace.as_mut_ptr().cast(),
1512            to_i32(lwork, "lwork")?,
1513            rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
1514            dev_info.as_mut_ptr().cast(),
1515        ))?;
1516    }
1517    Ok(())
1518}
1519
1520pub fn sgesvdj_buffer_size(
1521    ctx: &Context,
1522    jobz: EigenMode,
1523    econ: bool,
1524    m: usize,
1525    n: usize,
1526    a: MatrixRef<'_, f32>,
1527    s: &DeviceMemory<f32>,
1528    u: Option<MatrixRef<'_, f32>>,
1529    v: Option<MatrixRef<'_, f32>>,
1530    params: &GesvdjInfo,
1531) -> Result<usize> {
1532    ctx.bind()?;
1533    validate_gesvdj_inputs(
1534        m,
1535        n,
1536        a.data.len(),
1537        a.leading_dimension,
1538        s.len(),
1539        jobz,
1540        econ,
1541        matrix_ref_parts(u),
1542        matrix_ref_parts(v),
1543    )?;
1544    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1545    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1546    let mut lwork = 0;
1547    unsafe {
1548        try_ffi!(sys::cusolverDnSgesvdj_bufferSize(
1549            ctx.as_raw(),
1550            jobz.into(),
1551            i32::from(econ),
1552            to_i32(m, "m")?,
1553            to_i32(n, "n")?,
1554            a.data.as_ptr().cast(),
1555            to_i32(a.leading_dimension, "lda")?,
1556            s.as_ptr().cast(),
1557            u_ptr.cast(),
1558            ldu,
1559            v_ptr.cast(),
1560            ldv,
1561            &raw mut lwork,
1562            params.as_raw(),
1563        ))?;
1564    }
1565    to_usize(lwork, "lwork")
1566}
1567
1568pub fn dgesvdj_buffer_size(
1569    ctx: &Context,
1570    jobz: EigenMode,
1571    econ: bool,
1572    m: usize,
1573    n: usize,
1574    a: MatrixRef<'_, f64>,
1575    s: &DeviceMemory<f64>,
1576    u: Option<MatrixRef<'_, f64>>,
1577    v: Option<MatrixRef<'_, f64>>,
1578    params: &GesvdjInfo,
1579) -> Result<usize> {
1580    ctx.bind()?;
1581    validate_gesvdj_inputs(
1582        m,
1583        n,
1584        a.data.len(),
1585        a.leading_dimension,
1586        s.len(),
1587        jobz,
1588        econ,
1589        matrix_ref_parts(u),
1590        matrix_ref_parts(v),
1591    )?;
1592    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1593    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1594    let mut lwork = 0;
1595    unsafe {
1596        try_ffi!(sys::cusolverDnDgesvdj_bufferSize(
1597            ctx.as_raw(),
1598            jobz.into(),
1599            i32::from(econ),
1600            to_i32(m, "m")?,
1601            to_i32(n, "n")?,
1602            a.data.as_ptr().cast(),
1603            to_i32(a.leading_dimension, "lda")?,
1604            s.as_ptr().cast(),
1605            u_ptr.cast(),
1606            ldu,
1607            v_ptr.cast(),
1608            ldv,
1609            &raw mut lwork,
1610            params.as_raw(),
1611        ))?;
1612    }
1613    to_usize(lwork, "lwork")
1614}
1615
1616pub fn cgesvdj_buffer_size(
1617    ctx: &Context,
1618    jobz: EigenMode,
1619    econ: bool,
1620    m: usize,
1621    n: usize,
1622    a: MatrixRef<'_, Complex32>,
1623    s: &DeviceMemory<f32>,
1624    u: Option<MatrixRef<'_, Complex32>>,
1625    v: Option<MatrixRef<'_, Complex32>>,
1626    params: &GesvdjInfo,
1627) -> Result<usize> {
1628    ctx.bind()?;
1629    validate_gesvdj_inputs(
1630        m,
1631        n,
1632        a.data.len(),
1633        a.leading_dimension,
1634        s.len(),
1635        jobz,
1636        econ,
1637        matrix_ref_parts(u),
1638        matrix_ref_parts(v),
1639    )?;
1640    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1641    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1642    let mut lwork = 0;
1643    unsafe {
1644        try_ffi!(sys::cusolverDnCgesvdj_bufferSize(
1645            ctx.as_raw(),
1646            jobz.into(),
1647            i32::from(econ),
1648            to_i32(m, "m")?,
1649            to_i32(n, "n")?,
1650            a.data.as_ptr().cast(),
1651            to_i32(a.leading_dimension, "lda")?,
1652            s.as_ptr().cast(),
1653            u_ptr.cast(),
1654            ldu,
1655            v_ptr.cast(),
1656            ldv,
1657            &raw mut lwork,
1658            params.as_raw(),
1659        ))?;
1660    }
1661    to_usize(lwork, "lwork")
1662}
1663
1664pub fn zgesvdj_buffer_size(
1665    ctx: &Context,
1666    jobz: EigenMode,
1667    econ: bool,
1668    m: usize,
1669    n: usize,
1670    a: MatrixRef<'_, Complex64>,
1671    s: &DeviceMemory<f64>,
1672    u: Option<MatrixRef<'_, Complex64>>,
1673    v: Option<MatrixRef<'_, Complex64>>,
1674    params: &GesvdjInfo,
1675) -> Result<usize> {
1676    ctx.bind()?;
1677    validate_gesvdj_inputs(
1678        m,
1679        n,
1680        a.data.len(),
1681        a.leading_dimension,
1682        s.len(),
1683        jobz,
1684        econ,
1685        matrix_ref_parts(u),
1686        matrix_ref_parts(v),
1687    )?;
1688    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
1689    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
1690    let mut lwork = 0;
1691    unsafe {
1692        try_ffi!(sys::cusolverDnZgesvdj_bufferSize(
1693            ctx.as_raw(),
1694            jobz.into(),
1695            i32::from(econ),
1696            to_i32(m, "m")?,
1697            to_i32(n, "n")?,
1698            a.data.as_ptr().cast(),
1699            to_i32(a.leading_dimension, "lda")?,
1700            s.as_ptr().cast(),
1701            u_ptr.cast(),
1702            ldu,
1703            v_ptr.cast(),
1704            ldv,
1705            &raw mut lwork,
1706            params.as_raw(),
1707        ))?;
1708    }
1709    to_usize(lwork, "lwork")
1710}
1711
1712/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1713///
1714/// The S and D data types are real valued single and double precision, respectively.
1715///
1716/// The C and Z data types are complex valued single and double precision, respectively.
1717///
1718/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1719/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an $m \times n$ matrix which is zero except for its `min(m,n)` diagonal elements, `U` is an $m \times m$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
1720/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1721/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1722///
1723/// `gesvdj` computes the same decomposition as `gesvd`.
1724/// The difference is that `gesvd` uses a QR algorithm and `gesvdj` uses the Jacobi method.
1725/// The Jacobi method gives GPUs better parallelism on small and medium-size matrices.
1726/// Callers can configure `gesvdj` to target a chosen accuracy.
1727///
1728/// `gesvdj` iteratively generates a sequence of unitary matrices that transform `A` toward $A = U(S + E)V^{H}$, where `S` is diagonal and the diagonal of `E` is zero.
1729///
1730/// During the iterations, the Frobenius norm of `E` decreases monotonically.
1731/// As `E` goes down to zero, `S` is the set of singular values.
1732/// In practice, the Jacobi method stops when the off-diagonal residual is below the configured tolerance `eps`.
1733/// If the real residual norm is computed, it differs from ${\\|{E}\\|}\_{F}$ by roundoff errors of order $N = max(m, n)$ while still meeting the standard SVD accuracy expectation.
1734///
1735/// $O(N)$ is typically $N$, but the constant depends on the number of sweeps, which gives an upper roundoff error bound of $sweeps \cdot N$.
1736///
1737/// `gesvdj` has two parameters to control the accuracy.
1738/// The first parameter is the tolerance (`eps`).
1739/// The default value is machine accuracy, but [`GesvdjInfo::set_tolerance`] can set an a priori tolerance.
1740/// The maximum-sweep parameter is the maximum number of sweeps, which controls the number of Jacobi iterations.
1741/// The default value is 100, but [`GesvdjInfo::set_max_sweeps`] can set a different bound.
1742/// Experiments show that 15 sweeps are enough to converge to machine accuracy.
1743/// `gesvdj` stops when either the tolerance or the maximum number of sweeps is reached.
1744///
1745/// The Jacobi method has quadratic convergence, so the accuracy is not proportional to the number of sweeps.
1746/// To guarantee a target accuracy, configure only the tolerance.
1747///
1748/// Provide workspace through `workspace`.
1749/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1750/// The workspace size in bytes is `size_of::<T>() * lwork`.
1751///
1752/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
1753/// If `info == min(m, n) + 1`, `gesvdj` did not converge within the given tolerance and maximum sweep count.
1754///
1755/// If the tolerance is too small, `gesvdj` may not converge.
1756/// Use a tolerance no smaller than machine accuracy.
1757///
1758/// - `gesvdj` supports any combination of `m` and `n`.
1759///
1760/// - Returns `V`, not $V^{H}$.
1761///   This is different from `gesvd`.
1762///
1763/// # Errors
1764///
1765/// Returns an error if cuSOLVER has not been initialized, if the
1766/// matrix dimensions, leading dimensions, or vector-computation mode are
1767/// invalid, or if cuSOLVER reports an internal failure.
1768pub fn sgesvdj(
1769    ctx: &Context,
1770    jobz: EigenMode,
1771    econ: bool,
1772    m: usize,
1773    n: usize,
1774    a: MatrixMut<'_, f32>,
1775    s: &mut DeviceMemory<f32>,
1776    u: Option<MatrixMut<'_, f32>>,
1777    v: Option<MatrixMut<'_, f32>>,
1778    workspace: &mut DeviceMemory<f32>,
1779    dev_info: &mut DeviceMemory<i32>,
1780    params: &GesvdjInfo,
1781) -> Result<()> {
1782    ctx.bind()?;
1783    validate_gesvdj_inputs(
1784        m,
1785        n,
1786        a.data.len(),
1787        a.leading_dimension,
1788        s.len(),
1789        jobz,
1790        econ,
1791        matrix_mut_ref_parts(u.as_ref()),
1792        matrix_mut_ref_parts(v.as_ref()),
1793    )?;
1794    require_info_buffer(dev_info)?;
1795    let lwork = sgesvdj_buffer_size(
1796        ctx,
1797        jobz,
1798        econ,
1799        m,
1800        n,
1801        a.as_ref(),
1802        s,
1803        matrix_mut_ref_option(u.as_ref()),
1804        matrix_mut_ref_option(v.as_ref()),
1805        params,
1806    )?;
1807    require_workspace(workspace.len(), lwork)?;
1808    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
1809    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
1810    unsafe {
1811        try_ffi!(sys::cusolverDnSgesvdj(
1812            ctx.as_raw(),
1813            jobz.into(),
1814            i32::from(econ),
1815            to_i32(m, "m")?,
1816            to_i32(n, "n")?,
1817            a.data.as_mut_ptr().cast(),
1818            to_i32(a.leading_dimension, "lda")?,
1819            s.as_mut_ptr().cast(),
1820            u_ptr.cast(),
1821            ldu,
1822            v_ptr.cast(),
1823            ldv,
1824            workspace.as_mut_ptr().cast(),
1825            to_i32(lwork, "lwork")?,
1826            dev_info.as_mut_ptr().cast(),
1827            params.as_raw(),
1828        ))?;
1829    }
1830    Ok(())
1831}
1832
1833/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1834///
1835/// The S and D data types are real valued single and double precision, respectively.
1836///
1837/// The C and Z data types are complex valued single and double precision, respectively.
1838///
1839/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1840/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an $m \times n$ matrix which is zero except for its `min(m,n)` diagonal elements, `U` is an $m \times m$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
1841/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1842/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1843///
1844/// `gesvdj` computes the same decomposition as `gesvd`.
1845/// The difference is that `gesvd` uses a QR algorithm and `gesvdj` uses the Jacobi method.
1846/// The Jacobi method gives GPUs better parallelism on small and medium-size matrices.
1847/// Callers can configure `gesvdj` to target a chosen accuracy.
1848///
1849/// `gesvdj` iteratively generates a sequence of unitary matrices that transform `A` toward $A = U(S + E)V^{H}$, where `S` is diagonal and the diagonal of `E` is zero.
1850///
1851/// During the iterations, the Frobenius norm of `E` decreases monotonically.
1852/// As `E` goes down to zero, `S` is the set of singular values.
1853/// In practice, the Jacobi method stops when the off-diagonal residual is below the configured tolerance `eps`.
1854/// If the real residual norm is computed, it differs from ${\\|{E}\\|}\_{F}$ by roundoff errors of order $N = max(m, n)$ while still meeting the standard SVD accuracy expectation.
1855///
1856/// $O(N)$ is typically $N$, but the constant depends on the number of sweeps, which gives an upper roundoff error bound of $sweeps \cdot N$.
1857///
1858/// `gesvdj` has two parameters to control the accuracy.
1859/// The first parameter is the tolerance (`eps`).
1860/// The default value is machine accuracy, but [`GesvdjInfo::set_tolerance`] can set an a priori tolerance.
1861/// The maximum-sweep parameter is the maximum number of sweeps, which controls the number of Jacobi iterations.
1862/// The default value is 100, but [`GesvdjInfo::set_max_sweeps`] can set a different bound.
1863/// Experiments show that 15 sweeps are enough to converge to machine accuracy.
1864/// `gesvdj` stops when either the tolerance or the maximum number of sweeps is reached.
1865///
1866/// The Jacobi method has quadratic convergence, so the accuracy is not proportional to the number of sweeps.
1867/// To guarantee a target accuracy, configure only the tolerance.
1868///
1869/// Provide workspace through `workspace`.
1870/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1871/// The workspace size in bytes is `size_of::<T>() * lwork`.
1872///
1873/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
1874/// If `info == min(m, n) + 1`, `gesvdj` did not converge within the given tolerance and maximum sweep count.
1875///
1876/// If the tolerance is too small, `gesvdj` may not converge.
1877/// Use a tolerance no smaller than machine accuracy.
1878///
1879/// - `gesvdj` supports any combination of `m` and `n`.
1880///
1881/// - Returns `V`, not $V^{H}$.
1882///   This is different from `gesvd`.
1883///
1884/// # Errors
1885///
1886/// Returns an error if cuSOLVER has not been initialized, if the
1887/// matrix dimensions, leading dimensions, or vector-computation mode are
1888/// invalid, or if cuSOLVER reports an internal failure.
1889pub fn dgesvdj(
1890    ctx: &Context,
1891    jobz: EigenMode,
1892    econ: bool,
1893    m: usize,
1894    n: usize,
1895    a: MatrixMut<'_, f64>,
1896    s: &mut DeviceMemory<f64>,
1897    u: Option<MatrixMut<'_, f64>>,
1898    v: Option<MatrixMut<'_, f64>>,
1899    workspace: &mut DeviceMemory<f64>,
1900    dev_info: &mut DeviceMemory<i32>,
1901    params: &GesvdjInfo,
1902) -> Result<()> {
1903    ctx.bind()?;
1904    validate_gesvdj_inputs(
1905        m,
1906        n,
1907        a.data.len(),
1908        a.leading_dimension,
1909        s.len(),
1910        jobz,
1911        econ,
1912        matrix_mut_ref_parts(u.as_ref()),
1913        matrix_mut_ref_parts(v.as_ref()),
1914    )?;
1915    require_info_buffer(dev_info)?;
1916    let lwork = dgesvdj_buffer_size(
1917        ctx,
1918        jobz,
1919        econ,
1920        m,
1921        n,
1922        a.as_ref(),
1923        s,
1924        matrix_mut_ref_option(u.as_ref()),
1925        matrix_mut_ref_option(v.as_ref()),
1926        params,
1927    )?;
1928    require_workspace(workspace.len(), lwork)?;
1929    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
1930    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
1931    unsafe {
1932        try_ffi!(sys::cusolverDnDgesvdj(
1933            ctx.as_raw(),
1934            jobz.into(),
1935            i32::from(econ),
1936            to_i32(m, "m")?,
1937            to_i32(n, "n")?,
1938            a.data.as_mut_ptr().cast(),
1939            to_i32(a.leading_dimension, "lda")?,
1940            s.as_mut_ptr().cast(),
1941            u_ptr.cast(),
1942            ldu,
1943            v_ptr.cast(),
1944            ldv,
1945            workspace.as_mut_ptr().cast(),
1946            to_i32(lwork, "lwork")?,
1947            dev_info.as_mut_ptr().cast(),
1948            params.as_raw(),
1949        ))?;
1950    }
1951    Ok(())
1952}
1953
1954/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
1955///
1956/// The S and D data types are real valued single and double precision, respectively.
1957///
1958/// The C and Z data types are complex valued single and double precision, respectively.
1959///
1960/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
1961/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an $m \times n$ matrix which is zero except for its `min(m,n)` diagonal elements, `U` is an $m \times m$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
1962/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
1963/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
1964///
1965/// `gesvdj` computes the same decomposition as `gesvd`.
1966/// The difference is that `gesvd` uses a QR algorithm and `gesvdj` uses the Jacobi method.
1967/// The Jacobi method gives GPUs better parallelism on small and medium-size matrices.
1968/// Callers can configure `gesvdj` to target a chosen accuracy.
1969///
1970/// `gesvdj` iteratively generates a sequence of unitary matrices that transform `A` toward $A = U(S + E)V^{H}$, where `S` is diagonal and the diagonal of `E` is zero.
1971///
1972/// During the iterations, the Frobenius norm of `E` decreases monotonically.
1973/// As `E` goes down to zero, `S` is the set of singular values.
1974/// In practice, the Jacobi method stops when the off-diagonal residual is below the configured tolerance `eps`.
1975/// If the real residual norm is computed, it differs from ${\\|{E}\\|}\_{F}$ by roundoff errors of order $N = max(m, n)$ while still meeting the standard SVD accuracy expectation.
1976///
1977/// $O(N)$ is typically $N$, but the constant depends on the number of sweeps, which gives an upper roundoff error bound of $sweeps \cdot N$.
1978///
1979/// `gesvdj` has two parameters to control the accuracy.
1980/// The first parameter is the tolerance (`eps`).
1981/// The default value is machine accuracy, but [`GesvdjInfo::set_tolerance`] can set an a priori tolerance.
1982/// The maximum-sweep parameter is the maximum number of sweeps, which controls the number of Jacobi iterations.
1983/// The default value is 100, but [`GesvdjInfo::set_max_sweeps`] can set a different bound.
1984/// Experiments show that 15 sweeps are enough to converge to machine accuracy.
1985/// `gesvdj` stops when either the tolerance or the maximum number of sweeps is reached.
1986///
1987/// The Jacobi method has quadratic convergence, so the accuracy is not proportional to the number of sweeps.
1988/// To guarantee a target accuracy, configure only the tolerance.
1989///
1990/// Provide workspace through `workspace`.
1991/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
1992/// The workspace size in bytes is `size_of::<T>() * lwork`.
1993///
1994/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
1995/// If `info == min(m, n) + 1`, `gesvdj` did not converge within the given tolerance and maximum sweep count.
1996///
1997/// If the tolerance is too small, `gesvdj` may not converge.
1998/// Use a tolerance no smaller than machine accuracy.
1999///
2000/// - `gesvdj` supports any combination of `m` and `n`.
2001///
2002/// - Returns `V`, not $V^{H}$.
2003///   This is different from `gesvd`.
2004///
2005/// # Errors
2006///
2007/// Returns an error if cuSOLVER has not been initialized, if the
2008/// matrix dimensions, leading dimensions, or vector-computation mode are
2009/// invalid, or if cuSOLVER reports an internal failure.
2010pub fn cgesvdj(
2011    ctx: &Context,
2012    jobz: EigenMode,
2013    econ: bool,
2014    m: usize,
2015    n: usize,
2016    a: MatrixMut<'_, Complex32>,
2017    s: &mut DeviceMemory<f32>,
2018    u: Option<MatrixMut<'_, Complex32>>,
2019    v: Option<MatrixMut<'_, Complex32>>,
2020    workspace: &mut DeviceMemory<Complex32>,
2021    dev_info: &mut DeviceMemory<i32>,
2022    params: &GesvdjInfo,
2023) -> Result<()> {
2024    ctx.bind()?;
2025    validate_gesvdj_inputs(
2026        m,
2027        n,
2028        a.data.len(),
2029        a.leading_dimension,
2030        s.len(),
2031        jobz,
2032        econ,
2033        matrix_mut_ref_parts(u.as_ref()),
2034        matrix_mut_ref_parts(v.as_ref()),
2035    )?;
2036    require_info_buffer(dev_info)?;
2037    let lwork = cgesvdj_buffer_size(
2038        ctx,
2039        jobz,
2040        econ,
2041        m,
2042        n,
2043        a.as_ref(),
2044        s,
2045        matrix_mut_ref_option(u.as_ref()),
2046        matrix_mut_ref_option(v.as_ref()),
2047        params,
2048    )?;
2049    require_workspace(workspace.len(), lwork)?;
2050    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
2051    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
2052    unsafe {
2053        try_ffi!(sys::cusolverDnCgesvdj(
2054            ctx.as_raw(),
2055            jobz.into(),
2056            i32::from(econ),
2057            to_i32(m, "m")?,
2058            to_i32(n, "n")?,
2059            a.data.as_mut_ptr().cast(),
2060            to_i32(a.leading_dimension, "lda")?,
2061            s.as_mut_ptr().cast(),
2062            u_ptr.cast(),
2063            ldu,
2064            v_ptr.cast(),
2065            ldv,
2066            workspace.as_mut_ptr().cast(),
2067            to_i32(lwork, "lwork")?,
2068            dev_info.as_mut_ptr().cast(),
2069            params.as_raw(),
2070        ))?;
2071    }
2072    Ok(())
2073}
2074
2075/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
2076///
2077/// The S and D data types are real valued single and double precision, respectively.
2078///
2079/// The C and Z data types are complex valued single and double precision, respectively.
2080///
2081/// Computes the singular value decomposition (SVD) of an $m \times n$ matrix `A` and the corresponding left and/or right singular vectors.
2082/// The SVD is written as $A = U \Sigma V^{H}$, where `Σ` is an $m \times n$ matrix which is zero except for its `min(m,n)` diagonal elements, `U` is an $m \times m$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
2083/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
2084/// The first `min(m,n)` columns of `U` and `V` are the left and right singular vectors of `A`.
2085///
2086/// `gesvdj` computes the same decomposition as `gesvd`.
2087/// The difference is that `gesvd` uses a QR algorithm and `gesvdj` uses the Jacobi method.
2088/// The Jacobi method gives GPUs better parallelism on small and medium-size matrices.
2089/// Callers can configure `gesvdj` to target a chosen accuracy.
2090///
2091/// `gesvdj` iteratively generates a sequence of unitary matrices that transform `A` toward $A = U(S + E)V^{H}$, where `S` is diagonal and the diagonal of `E` is zero.
2092///
2093/// During the iterations, the Frobenius norm of `E` decreases monotonically.
2094/// As `E` goes down to zero, `S` is the set of singular values.
2095/// In practice, the Jacobi method stops when the off-diagonal residual is below the configured tolerance `eps`.
2096/// If the real residual norm is computed, it differs from ${\\|{E}\\|}\_{F}$ by roundoff errors of order $N = max(m, n)$ while still meeting the standard SVD accuracy expectation.
2097///
2098/// $O(N)$ is typically $N$, but the constant depends on the number of sweeps, which gives an upper roundoff error bound of $sweeps \cdot N$.
2099///
2100/// `gesvdj` has two parameters to control the accuracy.
2101/// The first parameter is the tolerance (`eps`).
2102/// The default value is machine accuracy, but [`GesvdjInfo::set_tolerance`] can set an a priori tolerance.
2103/// The maximum-sweep parameter is the maximum number of sweeps, which controls the number of Jacobi iterations.
2104/// The default value is 100, but [`GesvdjInfo::set_max_sweeps`] can set a different bound.
2105/// Experiments show that 15 sweeps are enough to converge to machine accuracy.
2106/// `gesvdj` stops when either the tolerance or the maximum number of sweeps is reached.
2107///
2108/// The Jacobi method has quadratic convergence, so the accuracy is not proportional to the number of sweeps.
2109/// To guarantee a target accuracy, configure only the tolerance.
2110///
2111/// Provide workspace through `workspace`.
2112/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
2113/// The workspace size in bytes is `size_of::<T>() * lwork`.
2114///
2115/// If the reported `info` value is `-i`, the `i`th parameter is invalid.
2116/// If `info == min(m, n) + 1`, `gesvdj` did not converge within the given tolerance and maximum sweep count.
2117///
2118/// If the tolerance is too small, `gesvdj` may not converge.
2119/// Use a tolerance no smaller than machine accuracy.
2120///
2121/// - `gesvdj` supports any combination of `m` and `n`.
2122///
2123/// - Returns `V`, not $V^{H}$.
2124///   This is different from `gesvd`.
2125///
2126/// # Errors
2127///
2128/// Returns an error if cuSOLVER has not been initialized, if the
2129/// matrix dimensions, leading dimensions, or vector-computation mode are
2130/// invalid, or if cuSOLVER reports an internal failure.
2131pub fn zgesvdj(
2132    ctx: &Context,
2133    jobz: EigenMode,
2134    econ: bool,
2135    m: usize,
2136    n: usize,
2137    a: MatrixMut<'_, Complex64>,
2138    s: &mut DeviceMemory<f64>,
2139    u: Option<MatrixMut<'_, Complex64>>,
2140    v: Option<MatrixMut<'_, Complex64>>,
2141    workspace: &mut DeviceMemory<Complex64>,
2142    dev_info: &mut DeviceMemory<i32>,
2143    params: &GesvdjInfo,
2144) -> Result<()> {
2145    ctx.bind()?;
2146    validate_gesvdj_inputs(
2147        m,
2148        n,
2149        a.data.len(),
2150        a.leading_dimension,
2151        s.len(),
2152        jobz,
2153        econ,
2154        matrix_mut_ref_parts(u.as_ref()),
2155        matrix_mut_ref_parts(v.as_ref()),
2156    )?;
2157    require_info_buffer(dev_info)?;
2158    let lwork = zgesvdj_buffer_size(
2159        ctx,
2160        jobz,
2161        econ,
2162        m,
2163        n,
2164        a.as_ref(),
2165        s,
2166        matrix_mut_ref_option(u.as_ref()),
2167        matrix_mut_ref_option(v.as_ref()),
2168        params,
2169    )?;
2170    require_workspace(workspace.len(), lwork)?;
2171    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
2172    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
2173    unsafe {
2174        try_ffi!(sys::cusolverDnZgesvdj(
2175            ctx.as_raw(),
2176            jobz.into(),
2177            i32::from(econ),
2178            to_i32(m, "m")?,
2179            to_i32(n, "n")?,
2180            a.data.as_mut_ptr().cast(),
2181            to_i32(a.leading_dimension, "lda")?,
2182            s.as_mut_ptr().cast(),
2183            u_ptr.cast(),
2184            ldu,
2185            v_ptr.cast(),
2186            ldv,
2187            workspace.as_mut_ptr().cast(),
2188            to_i32(lwork, "lwork")?,
2189            dev_info.as_mut_ptr().cast(),
2190            params.as_raw(),
2191        ))?;
2192    }
2193    Ok(())
2194}
2195
2196pub fn sgesvdj_batched_buffer_size(
2197    ctx: &Context,
2198    jobz: EigenMode,
2199    m: usize,
2200    n: usize,
2201    a: MatrixRef<'_, f32>,
2202    s: &DeviceMemory<f32>,
2203    u: Option<MatrixRef<'_, f32>>,
2204    v: Option<MatrixRef<'_, f32>>,
2205    params: &GesvdjInfo,
2206    batch_size: usize,
2207) -> Result<usize> {
2208    ctx.bind()?;
2209    validate_gesvdj_batched_inputs(
2210        m,
2211        n,
2212        a.data.len(),
2213        a.leading_dimension,
2214        s.len(),
2215        jobz,
2216        matrix_ref_parts(u),
2217        matrix_ref_parts(v),
2218        batch_size,
2219    )?;
2220    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2221    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2222    let mut lwork = 0;
2223    unsafe {
2224        try_ffi!(sys::cusolverDnSgesvdjBatched_bufferSize(
2225            ctx.as_raw(),
2226            jobz.into(),
2227            to_i32(m, "m")?,
2228            to_i32(n, "n")?,
2229            a.data.as_ptr().cast(),
2230            to_i32(a.leading_dimension, "lda")?,
2231            s.as_ptr().cast(),
2232            u_ptr.cast(),
2233            ldu,
2234            v_ptr.cast(),
2235            ldv,
2236            &raw mut lwork,
2237            params.as_raw(),
2238            to_i32(batch_size, "batch_size")?,
2239        ))?;
2240    }
2241    to_usize(lwork, "lwork")
2242}
2243
2244pub fn dgesvdj_batched_buffer_size(
2245    ctx: &Context,
2246    jobz: EigenMode,
2247    m: usize,
2248    n: usize,
2249    a: MatrixRef<'_, f64>,
2250    s: &DeviceMemory<f64>,
2251    u: Option<MatrixRef<'_, f64>>,
2252    v: Option<MatrixRef<'_, f64>>,
2253    params: &GesvdjInfo,
2254    batch_size: usize,
2255) -> Result<usize> {
2256    ctx.bind()?;
2257    validate_gesvdj_batched_inputs(
2258        m,
2259        n,
2260        a.data.len(),
2261        a.leading_dimension,
2262        s.len(),
2263        jobz,
2264        matrix_ref_parts(u),
2265        matrix_ref_parts(v),
2266        batch_size,
2267    )?;
2268    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2269    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2270    let mut lwork = 0;
2271    unsafe {
2272        try_ffi!(sys::cusolverDnDgesvdjBatched_bufferSize(
2273            ctx.as_raw(),
2274            jobz.into(),
2275            to_i32(m, "m")?,
2276            to_i32(n, "n")?,
2277            a.data.as_ptr().cast(),
2278            to_i32(a.leading_dimension, "lda")?,
2279            s.as_ptr().cast(),
2280            u_ptr.cast(),
2281            ldu,
2282            v_ptr.cast(),
2283            ldv,
2284            &raw mut lwork,
2285            params.as_raw(),
2286            to_i32(batch_size, "batch_size")?,
2287        ))?;
2288    }
2289    to_usize(lwork, "lwork")
2290}
2291
2292pub fn cgesvdj_batched_buffer_size(
2293    ctx: &Context,
2294    jobz: EigenMode,
2295    m: usize,
2296    n: usize,
2297    a: MatrixRef<'_, Complex32>,
2298    s: &DeviceMemory<f32>,
2299    u: Option<MatrixRef<'_, Complex32>>,
2300    v: Option<MatrixRef<'_, Complex32>>,
2301    params: &GesvdjInfo,
2302    batch_size: usize,
2303) -> Result<usize> {
2304    ctx.bind()?;
2305    validate_gesvdj_batched_inputs(
2306        m,
2307        n,
2308        a.data.len(),
2309        a.leading_dimension,
2310        s.len(),
2311        jobz,
2312        matrix_ref_parts(u),
2313        matrix_ref_parts(v),
2314        batch_size,
2315    )?;
2316    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2317    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2318    let mut lwork = 0;
2319    unsafe {
2320        try_ffi!(sys::cusolverDnCgesvdjBatched_bufferSize(
2321            ctx.as_raw(),
2322            jobz.into(),
2323            to_i32(m, "m")?,
2324            to_i32(n, "n")?,
2325            a.data.as_ptr().cast(),
2326            to_i32(a.leading_dimension, "lda")?,
2327            s.as_ptr().cast(),
2328            u_ptr.cast(),
2329            ldu,
2330            v_ptr.cast(),
2331            ldv,
2332            &raw mut lwork,
2333            params.as_raw(),
2334            to_i32(batch_size, "batch_size")?,
2335        ))?;
2336    }
2337    to_usize(lwork, "lwork")
2338}
2339
2340pub fn zgesvdj_batched_buffer_size(
2341    ctx: &Context,
2342    jobz: EigenMode,
2343    m: usize,
2344    n: usize,
2345    a: MatrixRef<'_, Complex64>,
2346    s: &DeviceMemory<f64>,
2347    u: Option<MatrixRef<'_, Complex64>>,
2348    v: Option<MatrixRef<'_, Complex64>>,
2349    params: &GesvdjInfo,
2350    batch_size: usize,
2351) -> Result<usize> {
2352    ctx.bind()?;
2353    validate_gesvdj_batched_inputs(
2354        m,
2355        n,
2356        a.data.len(),
2357        a.leading_dimension,
2358        s.len(),
2359        jobz,
2360        matrix_ref_parts(u),
2361        matrix_ref_parts(v),
2362        batch_size,
2363    )?;
2364    let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
2365    let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
2366    let mut lwork = 0;
2367    unsafe {
2368        try_ffi!(sys::cusolverDnZgesvdjBatched_bufferSize(
2369            ctx.as_raw(),
2370            jobz.into(),
2371            to_i32(m, "m")?,
2372            to_i32(n, "n")?,
2373            a.data.as_ptr().cast(),
2374            to_i32(a.leading_dimension, "lda")?,
2375            s.as_ptr().cast(),
2376            u_ptr.cast(),
2377            ldu,
2378            v_ptr.cast(),
2379            ldv,
2380            &raw mut lwork,
2381            params.as_raw(),
2382            to_i32(batch_size, "batch_size")?,
2383        ))?;
2384    }
2385    to_usize(lwork, "lwork")
2386}
2387
2388/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
2389///
2390/// The S and D data types are real valued single and double precision, respectively.
2391///
2392/// The C and Z data types are complex valued single and double precision, respectively.
2393///
2394/// Computes singular values and singular vectors of a sequence of general $m \times n$ matrices
2395///
2396/// where $\Sigma\_{j}$ is a real $m \times n$ diagonal matrix which is zero except for its `min(m,n)` diagonal elements. $U\_{j}$ (left singular vectors) is an $m \times m$ unitary matrix and $V\_{j}$ (right singular vectors) is a $n \times n$ unitary matrix.
2397/// The diagonal elements of $\Sigma\_{j}$ are the singular values of $A\_{j}$ in either descending order or non-sorting order.
2398///
2399/// `gesvdjBatched` performs `gesvdj` on each matrix.
2400/// It requires that all matrices are of the same size `m,n` no greater than 32 and are packed contiguously,
2401///
2402/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ lda\cdot n\cdot k\rbrack}$.
2403///
2404/// The `S` parameter also contains the singular values of each matrix contiguously,
2405///
2406/// The formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ min(m,n)\cdot k\rbrack}$.
2407///
2408/// Except for tolerance and maximum sweeps, `gesvdjBatched` can either sort the singular values in descending order (default) or choose as-is (without sorting) with [`GesvdjInfo::set_sort_eigenvalues`].
2409/// If several tiny matrices are packed into diagonal blocks of one matrix, the non-sorting option can separate the singular values of those tiny matrices.
2410///
2411/// `gesvdjBatched` cannot report residual and executed sweeps through [`GesvdjInfo::residual`] and [`GesvdjInfo::executed_sweeps`].
2412/// Calling either accessor returns [`Status::NotSupported`].
2413/// Compute the residual explicitly when needed.
2414///
2415/// Provide workspace through `workspace`.
2416/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
2417/// The workspace size in bytes is `size_of::<T>() * lwork`.
2418///
2419/// `dev_info` has one entry per batch item.
2420/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
2421/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdjBatched` did not converge on the `i`th matrix within the given tolerance and maximum sweep count.
2422///
2423/// # Errors
2424///
2425/// Returns an error if cuSOLVER has not been initialized, if the
2426/// matrix dimensions, leading dimensions, vector-computation mode, or batch
2427/// size are invalid, or if cuSOLVER reports an internal failure.
2428pub fn sgesvdj_batched(
2429    ctx: &Context,
2430    jobz: EigenMode,
2431    m: usize,
2432    n: usize,
2433    a: MatrixMut<'_, f32>,
2434    s: &mut DeviceMemory<f32>,
2435    u: Option<MatrixMut<'_, f32>>,
2436    v: Option<MatrixMut<'_, f32>>,
2437    workspace: &mut DeviceMemory<f32>,
2438    dev_info: &mut DeviceMemory<i32>,
2439    params: &GesvdjInfo,
2440    batch_size: usize,
2441) -> Result<()> {
2442    ctx.bind()?;
2443    validate_gesvdj_batched_inputs(
2444        m,
2445        n,
2446        a.data.len(),
2447        a.leading_dimension,
2448        s.len(),
2449        jobz,
2450        matrix_mut_ref_parts(u.as_ref()),
2451        matrix_mut_ref_parts(v.as_ref()),
2452        batch_size,
2453    )?;
2454    require_info_buffer_len(dev_info, batch_size)?;
2455    let lwork = sgesvdj_batched_buffer_size(
2456        ctx,
2457        jobz,
2458        m,
2459        n,
2460        a.as_ref(),
2461        s,
2462        matrix_mut_ref_option(u.as_ref()),
2463        matrix_mut_ref_option(v.as_ref()),
2464        params,
2465        batch_size,
2466    )?;
2467    require_workspace(workspace.len(), lwork)?;
2468    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2469    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2470    unsafe {
2471        try_ffi!(sys::cusolverDnSgesvdjBatched(
2472            ctx.as_raw(),
2473            jobz.into(),
2474            to_i32(m, "m")?,
2475            to_i32(n, "n")?,
2476            a.data.as_mut_ptr().cast(),
2477            to_i32(a.leading_dimension, "lda")?,
2478            s.as_mut_ptr().cast(),
2479            u_ptr.cast(),
2480            ldu,
2481            v_ptr.cast(),
2482            ldv,
2483            workspace.as_mut_ptr().cast(),
2484            to_i32(lwork, "lwork")?,
2485            dev_info.as_mut_ptr().cast(),
2486            params.as_raw(),
2487            to_i32(batch_size, "batch_size")?,
2488        ))?;
2489    }
2490    Ok(())
2491}
2492
2493/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
2494///
2495/// The S and D data types are real valued single and double precision, respectively.
2496///
2497/// The C and Z data types are complex valued single and double precision, respectively.
2498///
2499/// Computes singular values and singular vectors of a sequence of general $m \times n$ matrices
2500///
2501/// where $\Sigma\_{j}$ is a real $m \times n$ diagonal matrix which is zero except for its `min(m,n)` diagonal elements. $U\_{j}$ (left singular vectors) is an $m \times m$ unitary matrix and $V\_{j}$ (right singular vectors) is a $n \times n$ unitary matrix.
2502/// The diagonal elements of $\Sigma\_{j}$ are the singular values of $A\_{j}$ in either descending order or non-sorting order.
2503///
2504/// `gesvdjBatched` performs `gesvdj` on each matrix.
2505/// It requires that all matrices are of the same size `m,n` no greater than 32 and are packed contiguously,
2506///
2507/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ lda\cdot n\cdot k\rbrack}$.
2508///
2509/// The `S` parameter also contains the singular values of each matrix contiguously,
2510///
2511/// The formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ min(m,n)\cdot k\rbrack}$.
2512///
2513/// Except for tolerance and maximum sweeps, `gesvdjBatched` can either sort the singular values in descending order (default) or choose as-is (without sorting) with [`GesvdjInfo::set_sort_eigenvalues`].
2514/// If several tiny matrices are packed into diagonal blocks of one matrix, the non-sorting option can separate the singular values of those tiny matrices.
2515///
2516/// `gesvdjBatched` cannot report residual and executed sweeps through [`GesvdjInfo::residual`] and [`GesvdjInfo::executed_sweeps`].
2517/// Calling either accessor returns [`Status::NotSupported`].
2518/// Compute the residual explicitly when needed.
2519///
2520/// Provide workspace through `workspace`.
2521/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
2522/// The workspace size in bytes is `size_of::<T>() * lwork`.
2523///
2524/// `dev_info` has one entry per batch item.
2525/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
2526/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdjBatched` did not converge on the `i`th matrix within the given tolerance and maximum sweep count.
2527///
2528/// # Errors
2529///
2530/// Returns an error if cuSOLVER has not been initialized, if the
2531/// matrix dimensions, leading dimensions, vector-computation mode, or batch
2532/// size are invalid, or if cuSOLVER reports an internal failure.
2533pub fn dgesvdj_batched(
2534    ctx: &Context,
2535    jobz: EigenMode,
2536    m: usize,
2537    n: usize,
2538    a: MatrixMut<'_, f64>,
2539    s: &mut DeviceMemory<f64>,
2540    u: Option<MatrixMut<'_, f64>>,
2541    v: Option<MatrixMut<'_, f64>>,
2542    workspace: &mut DeviceMemory<f64>,
2543    dev_info: &mut DeviceMemory<i32>,
2544    params: &GesvdjInfo,
2545    batch_size: usize,
2546) -> Result<()> {
2547    ctx.bind()?;
2548    validate_gesvdj_batched_inputs(
2549        m,
2550        n,
2551        a.data.len(),
2552        a.leading_dimension,
2553        s.len(),
2554        jobz,
2555        matrix_mut_ref_parts(u.as_ref()),
2556        matrix_mut_ref_parts(v.as_ref()),
2557        batch_size,
2558    )?;
2559    require_info_buffer_len(dev_info, batch_size)?;
2560    let lwork = dgesvdj_batched_buffer_size(
2561        ctx,
2562        jobz,
2563        m,
2564        n,
2565        a.as_ref(),
2566        s,
2567        matrix_mut_ref_option(u.as_ref()),
2568        matrix_mut_ref_option(v.as_ref()),
2569        params,
2570        batch_size,
2571    )?;
2572    require_workspace(workspace.len(), lwork)?;
2573    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2574    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2575    unsafe {
2576        try_ffi!(sys::cusolverDnDgesvdjBatched(
2577            ctx.as_raw(),
2578            jobz.into(),
2579            to_i32(m, "m")?,
2580            to_i32(n, "n")?,
2581            a.data.as_mut_ptr().cast(),
2582            to_i32(a.leading_dimension, "lda")?,
2583            s.as_mut_ptr().cast(),
2584            u_ptr.cast(),
2585            ldu,
2586            v_ptr.cast(),
2587            ldv,
2588            workspace.as_mut_ptr().cast(),
2589            to_i32(lwork, "lwork")?,
2590            dev_info.as_mut_ptr().cast(),
2591            params.as_raw(),
2592            to_i32(batch_size, "batch_size")?,
2593        ))?;
2594    }
2595    Ok(())
2596}
2597
2598/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
2599///
2600/// The S and D data types are real valued single and double precision, respectively.
2601///
2602/// The C and Z data types are complex valued single and double precision, respectively.
2603///
2604/// Computes singular values and singular vectors of a sequence of general $m \times n$ matrices
2605///
2606/// where $\Sigma\_{j}$ is a real $m \times n$ diagonal matrix which is zero except for its `min(m,n)` diagonal elements. $U\_{j}$ (left singular vectors) is an $m \times m$ unitary matrix and $V\_{j}$ (right singular vectors) is a $n \times n$ unitary matrix.
2607/// The diagonal elements of $\Sigma\_{j}$ are the singular values of $A\_{j}$ in either descending order or non-sorting order.
2608///
2609/// `gesvdjBatched` performs `gesvdj` on each matrix.
2610/// It requires that all matrices are of the same size `m,n` no greater than 32 and are packed contiguously,
2611///
2612/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ lda\cdot n\cdot k\rbrack}$.
2613///
2614/// The `S` parameter also contains the singular values of each matrix contiguously,
2615///
2616/// The formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ min(m,n)\cdot k\rbrack}$.
2617///
2618/// Except for tolerance and maximum sweeps, `gesvdjBatched` can either sort the singular values in descending order (default) or choose as-is (without sorting) with [`GesvdjInfo::set_sort_eigenvalues`].
2619/// If several tiny matrices are packed into diagonal blocks of one matrix, the non-sorting option can separate the singular values of those tiny matrices.
2620///
2621/// `gesvdjBatched` cannot report residual and executed sweeps through [`GesvdjInfo::residual`] and [`GesvdjInfo::executed_sweeps`].
2622/// Calling either accessor returns [`Status::NotSupported`].
2623/// Compute the residual explicitly when needed.
2624///
2625/// Provide workspace through `workspace`.
2626/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
2627/// The workspace size in bytes is `size_of::<T>() * lwork`.
2628///
2629/// `dev_info` has one entry per batch item.
2630/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
2631/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdjBatched` did not converge on the `i`th matrix within the given tolerance and maximum sweep count.
2632///
2633/// # Errors
2634///
2635/// Returns an error if cuSOLVER has not been initialized, if the
2636/// matrix dimensions, leading dimensions, vector-computation mode, or batch
2637/// size are invalid, or if cuSOLVER reports an internal failure.
2638pub fn cgesvdj_batched(
2639    ctx: &Context,
2640    jobz: EigenMode,
2641    m: usize,
2642    n: usize,
2643    a: MatrixMut<'_, Complex32>,
2644    s: &mut DeviceMemory<f32>,
2645    u: Option<MatrixMut<'_, Complex32>>,
2646    v: Option<MatrixMut<'_, Complex32>>,
2647    workspace: &mut DeviceMemory<Complex32>,
2648    dev_info: &mut DeviceMemory<i32>,
2649    params: &GesvdjInfo,
2650    batch_size: usize,
2651) -> Result<()> {
2652    ctx.bind()?;
2653    validate_gesvdj_batched_inputs(
2654        m,
2655        n,
2656        a.data.len(),
2657        a.leading_dimension,
2658        s.len(),
2659        jobz,
2660        matrix_mut_ref_parts(u.as_ref()),
2661        matrix_mut_ref_parts(v.as_ref()),
2662        batch_size,
2663    )?;
2664    require_info_buffer_len(dev_info, batch_size)?;
2665    let lwork = cgesvdj_batched_buffer_size(
2666        ctx,
2667        jobz,
2668        m,
2669        n,
2670        a.as_ref(),
2671        s,
2672        matrix_mut_ref_option(u.as_ref()),
2673        matrix_mut_ref_option(v.as_ref()),
2674        params,
2675        batch_size,
2676    )?;
2677    require_workspace(workspace.len(), lwork)?;
2678    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2679    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2680    unsafe {
2681        try_ffi!(sys::cusolverDnCgesvdjBatched(
2682            ctx.as_raw(),
2683            jobz.into(),
2684            to_i32(m, "m")?,
2685            to_i32(n, "n")?,
2686            a.data.as_mut_ptr().cast(),
2687            to_i32(a.leading_dimension, "lda")?,
2688            s.as_mut_ptr().cast(),
2689            u_ptr.cast(),
2690            ldu,
2691            v_ptr.cast(),
2692            ldv,
2693            workspace.as_mut_ptr().cast(),
2694            to_i32(lwork, "lwork")?,
2695            dev_info.as_mut_ptr().cast(),
2696            params.as_raw(),
2697            to_i32(batch_size, "batch_size")?,
2698        ))?;
2699    }
2700    Ok(())
2701}
2702
2703/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
2704///
2705/// The S and D data types are real valued single and double precision, respectively.
2706///
2707/// The C and Z data types are complex valued single and double precision, respectively.
2708///
2709/// Computes singular values and singular vectors of a sequence of general $m \times n$ matrices
2710///
2711/// where $\Sigma\_{j}$ is a real $m \times n$ diagonal matrix which is zero except for its `min(m,n)` diagonal elements. $U\_{j}$ (left singular vectors) is an $m \times m$ unitary matrix and $V\_{j}$ (right singular vectors) is a $n \times n$ unitary matrix.
2712/// The diagonal elements of $\Sigma\_{j}$ are the singular values of $A\_{j}$ in either descending order or non-sorting order.
2713///
2714/// `gesvdjBatched` performs `gesvdj` on each matrix.
2715/// It requires that all matrices are of the same size `m,n` no greater than 32 and are packed contiguously,
2716///
2717/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ lda\cdot n\cdot k\rbrack}$.
2718///
2719/// The `S` parameter also contains the singular values of each matrix contiguously,
2720///
2721/// The formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ min(m,n)\cdot k\rbrack}$.
2722///
2723/// Except for tolerance and maximum sweeps, `gesvdjBatched` can either sort the singular values in descending order (default) or choose as-is (without sorting) with [`GesvdjInfo::set_sort_eigenvalues`].
2724/// If several tiny matrices are packed into diagonal blocks of one matrix, the non-sorting option can separate the singular values of those tiny matrices.
2725///
2726/// `gesvdjBatched` cannot report residual and executed sweeps through [`GesvdjInfo::residual`] and [`GesvdjInfo::executed_sweeps`].
2727/// Calling either accessor returns [`Status::NotSupported`].
2728/// Compute the residual explicitly when needed.
2729///
2730/// Provide workspace through `workspace`.
2731/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
2732/// The workspace size in bytes is `size_of::<T>() * lwork`.
2733///
2734/// `dev_info` has one entry per batch item.
2735/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
2736/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdjBatched` did not converge on the `i`th matrix within the given tolerance and maximum sweep count.
2737///
2738/// # Errors
2739///
2740/// Returns an error if cuSOLVER has not been initialized, if the
2741/// matrix dimensions, leading dimensions, vector-computation mode, or batch
2742/// size are invalid, or if cuSOLVER reports an internal failure.
2743pub fn zgesvdj_batched(
2744    ctx: &Context,
2745    jobz: EigenMode,
2746    m: usize,
2747    n: usize,
2748    a: MatrixMut<'_, Complex64>,
2749    s: &mut DeviceMemory<f64>,
2750    u: Option<MatrixMut<'_, Complex64>>,
2751    v: Option<MatrixMut<'_, Complex64>>,
2752    workspace: &mut DeviceMemory<Complex64>,
2753    dev_info: &mut DeviceMemory<i32>,
2754    params: &GesvdjInfo,
2755    batch_size: usize,
2756) -> Result<()> {
2757    ctx.bind()?;
2758    validate_gesvdj_batched_inputs(
2759        m,
2760        n,
2761        a.data.len(),
2762        a.leading_dimension,
2763        s.len(),
2764        jobz,
2765        matrix_mut_ref_parts(u.as_ref()),
2766        matrix_mut_ref_parts(v.as_ref()),
2767        batch_size,
2768    )?;
2769    require_info_buffer_len(dev_info, batch_size)?;
2770    let lwork = zgesvdj_batched_buffer_size(
2771        ctx,
2772        jobz,
2773        m,
2774        n,
2775        a.as_ref(),
2776        s,
2777        matrix_mut_ref_option(u.as_ref()),
2778        matrix_mut_ref_option(v.as_ref()),
2779        params,
2780        batch_size,
2781    )?;
2782    require_workspace(workspace.len(), lwork)?;
2783    let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
2784    let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
2785    unsafe {
2786        try_ffi!(sys::cusolverDnZgesvdjBatched(
2787            ctx.as_raw(),
2788            jobz.into(),
2789            to_i32(m, "m")?,
2790            to_i32(n, "n")?,
2791            a.data.as_mut_ptr().cast(),
2792            to_i32(a.leading_dimension, "lda")?,
2793            s.as_mut_ptr().cast(),
2794            u_ptr.cast(),
2795            ldu,
2796            v_ptr.cast(),
2797            ldv,
2798            workspace.as_mut_ptr().cast(),
2799            to_i32(lwork, "lwork")?,
2800            dev_info.as_mut_ptr().cast(),
2801            params.as_raw(),
2802            to_i32(batch_size, "batch_size")?,
2803        ))?;
2804    }
2805    Ok(())
2806}
2807
2808pub fn sgesvda_strided_batched_buffer_size(
2809    ctx: &Context,
2810    jobz: EigenMode,
2811    rank: usize,
2812    m: usize,
2813    n: usize,
2814    a: StridedBatchedMatrixRef<'_, f32>,
2815    s: StridedBatchedVectorRef<'_, f32>,
2816    u: Option<StridedBatchedMatrixRef<'_, f32>>,
2817    v: Option<StridedBatchedMatrixRef<'_, f32>>,
2818    batch_size: usize,
2819) -> Result<usize> {
2820    ctx.bind()?;
2821    validate_gesvda_strided_batched_inputs(
2822        rank,
2823        m,
2824        n,
2825        a.data.len(),
2826        a.leading_dimension,
2827        a.stride,
2828        s.data.len(),
2829        s.stride,
2830        jobz,
2831        strided_batched_matrix_ref_parts(u),
2832        strided_batched_matrix_ref_parts(v),
2833        batch_size,
2834    )?;
2835    let (u_ptr, ldu, stride_u) =
2836        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2837    let (v_ptr, ldv, stride_v) =
2838        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2839    let mut lwork = 0;
2840    unsafe {
2841        try_ffi!(sys::cusolverDnSgesvdaStridedBatched_bufferSize(
2842            ctx.as_raw(),
2843            jobz.into(),
2844            to_i32(rank, "rank")?,
2845            to_i32(m, "m")?,
2846            to_i32(n, "n")?,
2847            a.data.as_ptr().cast(),
2848            to_i32(a.leading_dimension, "lda")?,
2849            to_i64(a.stride, "stride_a")?,
2850            s.data.as_ptr().cast(),
2851            to_i64(s.stride, "stride_s")?,
2852            u_ptr.cast(),
2853            ldu,
2854            stride_u,
2855            v_ptr.cast(),
2856            ldv,
2857            stride_v,
2858            &raw mut lwork,
2859            to_i32(batch_size, "batch_size")?,
2860        ))?;
2861    }
2862    to_usize(lwork, "lwork")
2863}
2864
2865pub fn dgesvda_strided_batched_buffer_size(
2866    ctx: &Context,
2867    jobz: EigenMode,
2868    rank: usize,
2869    m: usize,
2870    n: usize,
2871    a: StridedBatchedMatrixRef<'_, f64>,
2872    s: StridedBatchedVectorRef<'_, f64>,
2873    u: Option<StridedBatchedMatrixRef<'_, f64>>,
2874    v: Option<StridedBatchedMatrixRef<'_, f64>>,
2875    batch_size: usize,
2876) -> Result<usize> {
2877    ctx.bind()?;
2878    validate_gesvda_strided_batched_inputs(
2879        rank,
2880        m,
2881        n,
2882        a.data.len(),
2883        a.leading_dimension,
2884        a.stride,
2885        s.data.len(),
2886        s.stride,
2887        jobz,
2888        strided_batched_matrix_ref_parts(u),
2889        strided_batched_matrix_ref_parts(v),
2890        batch_size,
2891    )?;
2892    let (u_ptr, ldu, stride_u) =
2893        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2894    let (v_ptr, ldv, stride_v) =
2895        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2896    let mut lwork = 0;
2897    unsafe {
2898        try_ffi!(sys::cusolverDnDgesvdaStridedBatched_bufferSize(
2899            ctx.as_raw(),
2900            jobz.into(),
2901            to_i32(rank, "rank")?,
2902            to_i32(m, "m")?,
2903            to_i32(n, "n")?,
2904            a.data.as_ptr().cast(),
2905            to_i32(a.leading_dimension, "lda")?,
2906            to_i64(a.stride, "stride_a")?,
2907            s.data.as_ptr().cast(),
2908            to_i64(s.stride, "stride_s")?,
2909            u_ptr.cast(),
2910            ldu,
2911            stride_u,
2912            v_ptr.cast(),
2913            ldv,
2914            stride_v,
2915            &raw mut lwork,
2916            to_i32(batch_size, "batch_size")?,
2917        ))?;
2918    }
2919    to_usize(lwork, "lwork")
2920}
2921
2922pub fn cgesvda_strided_batched_buffer_size(
2923    ctx: &Context,
2924    jobz: EigenMode,
2925    rank: usize,
2926    m: usize,
2927    n: usize,
2928    a: StridedBatchedMatrixRef<'_, Complex32>,
2929    s: StridedBatchedVectorRef<'_, f32>,
2930    u: Option<StridedBatchedMatrixRef<'_, Complex32>>,
2931    v: Option<StridedBatchedMatrixRef<'_, Complex32>>,
2932    batch_size: usize,
2933) -> Result<usize> {
2934    ctx.bind()?;
2935    validate_gesvda_strided_batched_inputs(
2936        rank,
2937        m,
2938        n,
2939        a.data.len(),
2940        a.leading_dimension,
2941        a.stride,
2942        s.data.len(),
2943        s.stride,
2944        jobz,
2945        strided_batched_matrix_ref_parts(u),
2946        strided_batched_matrix_ref_parts(v),
2947        batch_size,
2948    )?;
2949    let (u_ptr, ldu, stride_u) =
2950        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
2951    let (v_ptr, ldv, stride_v) =
2952        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
2953    let mut lwork = 0;
2954    unsafe {
2955        try_ffi!(sys::cusolverDnCgesvdaStridedBatched_bufferSize(
2956            ctx.as_raw(),
2957            jobz.into(),
2958            to_i32(rank, "rank")?,
2959            to_i32(m, "m")?,
2960            to_i32(n, "n")?,
2961            a.data.as_ptr().cast(),
2962            to_i32(a.leading_dimension, "lda")?,
2963            to_i64(a.stride, "stride_a")?,
2964            s.data.as_ptr().cast(),
2965            to_i64(s.stride, "stride_s")?,
2966            u_ptr.cast(),
2967            ldu,
2968            stride_u,
2969            v_ptr.cast(),
2970            ldv,
2971            stride_v,
2972            &raw mut lwork,
2973            to_i32(batch_size, "batch_size")?,
2974        ))?;
2975    }
2976    to_usize(lwork, "lwork")
2977}
2978
2979pub fn zgesvda_strided_batched_buffer_size(
2980    ctx: &Context,
2981    jobz: EigenMode,
2982    rank: usize,
2983    m: usize,
2984    n: usize,
2985    a: StridedBatchedMatrixRef<'_, Complex64>,
2986    s: StridedBatchedVectorRef<'_, f64>,
2987    u: Option<StridedBatchedMatrixRef<'_, Complex64>>,
2988    v: Option<StridedBatchedMatrixRef<'_, Complex64>>,
2989    batch_size: usize,
2990) -> Result<usize> {
2991    ctx.bind()?;
2992    validate_gesvda_strided_batched_inputs(
2993        rank,
2994        m,
2995        n,
2996        a.data.len(),
2997        a.leading_dimension,
2998        a.stride,
2999        s.data.len(),
3000        s.stride,
3001        jobz,
3002        strided_batched_matrix_ref_parts(u),
3003        strided_batched_matrix_ref_parts(v),
3004        batch_size,
3005    )?;
3006    let (u_ptr, ldu, stride_u) =
3007        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
3008    let (v_ptr, ldv, stride_v) =
3009        optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
3010    let mut lwork = 0;
3011    unsafe {
3012        try_ffi!(sys::cusolverDnZgesvdaStridedBatched_bufferSize(
3013            ctx.as_raw(),
3014            jobz.into(),
3015            to_i32(rank, "rank")?,
3016            to_i32(m, "m")?,
3017            to_i32(n, "n")?,
3018            a.data.as_ptr().cast(),
3019            to_i32(a.leading_dimension, "lda")?,
3020            to_i64(a.stride, "stride_a")?,
3021            s.data.as_ptr().cast(),
3022            to_i64(s.stride, "stride_s")?,
3023            u_ptr.cast(),
3024            ldu,
3025            stride_u,
3026            v_ptr.cast(),
3027            ldv,
3028            stride_v,
3029            &raw mut lwork,
3030            to_i32(batch_size, "batch_size")?,
3031        ))?;
3032    }
3033    to_usize(lwork, "lwork")
3034}
3035
3036/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
3037///
3038/// The S and D data types are real valued single and double precision, respectively.
3039///
3040/// The C and Z data types are complex valued single and double precision, respectively.
3041///
3042/// `gesvda` (`a` stands for approximate) approximates the singular value decomposition of a tall skinny $m \times n$ matrix `A` and the corresponding left and right singular vectors.
3043/// The economy form of SVD is written as $A = U \Sigma V^{H}$, where `Σ` is
3044/// an $n \times n$ matrix.
3045/// `U` is an $m \times n$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
3046/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
3047/// `U` and `V` are the left and right singular vectors of `A`.
3048///
3049/// `gesvda` computes eigenvalues of $A^{T}A$, or $A^{H}A$ if `A` is
3050/// complex, to approximate singular values and singular vectors.
3051/// It generates matrices `U` and `V` and transforms matrix `A` to
3052/// $A = U(S + E)V^{H}$, where `S` is diagonal and `E` depends on rounding
3053/// errors.
3054/// To certain conditions, `U`, `V` and `S` approximate singular values and singular vectors up to machine zero of single precision.
3055/// In general, `V` is unitary, `S` is more accurate than `U`.
3056/// If singular value is far from zero, then left singular vector `U` is accurate.
3057/// In other words, the accuracy of singular values and left singular vectors depend on the distance between singular value and zero.
3058/// Since computing $A^{T}A$ or $A^{H}A$ can greatly amplify errors, use
3059/// `gesvda` only with well-conditioned data.
3060///
3061/// `rank` controls how many singular values and singular vectors are computed in `S`, `U`, and `V`.
3062///
3063/// `residual`, when requested, receives the Frobenius norm of the residual.
3064/// When `rank == n`, it measures how well `A` is approximated by the computed SVD.
3065/// Otherwise, it reports in the Frobenius norm sense how far `U` is from unitary.
3066///
3067/// `gesvdaStridedBatched` performs `gesvda` on each matrix.
3068/// It requires that all matrices are of the same size `m,n` and are packed contiguously,
3069///
3070/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ stride_a\cdot k\rbrack}$.
3071/// Similarly, the formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ stride_s\cdot k\rbrack}$, the formula for random access of `U` is $U\_{k}\operatorname{(i,j)} = {U\lbrack\ i\ +\ ldu\cdot j\ +\ stride_u\cdot k\rbrack}$ and the formula for random access of `V` is $V\_{k}\operatorname{(i,j)} = {V\lbrack\ i\ +\ ldv\cdot j\ +\ stride_v\cdot k\rbrack}$.
3072///
3073/// Provide workspace through `workspace`.
3074/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
3075/// The workspace size in bytes is `size_of::<T>() * lwork`.
3076///
3077/// `dev_info` has one entry per batch item.
3078/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
3079/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdaStridedBatched` did not converge on the `i`th matrix.
3080/// If `0 < dev_info[i] < min(m, n) + 1`, `gesvdaStridedBatched` could not compute an SVD of the `i`th matrix fully; the leading singular values `S_i[k]`, `0 <= k <= dev_info[i] - 1`, and corresponding singular vectors may still be useful.
3081/// In this case, if `residual` is requested, it is reported as if `rank` was set to `dev_info[i] - 1`.
3082///
3083/// The problem size is limited by `batch_size * stride{A/S/U/V} <= INT32_MAX` primarily due to the current implementation constraints.
3084///
3085/// - Returns `V`, not $V^{H}$.
3086///   This is different from `gesvd`.
3087///
3088/// - Only supports `m >= n`.
3089///
3090/// - Prefer an FP64 data type, such as `DgesvdaStridedBatched` or `ZgesvdaStridedBatched`.
3091///
3092/// - If singular values and singular vectors are known to be accurate, for example when the required singular value is far from zero, performance can be improved by passing `None` for `residual`, with no residual norm computation.
3093///
3094/// # Errors
3095///
3096/// Returns an error if cuSOLVER has not been initialized, if the
3097/// matrix dimensions, leading dimensions, vector-computation mode, strides, or
3098/// batch size are invalid, or if cuSOLVER reports an internal failure.
3099pub fn sgesvda_strided_batched(
3100    ctx: &Context,
3101    jobz: EigenMode,
3102    rank: usize,
3103    m: usize,
3104    n: usize,
3105    a: StridedBatchedMatrixRef<'_, f32>,
3106    s: StridedBatchedVectorMut<'_, f32>,
3107    u: Option<StridedBatchedMatrixMut<'_, f32>>,
3108    v: Option<StridedBatchedMatrixMut<'_, f32>>,
3109    workspace: &mut DeviceMemory<f32>,
3110    dev_info: &mut DeviceMemory<i32>,
3111    residual: Option<&mut f64>,
3112    batch_size: usize,
3113) -> Result<()> {
3114    ctx.bind()?;
3115    validate_gesvda_strided_batched_inputs(
3116        rank,
3117        m,
3118        n,
3119        a.data.len(),
3120        a.leading_dimension,
3121        a.stride,
3122        s.data.len(),
3123        s.stride,
3124        jobz,
3125        strided_batched_matrix_mut_ref_option(u.as_ref())
3126            .map(|m| (m.data, m.leading_dimension, m.stride)),
3127        strided_batched_matrix_mut_ref_option(v.as_ref())
3128            .map(|m| (m.data, m.leading_dimension, m.stride)),
3129        batch_size,
3130    )?;
3131    require_info_buffer_len(dev_info, batch_size)?;
3132    let lwork = sgesvda_strided_batched_buffer_size(
3133        ctx,
3134        jobz,
3135        rank,
3136        m,
3137        n,
3138        a,
3139        s.as_ref(),
3140        strided_batched_matrix_mut_ref_option(u.as_ref()),
3141        strided_batched_matrix_mut_ref_option(v.as_ref()),
3142        batch_size,
3143    )?;
3144    require_workspace(workspace.len(), lwork)?;
3145    let (u_ptr, ldu, stride_u) =
3146        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3147    let (v_ptr, ldv, stride_v) =
3148        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3149    unsafe {
3150        try_ffi!(sys::cusolverDnSgesvdaStridedBatched(
3151            ctx.as_raw(),
3152            jobz.into(),
3153            to_i32(rank, "rank")?,
3154            to_i32(m, "m")?,
3155            to_i32(n, "n")?,
3156            a.data.as_ptr().cast(),
3157            to_i32(a.leading_dimension, "lda")?,
3158            to_i64(a.stride, "stride_a")?,
3159            s.data.as_mut_ptr().cast(),
3160            to_i64(s.stride, "stride_s")?,
3161            u_ptr.cast(),
3162            ldu,
3163            stride_u,
3164            v_ptr.cast(),
3165            ldv,
3166            stride_v,
3167            workspace.as_mut_ptr().cast(),
3168            to_i32(lwork, "lwork")?,
3169            dev_info.as_mut_ptr().cast(),
3170            residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3171            to_i32(batch_size, "batch_size")?,
3172        ))?;
3173    }
3174    Ok(())
3175}
3176
3177/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
3178///
3179/// The S and D data types are real valued single and double precision, respectively.
3180///
3181/// The C and Z data types are complex valued single and double precision, respectively.
3182///
3183/// `gesvda` (`a` stands for approximate) approximates the singular value decomposition of a tall skinny $m \times n$ matrix `A` and the corresponding left and right singular vectors.
3184/// The economy form of SVD is written as $A = U \Sigma V^{H}$, where `Σ` is
3185/// an $n \times n$ matrix.
3186/// `U` is an $m \times n$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
3187/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
3188/// `U` and `V` are the left and right singular vectors of `A`.
3189///
3190/// `gesvda` computes eigenvalues of $A^{T}A$, or $A^{H}A$ if `A` is
3191/// complex, to approximate singular values and singular vectors.
3192/// It generates matrices `U` and `V` and transforms matrix `A` to
3193/// $A = U(S + E)V^{H}$, where `S` is diagonal and `E` depends on rounding
3194/// errors.
3195/// To certain conditions, `U`, `V` and `S` approximate singular values and singular vectors up to machine zero of single precision.
3196/// In general, `V` is unitary, `S` is more accurate than `U`.
3197/// If singular value is far from zero, then left singular vector `U` is accurate.
3198/// In other words, the accuracy of singular values and left singular vectors depend on the distance between singular value and zero.
3199/// Since computing $A^{T}A$ or $A^{H}A$ can greatly amplify errors, use
3200/// `gesvda` only with well-conditioned data.
3201///
3202/// `rank` controls how many singular values and singular vectors are computed in `S`, `U`, and `V`.
3203///
3204/// `residual`, when requested, receives the Frobenius norm of the residual.
3205/// When `rank == n`, it measures how well `A` is approximated by the computed SVD.
3206/// Otherwise, it reports in the Frobenius norm sense how far `U` is from unitary.
3207///
3208/// `gesvdaStridedBatched` performs `gesvda` on each matrix.
3209/// It requires that all matrices are of the same size `m,n` and are packed contiguously,
3210///
3211/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ stride_a\cdot k\rbrack}$.
3212/// Similarly, the formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ stride_s\cdot k\rbrack}$, the formula for random access of `U` is $U\_{k}\operatorname{(i,j)} = {U\lbrack\ i\ +\ ldu\cdot j\ +\ stride_u\cdot k\rbrack}$ and the formula for random access of `V` is $V\_{k}\operatorname{(i,j)} = {V\lbrack\ i\ +\ ldv\cdot j\ +\ stride_v\cdot k\rbrack}$.
3213///
3214/// Provide workspace through `workspace`.
3215/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
3216/// The workspace size in bytes is `size_of::<T>() * lwork`.
3217///
3218/// `dev_info` has one entry per batch item.
3219/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
3220/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdaStridedBatched` did not converge on the `i`th matrix.
3221/// If `0 < dev_info[i] < min(m, n) + 1`, `gesvdaStridedBatched` could not compute an SVD of the `i`th matrix fully; the leading singular values `S_i[k]`, `0 <= k <= dev_info[i] - 1`, and corresponding singular vectors may still be useful.
3222/// In this case, if `residual` is requested, it is reported as if `rank` was set to `dev_info[i] - 1`.
3223///
3224/// The problem size is limited by `batch_size * stride{A/S/U/V} <= INT32_MAX` primarily due to the current implementation constraints.
3225///
3226/// - Returns `V`, not $V^{H}$.
3227///   This is different from `gesvd`.
3228///
3229/// - Only supports `m >= n`.
3230///
3231/// - Prefer an FP64 data type, such as `DgesvdaStridedBatched` or `ZgesvdaStridedBatched`.
3232///
3233/// - If singular values and singular vectors are known to be accurate, for example when the required singular value is far from zero, performance can be improved by passing `None` for `residual`, with no residual norm computation.
3234///
3235/// # Errors
3236///
3237/// Returns an error if cuSOLVER has not been initialized, if the
3238/// matrix dimensions, leading dimensions, vector-computation mode, strides, or
3239/// batch size are invalid, or if cuSOLVER reports an internal failure.
3240pub fn dgesvda_strided_batched(
3241    ctx: &Context,
3242    jobz: EigenMode,
3243    rank: usize,
3244    m: usize,
3245    n: usize,
3246    a: StridedBatchedMatrixRef<'_, f64>,
3247    s: StridedBatchedVectorMut<'_, f64>,
3248    u: Option<StridedBatchedMatrixMut<'_, f64>>,
3249    v: Option<StridedBatchedMatrixMut<'_, f64>>,
3250    workspace: &mut DeviceMemory<f64>,
3251    dev_info: &mut DeviceMemory<i32>,
3252    residual: Option<&mut f64>,
3253    batch_size: usize,
3254) -> Result<()> {
3255    ctx.bind()?;
3256    validate_gesvda_strided_batched_inputs(
3257        rank,
3258        m,
3259        n,
3260        a.data.len(),
3261        a.leading_dimension,
3262        a.stride,
3263        s.data.len(),
3264        s.stride,
3265        jobz,
3266        strided_batched_matrix_mut_ref_option(u.as_ref())
3267            .map(|m| (m.data, m.leading_dimension, m.stride)),
3268        strided_batched_matrix_mut_ref_option(v.as_ref())
3269            .map(|m| (m.data, m.leading_dimension, m.stride)),
3270        batch_size,
3271    )?;
3272    require_info_buffer_len(dev_info, batch_size)?;
3273    let lwork = dgesvda_strided_batched_buffer_size(
3274        ctx,
3275        jobz,
3276        rank,
3277        m,
3278        n,
3279        a,
3280        s.as_ref(),
3281        strided_batched_matrix_mut_ref_option(u.as_ref()),
3282        strided_batched_matrix_mut_ref_option(v.as_ref()),
3283        batch_size,
3284    )?;
3285    require_workspace(workspace.len(), lwork)?;
3286    let (u_ptr, ldu, stride_u) =
3287        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3288    let (v_ptr, ldv, stride_v) =
3289        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3290    unsafe {
3291        try_ffi!(sys::cusolverDnDgesvdaStridedBatched(
3292            ctx.as_raw(),
3293            jobz.into(),
3294            to_i32(rank, "rank")?,
3295            to_i32(m, "m")?,
3296            to_i32(n, "n")?,
3297            a.data.as_ptr().cast(),
3298            to_i32(a.leading_dimension, "lda")?,
3299            to_i64(a.stride, "stride_a")?,
3300            s.data.as_mut_ptr().cast(),
3301            to_i64(s.stride, "stride_s")?,
3302            u_ptr.cast(),
3303            ldu,
3304            stride_u,
3305            v_ptr.cast(),
3306            ldv,
3307            stride_v,
3308            workspace.as_mut_ptr().cast(),
3309            to_i32(lwork, "lwork")?,
3310            dev_info.as_mut_ptr().cast(),
3311            residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3312            to_i32(batch_size, "batch_size")?,
3313        ))?;
3314    }
3315    Ok(())
3316}
3317
3318/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
3319///
3320/// The S and D data types are real valued single and double precision, respectively.
3321///
3322/// The C and Z data types are complex valued single and double precision, respectively.
3323///
3324/// `gesvda` (`a` stands for approximate) approximates the singular value decomposition of a tall skinny $m \times n$ matrix `A` and the corresponding left and right singular vectors.
3325/// The economy form of SVD is written as $A = U \Sigma V^{H}$, where `Σ` is
3326/// an $n \times n$ matrix.
3327/// `U` is an $m \times n$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
3328/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
3329/// `U` and `V` are the left and right singular vectors of `A`.
3330///
3331/// `gesvda` computes eigenvalues of $A^{T}A$, or $A^{H}A$ if `A` is
3332/// complex, to approximate singular values and singular vectors.
3333/// It generates matrices `U` and `V` and transforms matrix `A` to
3334/// $A = U(S + E)V^{H}$, where `S` is diagonal and `E` depends on rounding
3335/// errors.
3336/// To certain conditions, `U`, `V` and `S` approximate singular values and singular vectors up to machine zero of single precision.
3337/// In general, `V` is unitary, `S` is more accurate than `U`.
3338/// If singular value is far from zero, then left singular vector `U` is accurate.
3339/// In other words, the accuracy of singular values and left singular vectors depend on the distance between singular value and zero.
3340/// Since computing $A^{T}A$ or $A^{H}A$ can greatly amplify errors, use
3341/// `gesvda` only with well-conditioned data.
3342///
3343/// `rank` controls how many singular values and singular vectors are computed in `S`, `U`, and `V`.
3344///
3345/// `residual`, when requested, receives the Frobenius norm of the residual.
3346/// When `rank == n`, it measures how well `A` is approximated by the computed SVD.
3347/// Otherwise, it reports in the Frobenius norm sense how far `U` is from unitary.
3348///
3349/// `gesvdaStridedBatched` performs `gesvda` on each matrix.
3350/// It requires that all matrices are of the same size `m,n` and are packed contiguously,
3351///
3352/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ stride_a\cdot k\rbrack}$.
3353/// Similarly, the formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ stride_s\cdot k\rbrack}$, the formula for random access of `U` is $U\_{k}\operatorname{(i,j)} = {U\lbrack\ i\ +\ ldu\cdot j\ +\ stride_u\cdot k\rbrack}$ and the formula for random access of `V` is $V\_{k}\operatorname{(i,j)} = {V\lbrack\ i\ +\ ldv\cdot j\ +\ stride_v\cdot k\rbrack}$.
3354///
3355/// Provide workspace through `workspace`.
3356/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
3357/// The workspace size in bytes is `size_of::<T>() * lwork`.
3358///
3359/// `dev_info` has one entry per batch item.
3360/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
3361/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdaStridedBatched` did not converge on the `i`th matrix.
3362/// If `0 < dev_info[i] < min(m, n) + 1`, `gesvdaStridedBatched` could not compute an SVD of the `i`th matrix fully; the leading singular values `S_i[k]`, `0 <= k <= dev_info[i] - 1`, and corresponding singular vectors may still be useful.
3363/// In this case, if `residual` is requested, it is reported as if `rank` was set to `dev_info[i] - 1`.
3364///
3365/// The problem size is limited by `batch_size * stride{A/S/U/V} <= INT32_MAX` primarily due to the current implementation constraints.
3366///
3367/// - Returns `V`, not $V^{H}$.
3368///   This is different from `gesvd`.
3369///
3370/// - Only supports `m >= n`.
3371///
3372/// - Prefer an FP64 data type, such as `DgesvdaStridedBatched` or `ZgesvdaStridedBatched`.
3373///
3374/// - If singular values and singular vectors are known to be accurate, for example when the required singular value is far from zero, performance can be improved by passing `None` for `residual`, with no residual norm computation.
3375///
3376/// # Errors
3377///
3378/// Returns an error if cuSOLVER has not been initialized, if the
3379/// matrix dimensions, leading dimensions, vector-computation mode, strides, or
3380/// batch size are invalid, or if cuSOLVER reports an internal failure.
3381pub fn cgesvda_strided_batched(
3382    ctx: &Context,
3383    jobz: EigenMode,
3384    rank: usize,
3385    m: usize,
3386    n: usize,
3387    a: StridedBatchedMatrixRef<'_, Complex32>,
3388    s: StridedBatchedVectorMut<'_, f32>,
3389    u: Option<StridedBatchedMatrixMut<'_, Complex32>>,
3390    v: Option<StridedBatchedMatrixMut<'_, Complex32>>,
3391    workspace: &mut DeviceMemory<Complex32>,
3392    dev_info: &mut DeviceMemory<i32>,
3393    residual: Option<&mut f64>,
3394    batch_size: usize,
3395) -> Result<()> {
3396    ctx.bind()?;
3397    validate_gesvda_strided_batched_inputs(
3398        rank,
3399        m,
3400        n,
3401        a.data.len(),
3402        a.leading_dimension,
3403        a.stride,
3404        s.data.len(),
3405        s.stride,
3406        jobz,
3407        strided_batched_matrix_mut_ref_option(u.as_ref())
3408            .map(|m| (m.data, m.leading_dimension, m.stride)),
3409        strided_batched_matrix_mut_ref_option(v.as_ref())
3410            .map(|m| (m.data, m.leading_dimension, m.stride)),
3411        batch_size,
3412    )?;
3413    require_info_buffer_len(dev_info, batch_size)?;
3414    let lwork = cgesvda_strided_batched_buffer_size(
3415        ctx,
3416        jobz,
3417        rank,
3418        m,
3419        n,
3420        a,
3421        s.as_ref(),
3422        strided_batched_matrix_mut_ref_option(u.as_ref()),
3423        strided_batched_matrix_mut_ref_option(v.as_ref()),
3424        batch_size,
3425    )?;
3426    require_workspace(workspace.len(), lwork)?;
3427    let (u_ptr, ldu, stride_u) =
3428        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3429    let (v_ptr, ldv, stride_v) =
3430        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3431    unsafe {
3432        try_ffi!(sys::cusolverDnCgesvdaStridedBatched(
3433            ctx.as_raw(),
3434            jobz.into(),
3435            to_i32(rank, "rank")?,
3436            to_i32(m, "m")?,
3437            to_i32(n, "n")?,
3438            a.data.as_ptr().cast(),
3439            to_i32(a.leading_dimension, "lda")?,
3440            to_i64(a.stride, "stride_a")?,
3441            s.data.as_mut_ptr().cast(),
3442            to_i64(s.stride, "stride_s")?,
3443            u_ptr.cast(),
3444            ldu,
3445            stride_u,
3446            v_ptr.cast(),
3447            ldv,
3448            stride_v,
3449            workspace.as_mut_ptr().cast(),
3450            to_i32(lwork, "lwork")?,
3451            dev_info.as_mut_ptr().cast(),
3452            residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3453            to_i32(batch_size, "batch_size")?,
3454        ))?;
3455    }
3456    Ok(())
3457}
3458
3459/// Use the matching buffer-size helper to calculate the sizes needed for pre-allocated workspace.
3460///
3461/// The S and D data types are real valued single and double precision, respectively.
3462///
3463/// The C and Z data types are complex valued single and double precision, respectively.
3464///
3465/// `gesvda` (`a` stands for approximate) approximates the singular value decomposition of a tall skinny $m \times n$ matrix `A` and the corresponding left and right singular vectors.
3466/// The economy form of SVD is written as $A = U \Sigma V^{H}$, where `Σ` is
3467/// an $n \times n$ matrix.
3468/// `U` is an $m \times n$ unitary matrix, and `V` is an $n \times n$ unitary matrix.
3469/// The diagonal elements of `Σ` are the singular values of `A`; they are real and non-negative, and are returned in descending order.
3470/// `U` and `V` are the left and right singular vectors of `A`.
3471///
3472/// `gesvda` computes eigenvalues of $A^{T}A$, or $A^{H}A$ if `A` is
3473/// complex, to approximate singular values and singular vectors.
3474/// It generates matrices `U` and `V` and transforms matrix `A` to
3475/// $A = U(S + E)V^{H}$, where `S` is diagonal and `E` depends on rounding
3476/// errors.
3477/// To certain conditions, `U`, `V` and `S` approximate singular values and singular vectors up to machine zero of single precision.
3478/// In general, `V` is unitary, `S` is more accurate than `U`.
3479/// If singular value is far from zero, then left singular vector `U` is accurate.
3480/// In other words, the accuracy of singular values and left singular vectors depend on the distance between singular value and zero.
3481/// Since computing $A^{T}A$ or $A^{H}A$ can greatly amplify errors, use
3482/// `gesvda` only with well-conditioned data.
3483///
3484/// `rank` controls how many singular values and singular vectors are computed in `S`, `U`, and `V`.
3485///
3486/// `residual`, when requested, receives the Frobenius norm of the residual.
3487/// When `rank == n`, it measures how well `A` is approximated by the computed SVD.
3488/// Otherwise, it reports in the Frobenius norm sense how far `U` is from unitary.
3489///
3490/// `gesvdaStridedBatched` performs `gesvda` on each matrix.
3491/// It requires that all matrices are of the same size `m,n` and are packed contiguously,
3492///
3493/// Each matrix is column-major with leading dimension `lda`, so the formula for random access is $A\_{k}\operatorname{(i,j)} = {A\lbrack\ i\ +\ lda\cdot j\ +\ stride_a\cdot k\rbrack}$.
3494/// Similarly, the formula for random access of `S` is $S\_{k}\operatorname{(j)} = {S\lbrack\ j\ +\ stride_s\cdot k\rbrack}$, the formula for random access of `U` is $U\_{k}\operatorname{(i,j)} = {U\lbrack\ i\ +\ ldu\cdot j\ +\ stride_u\cdot k\rbrack}$ and the formula for random access of `V` is $V\_{k}\operatorname{(i,j)} = {V\lbrack\ i\ +\ ldv\cdot j\ +\ stride_v\cdot k\rbrack}$.
3495///
3496/// Provide workspace through `workspace`.
3497/// Use the corresponding `*_buffer_size` helper to query the required workspace length.
3498/// The workspace size in bytes is `size_of::<T>() * lwork`.
3499///
3500/// `dev_info` has one entry per batch item.
3501/// If the call returns [`Status::InvalidValue`], `dev_info[0] == -i` indicates that the `i`th parameter is invalid.
3502/// Otherwise, `dev_info[i] == min(m, n) + 1` indicates that `gesvdaStridedBatched` did not converge on the `i`th matrix.
3503/// If `0 < dev_info[i] < min(m, n) + 1`, `gesvdaStridedBatched` could not compute an SVD of the `i`th matrix fully; the leading singular values `S_i[k]`, `0 <= k <= dev_info[i] - 1`, and corresponding singular vectors may still be useful.
3504/// In this case, if `residual` is requested, it is reported as if `rank` was set to `dev_info[i] - 1`.
3505///
3506/// The problem size is limited by `batch_size * stride{A/S/U/V} <= INT32_MAX` primarily due to the current implementation constraints.
3507///
3508/// - Returns `V`, not $V^{H}$.
3509///   This is different from `gesvd`.
3510///
3511/// - Only supports `m >= n`.
3512///
3513/// - Prefer an FP64 data type, such as `DgesvdaStridedBatched` or `ZgesvdaStridedBatched`.
3514///
3515/// - If singular values and singular vectors are known to be accurate, for example when the required singular value is far from zero, performance can be improved by passing `None` for `residual`, with no residual norm computation.
3516///
3517/// # Errors
3518///
3519/// Returns an error if cuSOLVER has not been initialized, if the
3520/// matrix dimensions, leading dimensions, vector-computation mode, strides, or
3521/// batch size are invalid, or if cuSOLVER reports an internal failure.
3522pub fn zgesvda_strided_batched(
3523    ctx: &Context,
3524    jobz: EigenMode,
3525    rank: usize,
3526    m: usize,
3527    n: usize,
3528    a: StridedBatchedMatrixRef<'_, Complex64>,
3529    s: StridedBatchedVectorMut<'_, f64>,
3530    u: Option<StridedBatchedMatrixMut<'_, Complex64>>,
3531    v: Option<StridedBatchedMatrixMut<'_, Complex64>>,
3532    workspace: &mut DeviceMemory<Complex64>,
3533    dev_info: &mut DeviceMemory<i32>,
3534    residual: Option<&mut f64>,
3535    batch_size: usize,
3536) -> Result<()> {
3537    ctx.bind()?;
3538    validate_gesvda_strided_batched_inputs(
3539        rank,
3540        m,
3541        n,
3542        a.data.len(),
3543        a.leading_dimension,
3544        a.stride,
3545        s.data.len(),
3546        s.stride,
3547        jobz,
3548        strided_batched_matrix_mut_ref_option(u.as_ref())
3549            .map(|m| (m.data, m.leading_dimension, m.stride)),
3550        strided_batched_matrix_mut_ref_option(v.as_ref())
3551            .map(|m| (m.data, m.leading_dimension, m.stride)),
3552        batch_size,
3553    )?;
3554    require_info_buffer_len(dev_info, batch_size)?;
3555    let lwork = zgesvda_strided_batched_buffer_size(
3556        ctx,
3557        jobz,
3558        rank,
3559        m,
3560        n,
3561        a,
3562        s.as_ref(),
3563        strided_batched_matrix_mut_ref_option(u.as_ref()),
3564        strided_batched_matrix_mut_ref_option(v.as_ref()),
3565        batch_size,
3566    )?;
3567    require_workspace(workspace.len(), lwork)?;
3568    let (u_ptr, ldu, stride_u) =
3569        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
3570    let (v_ptr, ldv, stride_v) =
3571        optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
3572    unsafe {
3573        try_ffi!(sys::cusolverDnZgesvdaStridedBatched(
3574            ctx.as_raw(),
3575            jobz.into(),
3576            to_i32(rank, "rank")?,
3577            to_i32(m, "m")?,
3578            to_i32(n, "n")?,
3579            a.data.as_ptr().cast(),
3580            to_i32(a.leading_dimension, "lda")?,
3581            to_i64(a.stride, "stride_a")?,
3582            s.data.as_mut_ptr().cast(),
3583            to_i64(s.stride, "stride_s")?,
3584            u_ptr.cast(),
3585            ldu,
3586            stride_u,
3587            v_ptr.cast(),
3588            ldv,
3589            stride_v,
3590            workspace.as_mut_ptr().cast(),
3591            to_i32(lwork, "lwork")?,
3592            dev_info.as_mut_ptr().cast(),
3593            residual.map_or(ptr::null_mut(), |value| value as *mut f64),
3594            to_i32(batch_size, "batch_size")?,
3595        ))?;
3596    }
3597    Ok(())
3598}
3599
3600fn validate_gesvd_dims(m: usize, n: usize) -> Result<()> {
3601    if m == 0 || n == 0 || m < n {
3602        return Err(Error::InvalidMatrixShape);
3603    }
3604    Ok(())
3605}
3606
3607fn validate_xgesvdp_inputs<TU, TV>(
3608    m: usize,
3609    n: usize,
3610    a_bytes: usize,
3611    lda: usize,
3612    a_type: DataType,
3613    s_bytes: usize,
3614    s_type: DataType,
3615    jobz: EigenMode,
3616    econ: bool,
3617    u: Option<&(&DeviceMemory<TU>, usize)>,
3618    u_type: DataType,
3619    v: Option<&(&DeviceMemory<TV>, usize)>,
3620    v_type: DataType,
3621) -> Result<()> {
3622    if m == 0 || n == 0 {
3623        return Err(Error::InvalidMatrixShape);
3624    }
3625    validate_x_matrix(m, n, a_bytes, lda, a_type)?;
3626    validate_x_vector(m.min(n), s_bytes, s_type)?;
3627    match jobz {
3628        EigenMode::NoVector => Ok(()),
3629        EigenMode::Vector => {
3630            let Some((u, ldu)) = u else {
3631                return Err(Error::InvalidMatrixShape);
3632            };
3633            let Some((v, ldv)) = v else {
3634                return Err(Error::InvalidMatrixShape);
3635            };
3636            validate_x_eig_output(m, n, u.byte_len(), *ldu, econ, u_type)?;
3637            validate_x_eig_output(n, n, v.byte_len(), *ldv, econ, v_type)
3638        }
3639    }
3640}
3641
3642fn matrix_ref_parts<T>(matrix: Option<MatrixRef<'_, T>>) -> Option<(&DeviceMemory<T>, usize)> {
3643    matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
3644}
3645
3646fn matrix_mut_parts<T>(matrix: Option<MatrixMut<'_, T>>) -> Option<(&mut DeviceMemory<T>, usize)> {
3647    matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
3648}
3649
3650fn matrix_mut_ref_parts<'a, T>(
3651    matrix: Option<&'a MatrixMut<'a, T>>,
3652) -> Option<(&'a DeviceMemory<T>, usize)> {
3653    matrix.map(|matrix| (&*matrix.data, matrix.leading_dimension))
3654}
3655
3656fn matrix_mut_ref_option<'a, T>(matrix: Option<&'a MatrixMut<'a, T>>) -> Option<MatrixRef<'a, T>> {
3657    matrix.map(MatrixMut::as_ref)
3658}
3659
3660fn strided_batched_matrix_ref_parts<T>(
3661    matrix: Option<StridedBatchedMatrixRef<'_, T>>,
3662) -> Option<(&DeviceMemory<T>, usize, usize)> {
3663    matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
3664}
3665
3666fn strided_batched_matrix_mut_parts<T>(
3667    matrix: Option<StridedBatchedMatrixMut<'_, T>>,
3668) -> Option<(&mut DeviceMemory<T>, usize, usize)> {
3669    matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
3670}
3671
3672fn strided_batched_matrix_mut_ref_option<'a, T>(
3673    matrix: Option<&'a StridedBatchedMatrixMut<'a, T>>,
3674) -> Option<StridedBatchedMatrixRef<'a, T>> {
3675    matrix.map(StridedBatchedMatrixMut::as_ref)
3676}
3677
3678fn validate_xgesvdr_inputs<TU, TV>(
3679    m: usize,
3680    n: usize,
3681    k: usize,
3682    p: usize,
3683    niters: usize,
3684    a_bytes: usize,
3685    lda: usize,
3686    a_type: DataType,
3687    s_bytes: usize,
3688    s_type: DataType,
3689    job_u: TruncatedSvdMode,
3690    u: Option<&(&DeviceMemory<TU>, usize)>,
3691    u_type: DataType,
3692    job_v: TruncatedSvdMode,
3693    v: Option<&(&DeviceMemory<TV>, usize)>,
3694    v_type: DataType,
3695) -> Result<()> {
3696    if m == 0 || n == 0 || k == 0 || k >= m.min(n) || p == 0 || k.checked_add(p).is_none() {
3697        return Err(Error::InvalidMatrixShape);
3698    }
3699    let kp = k.checked_add(p).ok_or(Error::InvalidMatrixShape)?;
3700    if kp >= m.min(n) || niters == 0 {
3701        return Err(Error::InvalidMatrixShape);
3702    }
3703    validate_x_matrix(m, n, a_bytes, lda, a_type)?;
3704    validate_x_vector(k, s_bytes, s_type)?;
3705    if matches!(job_u, TruncatedSvdMode::Some) {
3706        let Some((u, ldu)) = u else {
3707            return Err(Error::InvalidMatrixShape);
3708        };
3709        validate_x_matrix(m, k, u.byte_len(), *ldu, u_type)?;
3710    }
3711    if matches!(job_v, TruncatedSvdMode::Some) {
3712        let Some((v, ldv)) = v else {
3713            return Err(Error::InvalidMatrixShape);
3714        };
3715        validate_x_matrix(n, k, v.byte_len(), *ldv, v_type)?;
3716    }
3717    Ok(())
3718}
3719
3720fn validate_gesvdj_inputs<T>(
3721    m: usize,
3722    n: usize,
3723    a_len: usize,
3724    lda: usize,
3725    s_len: usize,
3726    jobz: EigenMode,
3727    econ: bool,
3728    u: Option<(&DeviceMemory<T>, usize)>,
3729    v: Option<(&DeviceMemory<T>, usize)>,
3730) -> Result<()> {
3731    if m == 0 || n == 0 {
3732        return Err(Error::InvalidMatrixShape);
3733    }
3734    validate_matrix(m, n, a_len, lda)?;
3735    if s_len < m.min(n) {
3736        return Err(Error::InvalidVectorShape);
3737    }
3738    validate_gesvdj_output(m, n, jobz, econ, u)?;
3739    validate_gesvdj_output(n, n, jobz, econ, v)?;
3740    Ok(())
3741}
3742
3743fn validate_gesvda_strided_batched_inputs<T>(
3744    rank: usize,
3745    m: usize,
3746    n: usize,
3747    a_len: usize,
3748    lda: usize,
3749    stride_a: usize,
3750    s_len: usize,
3751    stride_s: usize,
3752    jobz: EigenMode,
3753    u: Option<(&DeviceMemory<T>, usize, usize)>,
3754    v: Option<(&DeviceMemory<T>, usize, usize)>,
3755    batch_size: usize,
3756) -> Result<()> {
3757    if batch_size == 0 || m == 0 || n == 0 || m < n || rank == 0 || rank > n {
3758        return Err(Error::InvalidMatrixShape);
3759    }
3760
3761    validate_strided_matrix(m, n, a_len, lda, stride_a, batch_size)?;
3762    validate_strided_vector(s_len, n, stride_s, batch_size)?;
3763
3764    match jobz {
3765        EigenMode::NoVector => {}
3766        EigenMode::Vector => {
3767            let Some((u, ldu, stride_u)) = u else {
3768                return Err(Error::InvalidMatrixShape);
3769            };
3770            let Some((v, ldv, stride_v)) = v else {
3771                return Err(Error::InvalidMatrixShape);
3772            };
3773            validate_strided_matrix(m, rank, u.len(), ldu, stride_u, batch_size)?;
3774            validate_strided_matrix(n, rank, v.len(), ldv, stride_v, batch_size)?;
3775        }
3776    }
3777    Ok(())
3778}
3779
3780fn validate_gesvdj_batched_inputs<T>(
3781    m: usize,
3782    n: usize,
3783    a_len: usize,
3784    lda: usize,
3785    s_len: usize,
3786    jobz: EigenMode,
3787    u: Option<(&DeviceMemory<T>, usize)>,
3788    v: Option<(&DeviceMemory<T>, usize)>,
3789    batch_size: usize,
3790) -> Result<()> {
3791    if batch_size == 0 || m == 0 || n == 0 || m > 32 || n > 32 {
3792        return Err(Error::InvalidMatrixShape);
3793    }
3794
3795    let a_cols = n.checked_mul(batch_size).ok_or(Error::InvalidMatrixShape)?;
3796    validate_matrix(m, a_cols, a_len, lda)?;
3797
3798    let s_required = m
3799        .min(n)
3800        .checked_mul(batch_size)
3801        .ok_or(Error::InvalidVectorShape)?;
3802    if s_len < s_required {
3803        return Err(Error::InvalidVectorShape);
3804    }
3805
3806    validate_gesvdj_batched_output(m, n, jobz, u, batch_size)?;
3807    validate_gesvdj_batched_output(n, n, jobz, v, batch_size)?;
3808    Ok(())
3809}
3810
3811fn validate_gesvdj_output<T>(
3812    rows: usize,
3813    cols: usize,
3814    jobz: EigenMode,
3815    econ: bool,
3816    matrix: Option<(&DeviceMemory<T>, usize)>,
3817) -> Result<()> {
3818    match jobz {
3819        EigenMode::NoVector => Ok(()),
3820        EigenMode::Vector => {
3821            let Some((matrix, ld)) = matrix else {
3822                return Err(Error::InvalidMatrixShape);
3823            };
3824            let out_cols = if econ { rows.min(cols) } else { cols };
3825            validate_matrix(rows, out_cols, matrix.len(), ld)
3826        }
3827    }
3828}
3829
3830fn validate_gesvdj_batched_output<T>(
3831    rows: usize,
3832    cols: usize,
3833    jobz: EigenMode,
3834    matrix: Option<(&DeviceMemory<T>, usize)>,
3835    batch_size: usize,
3836) -> Result<()> {
3837    match jobz {
3838        EigenMode::NoVector => Ok(()),
3839        EigenMode::Vector => {
3840            let Some((matrix, ld)) = matrix else {
3841                return Err(Error::InvalidMatrixShape);
3842            };
3843            let out_cols = rows
3844                .min(cols)
3845                .checked_mul(batch_size)
3846                .ok_or(Error::InvalidMatrixShape)?;
3847            validate_matrix(rows, out_cols, matrix.len(), ld)
3848        }
3849    }
3850}
3851
3852fn optional_gesvda_output_ptr<T>(
3853    matrix: Option<(&DeviceMemory<T>, usize, usize)>,
3854    rows: usize,
3855    cols: usize,
3856    jobz: EigenMode,
3857) -> Result<(*mut T, i32, i64)> {
3858    match jobz {
3859        EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
3860        EigenMode::Vector => {
3861            let Some((matrix, ld, stride)) = matrix else {
3862                return Err(Error::InvalidMatrixShape);
3863            };
3864            validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
3865            Ok((
3866                matrix.as_ptr() as *mut T,
3867                to_i32(ld, "ld")?,
3868                to_i64(stride, "stride")?,
3869            ))
3870        }
3871    }
3872}
3873
3874fn optional_gesvda_output_mut_ptr<T>(
3875    matrix: Option<(&mut DeviceMemory<T>, usize, usize)>,
3876    rows: usize,
3877    cols: usize,
3878    jobz: EigenMode,
3879) -> Result<(*mut T, i32, i64)> {
3880    match jobz {
3881        EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
3882        EigenMode::Vector => {
3883            let Some((matrix, ld, stride)) = matrix else {
3884                return Err(Error::InvalidMatrixShape);
3885            };
3886            validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
3887            Ok((
3888                matrix.as_mut_ptr().cast(),
3889                to_i32(ld, "ld")?,
3890                to_i64(stride, "stride")?,
3891            ))
3892        }
3893    }
3894}
3895
3896fn validate_gesvd_inputs<T>(
3897    m: usize,
3898    n: usize,
3899    a_len: usize,
3900    lda: usize,
3901    s_len: usize,
3902    job_u: SvdMode,
3903    u: Option<&(&DeviceMemory<T>, usize)>,
3904    job_vt: SvdMode,
3905    vt: Option<&(&DeviceMemory<T>, usize)>,
3906) -> Result<()> {
3907    validate_gesvd_dims(m, n)?;
3908    validate_matrix(m, n, a_len, lda)?;
3909    if s_len < n {
3910        return Err(Error::InvalidVectorShape);
3911    }
3912    validate_svd_output(m, m, job_u, u)?;
3913    validate_svd_output(n, n, job_vt, vt)?;
3914    Ok(())
3915}
3916
3917fn validate_x_svd_output<T>(
3918    rows: usize,
3919    full_cols: usize,
3920    matrix: Option<(&DeviceMemory<T>, usize)>,
3921    mode: SvdMode,
3922    data_type: DataType,
3923) -> Result<()> {
3924    match mode {
3925        SvdMode::None | SvdMode::Overwrite => Ok(()),
3926        SvdMode::All => {
3927            let Some((matrix, ld)) = matrix else {
3928                return Err(Error::InvalidMatrixShape);
3929            };
3930            validate_x_matrix(rows, full_cols, matrix.byte_len(), ld, data_type)
3931        }
3932        SvdMode::Some => {
3933            let Some((matrix, ld)) = matrix else {
3934                return Err(Error::InvalidMatrixShape);
3935            };
3936            validate_x_matrix(rows, full_cols.min(rows), matrix.byte_len(), ld, data_type)
3937        }
3938    }
3939}
3940
3941fn validate_svd_output<T>(
3942    rows: usize,
3943    full_cols: usize,
3944    mode: SvdMode,
3945    matrix: Option<&(&DeviceMemory<T>, usize)>,
3946) -> Result<()> {
3947    match mode {
3948        SvdMode::None | SvdMode::Overwrite => Ok(()),
3949        SvdMode::All => {
3950            let Some((matrix, ld)) = matrix else {
3951                return Err(Error::InvalidMatrixShape);
3952            };
3953            validate_matrix(rows, full_cols, matrix.len(), *ld)
3954        }
3955        SvdMode::Some => {
3956            let Some((matrix, ld)) = matrix else {
3957                return Err(Error::InvalidMatrixShape);
3958            };
3959            validate_matrix(rows, full_cols.min(rows), matrix.len(), *ld)
3960        }
3961    }
3962}
3963
3964fn validate_x_eig_output(
3965    rows: usize,
3966    cols: usize,
3967    bytes: usize,
3968    ld: usize,
3969    econ: bool,
3970    data_type: DataType,
3971) -> Result<()> {
3972    let out_cols = if econ { rows.min(cols) } else { cols };
3973    validate_x_matrix(rows, out_cols, bytes, ld, data_type)
3974}
3975
3976fn optional_matrix_ptr<T>(
3977    matrix: Option<(&mut DeviceMemory<T>, usize)>,
3978    order: usize,
3979    mode: SvdMode,
3980) -> Result<(*mut T, i32)> {
3981    match mode {
3982        SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), to_i32(order.max(1), "ld")?)),
3983        SvdMode::All | SvdMode::Some => {
3984            let Some((matrix, ld)) = matrix else {
3985                return Err(Error::InvalidMatrixShape);
3986            };
3987            Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
3988        }
3989    }
3990}
3991
3992fn optional_x_matrix_ptr<T>(
3993    matrix: Option<(&DeviceMemory<T>, usize)>,
3994    rows: usize,
3995    cols: usize,
3996    mode: SvdMode,
3997    data_type: DataType,
3998) -> Result<(*mut T, i64)> {
3999    match mode {
4000        SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
4001        SvdMode::All => {
4002            let Some((matrix, ld)) = matrix else {
4003                return Err(Error::InvalidMatrixShape);
4004            };
4005            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4006            Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4007        }
4008        SvdMode::Some => {
4009            let Some((matrix, ld)) = matrix else {
4010                return Err(Error::InvalidMatrixShape);
4011            };
4012            validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
4013            Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4014        }
4015    }
4016}
4017
4018fn optional_x_matrix_mut_ptr<T>(
4019    matrix: Option<(&mut DeviceMemory<T>, usize)>,
4020    rows: usize,
4021    cols: usize,
4022    mode: SvdMode,
4023    data_type: DataType,
4024) -> Result<(*mut T, i64)> {
4025    match mode {
4026        SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
4027        SvdMode::All => {
4028            let Some((matrix, ld)) = matrix else {
4029                return Err(Error::InvalidMatrixShape);
4030            };
4031            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4032            Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4033        }
4034        SvdMode::Some => {
4035            let Some((matrix, ld)) = matrix else {
4036                return Err(Error::InvalidMatrixShape);
4037            };
4038            validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
4039            Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4040        }
4041    }
4042}
4043
4044fn optional_x_eig_matrix_ptr<T>(
4045    matrix: Option<(&DeviceMemory<T>, usize)>,
4046    rows: usize,
4047    cols: usize,
4048    jobz: EigenMode,
4049    econ: bool,
4050    data_type: DataType,
4051) -> Result<(*mut T, i64)> {
4052    match jobz {
4053        EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4054        EigenMode::Vector => {
4055            let Some((matrix, ld)) = matrix else {
4056                return Err(Error::InvalidMatrixShape);
4057            };
4058            validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
4059            Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4060        }
4061    }
4062}
4063
4064fn optional_x_eig_matrix_mut_ptr<T>(
4065    matrix: Option<(&mut DeviceMemory<T>, usize)>,
4066    rows: usize,
4067    cols: usize,
4068    jobz: EigenMode,
4069    econ: bool,
4070    data_type: DataType,
4071) -> Result<(*mut T, i64)> {
4072    match jobz {
4073        EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4074        EigenMode::Vector => {
4075            let Some((matrix, ld)) = matrix else {
4076                return Err(Error::InvalidMatrixShape);
4077            };
4078            validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
4079            Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4080        }
4081    }
4082}
4083
4084fn optional_x_truncated_u_ptr<T>(
4085    matrix: Option<(&DeviceMemory<T>, usize)>,
4086    rows: usize,
4087    cols: usize,
4088    mode: TruncatedSvdMode,
4089    data_type: DataType,
4090) -> Result<(*mut T, i64)> {
4091    match mode {
4092        TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4093        TruncatedSvdMode::Some => {
4094            let Some((matrix, ld)) = matrix else {
4095                return Err(Error::InvalidMatrixShape);
4096            };
4097            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4098            Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4099        }
4100    }
4101}
4102
4103fn optional_x_truncated_u_mut_ptr<T>(
4104    matrix: Option<(&mut DeviceMemory<T>, usize)>,
4105    rows: usize,
4106    cols: usize,
4107    mode: TruncatedSvdMode,
4108    data_type: DataType,
4109) -> Result<(*mut T, i64)> {
4110    match mode {
4111        TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4112        TruncatedSvdMode::Some => {
4113            let Some((matrix, ld)) = matrix else {
4114                return Err(Error::InvalidMatrixShape);
4115            };
4116            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4117            Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4118        }
4119    }
4120}
4121
4122fn optional_x_truncated_v_ptr<T>(
4123    matrix: Option<(&DeviceMemory<T>, usize)>,
4124    rows: usize,
4125    cols: usize,
4126    mode: TruncatedSvdMode,
4127    data_type: DataType,
4128) -> Result<(*mut T, i64)> {
4129    match mode {
4130        TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4131        TruncatedSvdMode::Some => {
4132            let Some((matrix, ld)) = matrix else {
4133                return Err(Error::InvalidMatrixShape);
4134            };
4135            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4136            Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
4137        }
4138    }
4139}
4140
4141fn optional_x_truncated_v_mut_ptr<T>(
4142    matrix: Option<(&mut DeviceMemory<T>, usize)>,
4143    rows: usize,
4144    cols: usize,
4145    mode: TruncatedSvdMode,
4146    data_type: DataType,
4147) -> Result<(*mut T, i64)> {
4148    match mode {
4149        TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
4150        TruncatedSvdMode::Some => {
4151            let Some((matrix, ld)) = matrix else {
4152                return Err(Error::InvalidMatrixShape);
4153            };
4154            validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
4155            Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
4156        }
4157    }
4158}
4159
4160fn optional_gesvdj_matrix_ptr<T>(
4161    matrix: Option<(&DeviceMemory<T>, usize)>,
4162    rows: usize,
4163    cols: usize,
4164    jobz: EigenMode,
4165    econ: bool,
4166) -> Result<(*mut T, i32)> {
4167    match jobz {
4168        EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4169        EigenMode::Vector => {
4170            let Some((matrix, ld)) = matrix else {
4171                return Err(Error::InvalidMatrixShape);
4172            };
4173            let out_cols = if econ { rows.min(cols) } else { cols };
4174            validate_matrix(rows, out_cols, matrix.len(), ld)?;
4175            Ok((matrix.as_ptr() as *mut T, to_i32(ld, "ld")?))
4176        }
4177    }
4178}
4179
4180fn optional_gesvdj_matrix_mut_ptr<T>(
4181    matrix: Option<(&mut DeviceMemory<T>, usize)>,
4182    rows: usize,
4183    cols: usize,
4184    jobz: EigenMode,
4185    econ: bool,
4186) -> Result<(*mut T, i32)> {
4187    match jobz {
4188        EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
4189        EigenMode::Vector => {
4190            let Some((matrix, ld)) = matrix else {
4191                return Err(Error::InvalidMatrixShape);
4192            };
4193            let out_cols = if econ { rows.min(cols) } else { cols };
4194            validate_matrix(rows, out_cols, matrix.len(), ld)?;
4195            Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
4196        }
4197    }
4198}
4199
4200fn require_rwork_buffer<T>(rwork: Option<&DeviceMemory<T>>, m: usize, n: usize) -> Result<()> {
4201    let required = n.saturating_sub(1).min(m.saturating_sub(1));
4202    if let Some(rwork) = rwork
4203        && rwork.len() < required
4204    {
4205        return Err(Error::InvalidVectorShape);
4206    }
4207    Ok(())
4208}
4209
4210fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
4211    if rows == 0 || cols == 0 {
4212        return Err(Error::InvalidMatrixShape);
4213    }
4214    if lda < rows {
4215        return Err(Error::InvalidLeadingDimension);
4216    }
4217    let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
4218    if len < required {
4219        return Err(Error::InvalidMatrixShape);
4220    }
4221    Ok(())
4222}
4223
4224fn validate_x_matrix(
4225    rows: usize,
4226    cols: usize,
4227    bytes: usize,
4228    lda: usize,
4229    data_type: DataType,
4230) -> Result<()> {
4231    if rows == 0 || cols == 0 {
4232        return Err(Error::InvalidMatrixShape);
4233    }
4234    if lda < rows {
4235        return Err(Error::InvalidLeadingDimension);
4236    }
4237    let elem_size = data_type.size_in_bytes();
4238    let required = lda
4239        .checked_mul(cols)
4240        .and_then(|count| count.checked_mul(elem_size))
4241        .ok_or(Error::InvalidMatrixShape)?;
4242    if bytes < required {
4243        return Err(Error::InvalidMatrixShape);
4244    }
4245    Ok(())
4246}
4247
4248fn validate_x_vector(len: usize, bytes: usize, data_type: DataType) -> Result<()> {
4249    let required = len
4250        .checked_mul(data_type.size_in_bytes())
4251        .ok_or(Error::InvalidVectorShape)?;
4252    if bytes < required {
4253        return Err(Error::InvalidVectorShape);
4254    }
4255    Ok(())
4256}
4257
4258fn validate_strided_matrix(
4259    rows: usize,
4260    cols: usize,
4261    len: usize,
4262    lda: usize,
4263    stride: usize,
4264    batch_size: usize,
4265) -> Result<()> {
4266    validate_matrix(rows, cols, len, lda)?;
4267    if batch_size == 0 {
4268        return Err(Error::InvalidMatrixShape);
4269    }
4270    let footprint = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
4271    if stride < footprint {
4272        return Err(Error::InvalidMatrixShape);
4273    }
4274    let required = if batch_size == 1 {
4275        footprint
4276    } else {
4277        stride
4278            .checked_mul(batch_size - 1)
4279            .and_then(|base| base.checked_add(footprint))
4280            .ok_or(Error::InvalidMatrixShape)?
4281    };
4282    if len < required {
4283        return Err(Error::InvalidMatrixShape);
4284    }
4285    Ok(())
4286}
4287
4288fn validate_strided_vector(
4289    len: usize,
4290    width: usize,
4291    stride: usize,
4292    batch_size: usize,
4293) -> Result<()> {
4294    if width == 0 || batch_size == 0 {
4295        return Err(Error::InvalidVectorShape);
4296    }
4297    if stride < width {
4298        return Err(Error::InvalidVectorShape);
4299    }
4300    let required = if batch_size == 1 {
4301        width
4302    } else {
4303        stride
4304            .checked_mul(batch_size - 1)
4305            .and_then(|base| base.checked_add(width))
4306            .ok_or(Error::InvalidVectorShape)?
4307    };
4308    if len < required {
4309        return Err(Error::InvalidVectorShape);
4310    }
4311    Ok(())
4312}
4313
4314fn require_workspace(actual: usize, required: usize) -> Result<()> {
4315    if actual < required {
4316        return Err(Error::InsufficientWorkspaceSize { required, actual });
4317    }
4318    Ok(())
4319}
4320
4321fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
4322    if actual < required {
4323        return Err(Error::InsufficientWorkspaceSize { required, actual });
4324    }
4325    Ok(())
4326}
4327
4328fn require_host_workspace(actual: usize, required: usize) -> Result<()> {
4329    if actual < required {
4330        return Err(Error::InsufficientWorkspaceSize { required, actual });
4331    }
4332    Ok(())
4333}
4334
4335fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
4336    if dev_info.is_empty() {
4337        return Err(Error::InvalidVectorShape);
4338    }
4339    Ok(())
4340}
4341
4342fn require_info_buffer_len(dev_info: &DeviceMemory<i32>, required: usize) -> Result<()> {
4343    if dev_info.len() < required {
4344        return Err(Error::InvalidVectorShape);
4345    }
4346    Ok(())
4347}
4348
4349#[cfg(all(test, feature = "testing"))]
4350mod tests {
4351    use singe_core::assert_close;
4352    use singe_cuda::memory::DeviceMemory;
4353
4354    use super::*;
4355    use crate::{params::Params, testing::setup_context_if_available};
4356
4357    #[test]
4358    fn test_sgesvd_returns_expected_singular_values() -> Result<()> {
4359        let Some(ctx) = setup_context_if_available()? else {
4360            return Ok(());
4361        };
4362
4363        let mut a = DeviceMemory::from_slice(&[
4364            3.0_f32, 0.0, //
4365            0.0, 2.0,
4366        ])?;
4367        let mut s = DeviceMemory::create(2)?;
4368        let mut workspace = DeviceMemory::create(sgesvd_buffer_size(&ctx, 2, 2)?)?;
4369        let mut dev_info = DeviceMemory::create(1)?;
4370
4371        sgesvd(
4372            &ctx,
4373            SvdMode::None,
4374            SvdMode::None,
4375            2,
4376            2,
4377            MatrixMut::new(&mut a, 2),
4378            &mut s,
4379            None,
4380            None,
4381            &mut workspace,
4382            None,
4383            &mut dev_info,
4384        )?;
4385
4386        let singular_values = s.copy_to_host_vec()?;
4387        let info = dev_info.copy_to_host_vec()?;
4388
4389        assert_eq!(info, vec![0]);
4390        assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
4391        Ok(())
4392    }
4393
4394    #[test]
4395    fn test_xgesvd_returns_expected_singular_values() -> Result<()> {
4396        let Some(ctx) = setup_context_if_available()? else {
4397            return Ok(());
4398        };
4399        let params = Params::create()?;
4400
4401        let mut a = DeviceMemory::from_slice(&[
4402            3.0_f32, 0.0, //
4403            0.0, 2.0,
4404        ])?;
4405        let mut s = DeviceMemory::create(2)?;
4406        let workspace_sizes = xgesvd_buffer_size::<f32, f32, f32, f32>(
4407            &ctx,
4408            &params,
4409            SvdMode::None,
4410            SvdMode::None,
4411            2,
4412            2,
4413            MatrixRef::new(&a, 2),
4414            &s,
4415            None,
4416            None,
4417        )?;
4418        let mut device_workspace = DeviceMemory::create(workspace_sizes.device_bytes.max(1))?;
4419        let mut host_workspace = vec![0_u8; workspace_sizes.host_bytes.max(1)];
4420        let mut dev_info = DeviceMemory::create(1)?;
4421
4422        xgesvd::<f32, f32, f32, f32>(
4423            &ctx,
4424            &params,
4425            SvdMode::None,
4426            SvdMode::None,
4427            2,
4428            2,
4429            MatrixMut::new(&mut a, 2),
4430            &mut s,
4431            None,
4432            None,
4433            ByteWorkspaceMut::new(&mut device_workspace, &mut host_workspace),
4434            &mut dev_info,
4435        )?;
4436
4437        let singular_values = s.copy_to_host_vec()?;
4438        let info = dev_info.copy_to_host_vec()?;
4439
4440        assert_eq!(info, vec![0]);
4441        assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
4442        Ok(())
4443    }
4444}