Skip to main content

singe_cusolver/
irs.rs

1#[allow(unused_imports)]
2use crate::error::Status;
3
4use std::{ptr, slice};
5
6use singe_cuda::{data_type::DataTypeLike, memory::DeviceMemory};
7
8use crate::{
9    context::Context,
10    error::{Error, Result},
11    layout::{MatrixMut, MatrixRef},
12    sys, try_ffi,
13    types::{IrsRefinement, PrecisionType},
14    utility::{to_i32, to_u64},
15};
16
17#[derive(Debug)]
18pub struct IrsParams {
19    handle: sys::cusolverDnIRSParams_t,
20    main_precision: Option<PrecisionType>,
21    lowest_precision: Option<PrecisionType>,
22}
23
24#[derive(Debug, Default)]
25pub struct IrsInfos {
26    handle: sys::cusolverDnIRSInfos_t,
27    residual_history_requested: bool,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct ResidualHistoryEntry<T> {
32    pub total_iterations: T,
33    pub residual_norm: T,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct ResidualHistory<T> {
38    pub rows: Vec<ResidualHistoryEntry<T>>,
39    pub leading_dimension: usize,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub struct IrsSolve {
44    pub n: usize,
45    pub right_hand_sides: usize,
46}
47
48impl IrsSolve {
49    pub fn new(n: usize, right_hand_sides: usize) -> Self {
50        Self {
51            n,
52            right_hand_sides,
53        }
54    }
55
56    pub fn workspace_size<T: DataTypeLike>(
57        self,
58        ctx: &Context,
59        params: &mut IrsParams,
60    ) -> Result<usize> {
61        xgesv_buffer_size::<T>(ctx, params, self.n, self.right_hand_sides)
62    }
63
64    pub fn execute<T: DataTypeLike>(
65        self,
66        ctx: &Context,
67        params: &mut IrsParams,
68        infos: &IrsInfos,
69        bindings: IrsSolveBindings<'_, T>,
70    ) -> Result<i32> {
71        xgesv(
72            ctx,
73            params,
74            infos,
75            self.n,
76            self.right_hand_sides,
77            bindings.a,
78            bindings.b,
79            bindings.x,
80            bindings.device_workspace,
81            bindings.dev_info,
82        )
83    }
84}
85
86#[derive(Debug)]
87pub struct IrsSolveBindings<'a, T> {
88    pub a: MatrixMut<'a, T>,
89    pub b: MatrixRef<'a, T>,
90    pub x: MatrixMut<'a, T>,
91    pub device_workspace: &'a mut DeviceMemory<u8>,
92    pub dev_info: &'a mut DeviceMemory<i32>,
93}
94
95// IRS parameter/info handles expose mutation through &mut self and inspection
96// through shared references, so immutable sharing follows the cuSOLVER contract.
97unsafe impl Send for IrsParams {}
98unsafe impl Sync for IrsParams {}
99unsafe impl Send for IrsInfos {}
100unsafe impl Sync for IrsInfos {}
101
102impl IrsParams {
103    /// Creates and initializes the parameter structure for IRS solvers such as
104    /// [`xgesv`] and [`xgels`].
105    ///
106    /// The returned parameter structure can be reused across calls to the same
107    /// IRS solver or to different IRS solvers.
108    ///
109    /// In CUDA 10.2, the behavior was different and a new parameter structure
110    /// was required for each IRS solve call.
111    ///
112    /// You can also reconfigure the parameters between solves, but only after
113    /// the previous IRS call has completed.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if cuSOLVER cannot allocate the required resources
118    /// or does not return a valid handle.
119    pub fn create() -> Result<Self> {
120        let mut handle = ptr::null_mut();
121        unsafe {
122            try_ffi!(sys::cusolverDnIRSParamsCreate(&raw mut handle))?;
123        }
124        if handle.is_null() {
125            return Err(Error::NullHandle);
126        }
127        let mut params = Self {
128            handle,
129            main_precision: None,
130            lowest_precision: None,
131        };
132        params.set_refinement_solver(IrsRefinement::None)?;
133        Ok(params)
134    }
135
136    /// Sets the refinement solver used by IRS operations such as [`xgesv`] and
137    /// [`xgels`].
138    ///
139    /// Configure the refinement algorithm before the first IRS solve. Newly created [`IrsParams`] do not set one by default.
140    ///
141    /// The supported values are described below.
142    ///
143    /// [`IrsRefinement::NotSet`]: Solver is not set. The IRS solver returns an
144    /// error if this value is used.
145    ///
146    /// [`IrsRefinement::None`]: No refinement solver; the IRS solver performs a factorization followed by a solve without any refinement.
147    /// For example, if the IRS solver was [`xgesv`], this is equivalent to an
148    /// [`xgesv`] solve without refinement, with the factorization carried out in
149    /// the lowest configured precision.
150    /// If both the main and lowest precision are [`PrecisionType::R64F`], the
151    /// solve is effectively performed in `f64`.
152    ///
153    /// [`IrsRefinement::Classical`]: Classical iterative refinement solver.
154    /// Similar to the value used in LAPACK operations.
155    ///
156    /// [`IrsRefinement::Gmres`]: GMRES (Generalized Minimal Residual) based iterative refinement solver.
157    /// Recent studies use GMRES as a refinement solver that can outperform
158    /// classical iterative refinement.
159    /// Recommended setting based on cuSOLVER experimentation.
160    ///
161    /// [`IrsRefinement::ClassicalGmres`]: Classical iterative refinement solver that uses the GMRES (Generalized Minimal Residual) internally to solve the correction equation at each iteration.
162    /// The classical refinement iteration is the outer iteration, and GMRES is
163    /// the inner iteration.
164    /// If the tolerance of the inner GMRES is set very low, for
165    /// example near machine precision, then the outer *classical refinement
166    /// iteration* performs only one iteration and this option behaves like
167    /// [`IrsRefinement::Gmres`].
168    ///
169    /// [`IrsRefinement::GmresGmres`]: GMRES-based iterative refinement solver
170    /// that uses another GMRES solve internally for the preconditioned system.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if cuSOLVER rejects the parameter structure.
175    pub fn set_refinement_solver(&mut self, refinement: IrsRefinement) -> Result<()> {
176        unsafe {
177            try_ffi!(sys::cusolverDnIRSParamsSetRefinementSolver(
178                self.as_raw(),
179                refinement.into(),
180            ))?;
181        }
182        Ok(())
183    }
184
185    /// Sets the main precision for the Iterative Refinement Solver (IRS).
186    ///
187    /// The main precision is the type of the input and output data.
188    /// Configure both the main and lowest precision before the first IRS solve. Those
189    /// values are not inferred when the parameter structure is created because
190    /// they depend on the input/output data type and the requested solver
191    /// configuration. You can set them independently or together with
192    /// [`IrsParams::set_solver_precisions`].
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if cuSOLVER rejects the parameter structure.
197    pub fn set_main_precision(&mut self, precision: PrecisionType) -> Result<()> {
198        unsafe {
199            try_ffi!(sys::cusolverDnIRSParamsSetSolverMainPrecision(
200                self.as_raw(),
201                precision.into(),
202            ))?;
203        }
204        self.main_precision = Some(precision);
205        Ok(())
206    }
207
208    /// Sets the lowest precision that the IRS solver may use.
209    ///
210    /// The lowest precision is the minimum compute precision used
211    /// during the LU factorization process.
212    ///
213    /// Configure both the main and lowest precision before the first IRS solve. They
214    /// are not inferred when creating the parameter structure because they
215    /// depend on the input and output data types and the requested solver
216    /// configuration.
217    /// Usually the lowest precision defines the speedup that can be achieved.
218    /// The ratio between the performance of the lowest precision and the main
219    /// precision gives an approximate upper bound on the speedup.
220    /// More precisely, it depends on many factors, but for large matrices it is
221    /// often tied to the performance ratio of large GEMM-like kernels.
222    /// For instance, if the input/output precision is real double precision
223    /// [`PrecisionType::R64F`] and the lowest precision is
224    /// [`PrecisionType::R32F`], then a speedup of at most about 2x is expected
225    /// for large problem sizes.
226    /// If the lowest precision is [`PrecisionType::R16F`], expect 3x-4x.
227    /// A reasonable strategy accounts for the number of right-hand sides, the matrix size, and the convergence rate.
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if cuSOLVER rejects the parameter structure.
232    pub fn set_lowest_precision(&mut self, precision: PrecisionType) -> Result<()> {
233        unsafe {
234            try_ffi!(sys::cusolverDnIRSParamsSetSolverLowestPrecision(
235                self.as_raw(),
236                precision.into(),
237            ))?;
238        }
239        self.lowest_precision = Some(precision);
240        Ok(())
241    }
242
243    /// Sets both the main and lowest precision for the Iterative Refinement
244    /// Solver (IRS).
245    ///
246    /// The main precision is the precision of the input and output data.
247    /// The lowest precision is the minimum compute precision used
248    /// during the LU factorization process.
249    ///
250    /// Configure both values before the first IRS solve. They are not inferred when
251    /// creating the parameter structure because they depend on the input and
252    /// output data types and the requested solver configuration.
253    ///
254    /// Convenience wrapper around
255    /// [`IrsParams::set_main_precision`] and
256    /// [`IrsParams::set_lowest_precision`].
257    /// All possible combinations of main/lowest precision are described in the table below.
258    /// Usually the lowest precision defines the speedup that can be achieved.
259    /// The ratio between the performance of the lowest precision and the main
260    /// precision gives an approximate upper bound on the speedup.
261    /// More precisely, it depends on many factors, but for large matrices it is
262    /// often tied to the performance ratio of large GEMM-like kernels.
263    /// For instance, if the input/output precision is real double precision
264    /// [`PrecisionType::R64F`] and the lowest precision is
265    /// [`PrecisionType::R32F`], then a speedup of at most about 2x is expected
266    /// for large problem sizes.
267    /// If the lowest precision is [`PrecisionType::R16F`], expect 3x-4x.
268    /// A reasonable strategy accounts for the number of right-hand sides, the matrix size, and the convergence rate.
269    ///
270    /// **Supported input/output data type and lower precision for the IRS solver**
271    ///
272    /// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
273    /// | --- | --- |
274    /// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
275    /// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
276    /// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
277    /// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if cuSOLVER rejects the parameter structure.
282    pub fn set_solver_precisions(
283        &mut self,
284        main_precision: PrecisionType,
285        lowest_precision: PrecisionType,
286    ) -> Result<()> {
287        unsafe {
288            try_ffi!(sys::cusolverDnIRSParamsSetSolverPrecisions(
289                self.as_raw(),
290                main_precision.into(),
291                lowest_precision.into(),
292            ))?;
293        }
294        self.main_precision = Some(main_precision);
295        self.lowest_precision = Some(lowest_precision);
296        Ok(())
297    }
298
299    /// Sets the tolerance for the refinement solver.
300    /// By default it is such that all the RHS satisfy:
301    ///
302    /// `RNRM &lt; SQRT(N)*XNRM*ANRM*EPS*BWDMAX` where
303    ///
304    /// * RNRM is the infinity-norm of the residual
305    /// * XNRM is the infinity-norm of the solution
306    /// * ANRM is the infinity-operator-norm of the matrix A
307    /// * EPS is the machine epsilon for the input/output data type that matches
308    ///   LAPACK `xLAMCH('Epsilon')`
309    /// * BWDMAX, the value BWDMAX is fixed to 1.0
310    ///
311    /// Use this to set the tolerance to a lower or higher value.
312    /// The tolerance value is always stored in real double precision,
313    /// regardless of the input and output data type.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if cuSOLVER rejects the parameter structure.
318    pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
319        unsafe {
320            try_ffi!(sys::cusolverDnIRSParamsSetTol(self.as_raw(), tolerance))?;
321        }
322        Ok(())
323    }
324
325    /// Sets the tolerance for the inner refinement solver when
326    /// the refinement solver consists of two levels, for example
327    /// [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`].
328    /// Ignored for one-level refinement solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`].
329    /// The default value is 1e-4.
330    /// This sets the tolerance for the inner solver, such as the inner GMRES.
331    /// For example, if the refinement solver is
332    /// [`IrsRefinement::ClassicalGmres`], setting this tolerance means that the
333    /// inner GMRES solver converges to that tolerance at each outer
334    /// iteration of the classical refinement solver.
335    /// The tolerance value is always stored in real double precision,
336    /// regardless of the input and output data type.
337    ///
338    /// # Errors
339    ///
340    /// Returns an error if cuSOLVER rejects the parameter structure.
341    pub fn set_inner_tolerance(&mut self, tolerance: f64) -> Result<()> {
342        unsafe {
343            try_ffi!(sys::cusolverDnIRSParamsSetTolInner(
344                self.as_raw(),
345                tolerance,
346            ))?;
347        }
348        Ok(())
349    }
350
351    /// Sets the total number of allowed refinement iterations before the solver stops.
352    /// The total is the sum of the outer and inner iterations. Inner iterations are meaningful when a two-level refinement solver is configured.
353    /// The default value is 50.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if cuSOLVER rejects the parameter structure.
358    pub fn set_max_iterations(&mut self, max_iterations: i32) -> Result<()> {
359        unsafe {
360            try_ffi!(sys::cusolverDnIRSParamsSetMaxIters(
361                self.as_raw(),
362                max_iterations,
363            ))?;
364        }
365        Ok(())
366    }
367
368    /// Sets the maximum number of iterations allowed for the inner refinement solver.
369    /// Ignored for one-level refinement solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`].
370    /// The inner refinement solver stops after reaching either the inner tolerance or `MaxItersInner`.
371    /// The default value is 50.
372    /// Cannot be larger than `MaxIters` because `MaxIters` is the total number of allowed iterations.
373    /// If [`IrsParams::set_max_iterations`] is called after this method, it has priority and overwrites `MaxItersInner` with `min(MaxIters, MaxItersInner)`.
374    ///
375    /// # Errors
376    ///
377    /// Returns an error if `max_iterations` is larger than `MaxIters`, or if
378    /// cuSOLVER rejects the parameter structure.
379    pub fn set_max_inner_iterations(&mut self, max_iterations: i32) -> Result<()> {
380        unsafe {
381            try_ffi!(sys::cusolverDnIRSParamsSetMaxItersInner(
382                self.as_raw(),
383                max_iterations,
384            ))?;
385        }
386        Ok(())
387    }
388
389    /// Returns the current maximum-iteration setting in this parameter structure.
390    /// Current parameter configuration, distinct from [`IrsInfos::max_iterations`], which returns the maximum number of iterations allowed for a particular IRS solver call.
391    /// The parameter structure can be reused across many IRS solver calls.
392    /// The allowed `MaxIters` value can change between calls, while the `Infos` structure contains information about one particular call and cannot be reused for different calls.
393    ///
394    /// # Errors
395    ///
396    /// Returns an error if cuSOLVER rejects the parameter structure.
397    pub fn max_iterations(&self) -> Result<i32> {
398        let mut value = 0;
399        unsafe {
400            try_ffi!(sys::cusolverDnIRSParamsGetMaxIters(
401                self.as_raw(),
402                &raw mut value,
403            ))?;
404        }
405        Ok(value)
406    }
407
408    /// Enables fallback to the main precision if the Iterative Refinement Solver (IRS) fails to converge.
409    /// If the IRS solver fails to converge, it returns a non-convergence code such as `niter < 0`.
410    /// With fallback disabled, it returns the non-convergent solution as-is.
411    /// With fallback enabled, it falls back to the main precision, which is the input/output data precision, and solves the problem again from scratch.
412    /// This fallback is the default behavior.
413    ///
414    /// # Errors
415    ///
416    /// Returns an error if cuSOLVER rejects the parameter structure.
417    pub fn enable_fallback(&mut self) -> Result<()> {
418        unsafe {
419            try_ffi!(sys::cusolverDnIRSParamsEnableFallback(self.as_raw()))?;
420        }
421        Ok(())
422    }
423
424    /// Disables fallback to the main precision if the Iterative Refinement Solver (IRS) fails to converge.
425    /// If the IRS solver fails to converge, it returns a non-convergence code such as `niter < 0`.
426    /// With fallback disabled, the returned solution is whatever the refinement solver reached before returning.
427    /// Disabling fallback does not guarantee that the solution is accurate.
428    /// Re-enable fallback with [`IrsParams::enable_fallback`].
429    ///
430    /// # Errors
431    ///
432    /// Returns an error if cuSOLVER rejects the parameter structure.
433    pub fn disable_fallback(&mut self) -> Result<()> {
434        unsafe {
435            try_ffi!(sys::cusolverDnIRSParamsDisableFallback(self.as_raw()))?;
436        }
437        Ok(())
438    }
439
440    fn ensure_type_precision<T: DataTypeLike>(&mut self) -> Result<()> {
441        let precision = PrecisionType::from_data_type(T::data_type())
442            .ok_or(Error::InvalidPrecisionConfiguration)?;
443        match self.main_precision {
444            Some(existing) if existing != precision => {
445                return Err(Error::InvalidPrecisionConfiguration);
446            }
447            None => self.set_main_precision(precision)?,
448            _ => {}
449        }
450        if self.lowest_precision.is_none() {
451            self.set_lowest_precision(precision)?;
452        }
453        Ok(())
454    }
455
456    pub fn as_raw(&self) -> sys::cusolverDnIRSParams_t {
457        self.handle
458    }
459}
460
461impl Drop for IrsParams {
462    fn drop(&mut self) {
463        unsafe {
464            if let Err(err) = try_ffi!(sys::cusolverDnIRSParamsDestroy(self.handle)) {
465                #[cfg(debug_assertions)]
466                eprintln!("failed to destroy cusolver irs params: {err}");
467            }
468        }
469    }
470}
471
472impl IrsInfos {
473    /// Creates and initializes the `Infos` structure that holds refinement information for an Iterative Refinement Solver (IRS) call.
474    /// Such information includes the total number of iterations needed to converge (`Niters`), the number of outer iterations (meaningful when a two-level preconditioner such as [`IrsRefinement::ClassicalGmres`] is used), the maximum number of iterations allowed for that call, and a pointer to the convergence-history residual norm matrix.
475    /// Construct the `Infos` structure before calling an IRS solver.
476    /// The `Infos` structure is valid for only one call to an IRS solver, since it holds information about that solve; each solve requires its own `Infos` structure.
477    ///
478    /// # Errors
479    ///
480    /// Returns an error if cuSOLVER cannot allocate the required resources
481    /// or does not return a valid handle.
482    pub fn create() -> Result<Self> {
483        let mut handle = ptr::null_mut();
484        unsafe {
485            try_ffi!(sys::cusolverDnIRSInfosCreate(&raw mut handle))?;
486        }
487        if handle.is_null() {
488            return Err(Error::NullHandle);
489        }
490        Ok(Self {
491            handle,
492            residual_history_requested: false,
493        })
494    }
495
496    /// Returns the total number of iterations performed by the IRS solver.
497    /// If this value is negative, the IRS solver did not converge. If fallback to full precision was enabled, the solver fell back to a full-precision solution.
498    /// See [`xgesv`] and [`xgels`] for the meaning of negative `niters` values.
499    ///
500    /// # Errors
501    ///
502    /// Returns an error if cuSOLVER rejects the `Infos` structure.
503    pub fn niters(&self) -> Result<i32> {
504        let mut value = 0;
505        unsafe {
506            try_ffi!(sys::cusolverDnIRSInfosGetNiters(
507                self.as_raw(),
508                &raw mut value,
509            ))?;
510        }
511        Ok(value)
512    }
513
514    /// Returns the number of iterations performed by the outer refinement loop of the IRS solver.
515    /// For one-level solvers such as [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], this is the same as `Niters`.
516    /// For two-level solvers such as [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], this is the number of outer-loop iterations.
517    /// See [`IrsRefinement`] for refinement mode details.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if cuSOLVER rejects the `Infos` structure.
522    pub fn outer_niters(&self) -> Result<i32> {
523        let mut value = 0;
524        unsafe {
525            try_ffi!(sys::cusolverDnIRSInfosGetOuterNiters(
526                self.as_raw(),
527                &raw mut value,
528            ))?;
529        }
530        Ok(value)
531    }
532
533    /// Returns the maximum number of iterations allowed for the corresponding IRS solver call.
534    /// Setting used when that call happened, distinct from [`IrsParams::max_iterations`], which returns the current setting in the `params` configuration structure.
535    /// The `params` structure can be reused for many IRS solver calls.
536    /// The allowed `MaxIters` value can change between calls, while this `Infos` structure contains information about one particular call and cannot be reused for different calls.
537    ///
538    /// # Errors
539    ///
540    /// Returns an error if cuSOLVER rejects the `Infos` structure.
541    pub fn max_iterations(&self) -> Result<i32> {
542        let mut value = 0;
543        unsafe {
544            try_ffi!(sys::cusolverDnIRSInfosGetMaxIters(
545                self.as_raw(),
546                &raw mut value,
547            ))?;
548        }
549        Ok(value)
550    }
551
552    /// Asks the IRS solver to store the convergence history
553    /// (residual norms) of the refinement phase so it can later be queried with
554    /// [`IrsInfos::residual_history_f32`] or [`IrsInfos::residual_history_f64`].
555    ///
556    /// # Errors
557    ///
558    /// Returns an error if cuSOLVER rejects the `Infos` structure.
559    pub fn request_residual_history(&mut self) -> Result<()> {
560        unsafe {
561            try_ffi!(sys::cusolverDnIRSInfosRequestResidual(self.as_raw()))?;
562        }
563        self.residual_history_requested = true;
564        Ok(())
565    }
566
567    /// Returns the convergence history stored by the IRS solver when [`IrsInfos::request_residual_history`] was called before solving.
568    /// The residual norm type depends on the input and output precision.
569    /// Double-precision real and complex configurations report `f64` residuals, while single-precision real and complex configurations report `f32` residuals.
570    ///
571    /// The residual history matrix has two columns, even for multiple right-hand sides, and `MaxIters + 1` rows.
572    /// Only the first `OuterNiters + 1` rows contain residual norms; the remaining rows are undefined.
573    /// In the first column, each row `i` contains the total number of iterations performed up to outer iteration `i`.
574    /// In the second column, each row contains the residual norm for that outer iteration.
575    /// Row 0 contains the initial residual before the refinement loop starts, and subsequent rows contain residuals obtained at each outer iteration.
576    /// The history only covers the outer loop.
577    ///
578    /// If the refinement solver was [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], then `OuterNiters == Niters`, and there are `Niters + 1` rows of norms corresponding to the `Niters` outer iterations.
579    ///
580    /// If the refinement solver was [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], then `OuterNiters <= Niters` corresponds to the outer iterations performed by the outer refinement loop.
581    /// There are `OuterNiters + 1` residual norms. Row `i` corresponds to outer iteration `i`; the first column contains the total number of outer and inner iterations performed up to that step, and the second column contains the residual norm at that step.
582    ///
583    /// For example, if [`IrsRefinement::ClassicalGmres`] needs 3 outer iterations to converge and 4, 3, and 3 inner iterations at each outer iteration, it performs 10 total iterations.
584    /// Row 0 corresponds to the first residual before the refinement start, so it has 0 in its first column.
585    /// Row 1 corresponds to outer iteration 1 and contains 4 in its first column, row 2 contains 7, and row 3 contains 10.
586    ///
587    /// In summary, let `ldh = MaxIters + 1`, the leading dimension of the residual matrix. Then `residual_history[i]` contains the total number of iterations performed at outer iteration `i`, and `residual_history[i + ldh]` contains the residual norm at that outer iteration.
588    ///
589    /// # Errors
590    ///
591    /// Returns an error if residual history was not requested before solving,
592    /// or if cuSOLVER rejects the `Infos` structure.
593    pub fn residual_history_f32(&self) -> Result<ResidualHistory<f32>> {
594        if !self.residual_history_requested {
595            return Err(Error::InvalidPrecisionConfiguration);
596        }
597        let (leading_dimension, valid_rows) = self.residual_history_layout()?;
598        let mut history = ptr::null_mut();
599        unsafe {
600            try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
601                self.as_raw(),
602                &raw mut history,
603            ))?;
604            Ok(copy_residual_history(
605                history.cast::<f32>(),
606                leading_dimension,
607                valid_rows,
608            ))
609        }
610    }
611
612    /// Returns the convergence history stored by the IRS solver when [`IrsInfos::request_residual_history`] was called before solving.
613    /// The residual norm type depends on the input and output precision.
614    /// Double-precision real and complex configurations report `f64` residuals, while single-precision real and complex configurations report `f32` residuals.
615    ///
616    /// The residual history matrix has two columns, even for multiple right-hand sides, and `MaxIters + 1` rows.
617    /// Only the first `OuterNiters + 1` rows contain residual norms; the remaining rows are undefined.
618    /// In the first column, each row `i` contains the total number of iterations performed up to outer iteration `i`.
619    /// In the second column, each row contains the residual norm for that outer iteration.
620    /// Row 0 contains the initial residual before the refinement loop starts, and subsequent rows contain residuals obtained at each outer iteration.
621    /// The history only covers the outer loop.
622    ///
623    /// If the refinement solver was [`IrsRefinement::Classical`] or [`IrsRefinement::Gmres`], then `OuterNiters == Niters`, and there are `Niters + 1` rows of norms corresponding to the `Niters` outer iterations.
624    ///
625    /// If the refinement solver was [`IrsRefinement::ClassicalGmres`] or [`IrsRefinement::GmresGmres`], then `OuterNiters <= Niters` corresponds to the outer iterations performed by the outer refinement loop.
626    /// There are `OuterNiters + 1` residual norms. Row `i` corresponds to outer iteration `i`; the first column contains the total number of outer and inner iterations performed up to that step, and the second column contains the residual norm at that step.
627    ///
628    /// For example, if [`IrsRefinement::ClassicalGmres`] needs 3 outer iterations to converge and 4, 3, and 3 inner iterations at each outer iteration, it performs 10 total iterations.
629    /// Row 0 corresponds to the first residual before the refinement start, so it has 0 in its first column.
630    /// Row 1 corresponds to outer iteration 1 and contains 4 in its first column, row 2 contains 7, and row 3 contains 10.
631    ///
632    /// In summary, let `ldh = MaxIters + 1`, the leading dimension of the residual matrix. Then `residual_history[i]` contains the total number of iterations performed at outer iteration `i`, and `residual_history[i + ldh]` contains the residual norm at that outer iteration.
633    ///
634    /// # Errors
635    ///
636    /// Returns an error if residual history was not requested before solving,
637    /// or if cuSOLVER rejects the `Infos` structure.
638    pub fn residual_history_f64(&self) -> Result<ResidualHistory<f64>> {
639        if !self.residual_history_requested {
640            return Err(Error::InvalidPrecisionConfiguration);
641        }
642        let (leading_dimension, valid_rows) = self.residual_history_layout()?;
643        let mut history = ptr::null_mut();
644        unsafe {
645            try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
646                self.as_raw(),
647                &raw mut history,
648            ))?;
649            Ok(copy_residual_history(
650                history.cast::<f64>(),
651                leading_dimension,
652                valid_rows,
653            ))
654        }
655    }
656
657    pub fn as_raw(&self) -> sys::cusolverDnIRSInfos_t {
658        self.handle
659    }
660
661    fn residual_history_layout(&self) -> Result<(usize, usize)> {
662        let leading_dimension = self
663            .max_iterations()?
664            .checked_add(1)
665            .ok_or(Error::InvalidResidualHistory)
666            .and_then(|value| {
667                usize::try_from(value).map_err(|_| Error::OutOfRange {
668                    name: "residual history leading dimension".into(),
669                })
670            })?;
671        let valid_rows = self
672            .outer_niters()?
673            .checked_add(1)
674            .ok_or(Error::InvalidResidualHistory)
675            .and_then(|value| {
676                usize::try_from(value).map_err(|_| Error::OutOfRange {
677                    name: "residual history rows".into(),
678                })
679            })?;
680
681        if valid_rows > leading_dimension {
682            return Err(Error::InvalidResidualHistory);
683        }
684
685        Ok((leading_dimension, valid_rows))
686    }
687}
688
689impl Drop for IrsInfos {
690    fn drop(&mut self) {
691        unsafe {
692            if let Err(err) = try_ffi!(sys::cusolverDnIRSInfosDestroy(self.handle)) {
693                #[cfg(debug_assertions)]
694                eprintln!("failed to destroy cusolver irs infos: {err}");
695            }
696        }
697    }
698}
699
700pub fn xgesv_buffer_size<T: DataTypeLike>(
701    ctx: &Context,
702    params: &mut IrsParams,
703    n: usize,
704    nrhs: usize,
705) -> Result<usize> {
706    ctx.bind()?;
707    if n == 0 || nrhs == 0 {
708        return Err(Error::InvalidMatrixShape);
709    }
710    params.ensure_type_precision::<T>()?;
711    let mut workspace_bytes = 0;
712    unsafe {
713        try_ffi!(sys::cusolverDnIRSXgesv_bufferSize(
714            ctx.as_raw(),
715            params.as_raw(),
716            to_i32(n, "n")?,
717            to_i32(nrhs, "nrhs")?,
718            &raw mut workspace_bytes,
719        ))?;
720    }
721    Ok(workspace_bytes as usize)
722}
723
724/// Provides the same solve as the typed cuSOLVER `gesv` entry
725/// points, but through a generic Rust wrapper that exposes IRS configuration
726/// and reporting more directly.
727/// [`xgesv`] allows additional control of the solver parameters such as setting:
728///
729/// * the main precision (input/output precision) of the solver
730/// * the lowest precision to be used internally by the solver
731/// * the refinement solver type
732/// * the maximum allowed number of iterations in the refinement phase
733/// * the tolerance of the refinement solver
734/// * the fallback to main precision
735/// * and more
736///
737/// through [`IrsParams`] and its helper methods.
738/// Moreover, [`xgesv`] provides additional output information such as the convergence history (for example, residual norms) at each iteration and the number of iterations needed to converge.
739/// [`IrsInfos`] exposes the information reported for a particular solve.
740///
741/// The returned value describes the solving process.
742/// `Ok` indicates that the solve finished successfully. An error indicates that one of the arguments is incorrect, that the parameter or info structures are misconfigured, or that the solve did not finish successfully.
743/// Check `niters` and `dinfo` for additional error details.
744/// Provide the required device workspace through `workspace`.
745/// Query the required byte count with [`xgesv_buffer_size`].
746/// Apply any required configuration through the parameter structure before calling [`xgesv_buffer_size`] so the workspace size matches that configuration.
747///
748/// Tensor Float (TF32), introduced with NVIDIA Ampere architecture GPUs, is the most robust tensor core accelerated compute mode for the iterative refinement solver.
749/// It solves a broad range of HPC problems and can provide up to 4x and 5x
750/// speedups for real and complex systems, respectively.
751/// On Volta and Turing architecture GPUs, half precision tensor core acceleration is recommended.
752/// In cases where the iterative refinement solver fails to converge to the desired accuracy (main precision, input/output data precision), it is recommended to use main precision as internal lowest precision.
753///
754/// The following table provides all possible lowest-precision values corresponding to the input/output data type.
755/// If the lowest precision matches the input/output data type, the main
756/// precision factorization is used.
757///
758/// **Supported input/output data type and lower precision for the IRS solver**
759///
760/// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
761/// | --- | --- |
762/// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
763/// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
764/// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
765/// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
766///
767/// [`xgesv_buffer_size`] returns the required workspace size in bytes for the
768/// current [`IrsParams`] configuration.
769///
770/// # Errors
771///
772/// Returns an error if cuSOLVER rejects the matrix dimensions, leading
773/// dimensions, parameter structure, info structure, or workspace. The workspace
774/// can become invalid if [`xgesv_buffer_size`] is called and then an IRS
775/// configuration value, such as the lowest precision, is changed. cuSOLVER can
776/// also report an error if host memory allocation fails, if the selected IRS
777/// configuration is not supported on the current GPU architecture, if the
778/// library has not been initialized, or if the solve ends with an internal or
779/// numerical failure. Check `niters` and `dinfo` for additional solver details.
780pub fn xgesv<T: DataTypeLike>(
781    ctx: &Context,
782    params: &mut IrsParams,
783    infos: &IrsInfos,
784    n: usize,
785    nrhs: usize,
786    a: MatrixMut<'_, T>,
787    b: MatrixRef<'_, T>,
788    x: MatrixMut<'_, T>,
789    device_workspace: &mut DeviceMemory<u8>,
790    dev_info: &mut DeviceMemory<i32>,
791) -> Result<i32> {
792    ctx.bind()?;
793    validate_matrix(n, n, a.data.len(), a.leading_dimension)?;
794    validate_matrix(n, nrhs, b.data.len(), b.leading_dimension)?;
795    validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
796    require_info_buffer(dev_info)?;
797    let workspace_bytes = xgesv_buffer_size::<T>(ctx, params, n, nrhs)?;
798    require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
799    let mut niters = 0;
800    unsafe {
801        try_ffi!(sys::cusolverDnIRSXgesv(
802            ctx.as_raw(),
803            params.as_raw(),
804            infos.as_raw(),
805            to_i32(n, "n")?,
806            to_i32(nrhs, "nrhs")?,
807            a.data.as_mut_ptr() as _,
808            to_i32(a.leading_dimension, "ldda")?,
809            b.data.as_ptr() as _,
810            to_i32(b.leading_dimension, "lddb")?,
811            x.data.as_mut_ptr() as _,
812            to_i32(x.leading_dimension, "lddx")?,
813            device_workspace.as_mut_ptr() as _,
814            to_u64(workspace_bytes, "lwork_bytes")?,
815            &raw mut niters,
816            dev_info.as_mut_ptr() as _,
817        ))?;
818    }
819    Ok(niters)
820}
821
822pub fn xgels_buffer_size<T: DataTypeLike>(
823    ctx: &Context,
824    params: &mut IrsParams,
825    m: usize,
826    n: usize,
827    nrhs: usize,
828) -> Result<usize> {
829    ctx.bind()?;
830    if m == 0 || n == 0 || nrhs == 0 || n > m {
831        return Err(Error::InvalidMatrixShape);
832    }
833    params.ensure_type_precision::<T>()?;
834    let mut workspace_bytes = 0;
835    unsafe {
836        try_ffi!(sys::cusolverDnIRSXgels_bufferSize(
837            ctx.as_raw(),
838            params.as_raw(),
839            to_i32(m, "m")?,
840            to_i32(n, "n")?,
841            to_i32(nrhs, "nrhs")?,
842            &raw mut workspace_bytes,
843        ))?;
844    }
845    Ok(workspace_bytes as usize)
846}
847
848/// Provides the same solve as the typed cuSOLVER `gels` entry
849/// points, but through a generic Rust wrapper that exposes IRS configuration
850/// and reporting more directly.
851/// [`xgels`] allows additional control of the solver parameters such as setting:
852///
853/// * the main precision (input/output precision) of the solver,
854/// * the lowest precision to be used internally by the solver,
855/// * the refinement solver type
856/// * the maximum allowed number of iterations in the refinement phase
857/// * the tolerance of the refinement solver
858/// * the fallback to main precision
859/// * and others
860///
861/// through [`IrsParams`] and its helper methods.
862/// Moreover, [`xgels`] provides additional output information such as the convergence history (for example, residual norms) at each iteration and the number of iterations needed to converge.
863/// [`IrsInfos`] exposes the information reported for a particular solve.
864///
865/// The returned value describes the solving process.
866/// `Ok` indicates that the solve finished successfully. An error indicates that one of the arguments is incorrect, that the parameter or info structures are misconfigured, or that the solve did not finish successfully.
867/// Check `niters` and `dinfo` for additional error details.
868/// Provide the required device workspace through `workspace`.
869/// Query the required byte count with [`xgels_buffer_size`].
870/// Apply any required configuration through the parameter structure before calling [`xgels_buffer_size`] so the workspace size matches that configuration.
871///
872/// The following table provides all possible lowest-precision values corresponding to the input/output data type.
873/// If the lowest precision matches the input/output data type, the main
874/// precision factorization is used.
875///
876/// Tensor Float (TF32), introduced with NVIDIA Ampere architecture GPUs, is the most robust tensor core accelerated compute mode for the iterative refinement solver.
877/// It solves a broad range of HPC problems and can provide up to 4x and 5x
878/// speedups for real and complex systems, respectively.
879/// On Volta and Turing architecture GPUs, half precision tensor core acceleration is recommended.
880/// In cases where the iterative refinement solver fails to converge to the desired accuracy (main precision, input/output data precision), it is recommended to use main precision as internal lowest precision.
881///
882/// **Supported input/output data type and lower precision for the IRS solver**
883///
884/// | **input/output Data Type (for example, main precision)** | **Supported values for the lowest precision** |
885/// | --- | --- |
886/// | [`PrecisionType::C64F`] | [`PrecisionType::C64F`], [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
887/// | [`PrecisionType::C32F`] | [`PrecisionType::C32F`], [`PrecisionType::C16F`], [`PrecisionType::C16Bf`], [`PrecisionType::CTf32`] |
888/// | [`PrecisionType::R64F`] | [`PrecisionType::R64F`], [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
889/// | [`PrecisionType::R32F`] | [`PrecisionType::R32F`], [`PrecisionType::R16F`], [`PrecisionType::R16Bf`], [`PrecisionType::RTf32`] |
890///
891/// [`xgels_buffer_size`] returns the required workspace size in bytes for the
892/// current [`IrsParams`] configuration.
893///
894/// # Errors
895///
896/// Returns an error if cuSOLVER rejects the matrix dimensions, leading
897/// dimensions, parameter structure, info structure, or workspace. The workspace
898/// can become invalid if [`xgels_buffer_size`] is called and then an IRS
899/// configuration value, such as the lowest precision, is changed. cuSOLVER can
900/// also report an error if host memory allocation fails, if the selected IRS
901/// configuration is not supported on the current GPU architecture, if the
902/// library has not been initialized, or if the solve ends with an internal or
903/// numerical failure. Check `niters` and `dinfo` for additional solver details.
904pub fn xgels<T: DataTypeLike>(
905    ctx: &Context,
906    params: &mut IrsParams,
907    infos: &IrsInfos,
908    m: usize,
909    n: usize,
910    nrhs: usize,
911    a: MatrixMut<'_, T>,
912    b: MatrixRef<'_, T>,
913    x: MatrixMut<'_, T>,
914    device_workspace: &mut DeviceMemory<u8>,
915    dev_info: &mut DeviceMemory<i32>,
916) -> Result<i32> {
917    ctx.bind()?;
918    if n > m {
919        return Err(Error::InvalidMatrixShape);
920    }
921    validate_matrix(m, n, a.data.len(), a.leading_dimension)?;
922    validate_matrix(m, nrhs, b.data.len(), b.leading_dimension)?;
923    validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
924    require_info_buffer(dev_info)?;
925    let workspace_bytes = xgels_buffer_size::<T>(ctx, params, m, n, nrhs)?;
926    require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
927    let mut niters = 0;
928    unsafe {
929        try_ffi!(sys::cusolverDnIRSXgels(
930            ctx.as_raw(),
931            params.as_raw(),
932            infos.as_raw(),
933            to_i32(m, "m")?,
934            to_i32(n, "n")?,
935            to_i32(nrhs, "nrhs")?,
936            a.data.as_mut_ptr() as _,
937            to_i32(a.leading_dimension, "ldda")?,
938            b.data.as_ptr() as _,
939            to_i32(b.leading_dimension, "lddb")?,
940            x.data.as_mut_ptr() as _,
941            to_i32(x.leading_dimension, "lddx")?,
942            device_workspace.as_mut_ptr() as _,
943            to_u64(workspace_bytes, "lwork_bytes")?,
944            &raw mut niters,
945            dev_info.as_mut_ptr() as _,
946        ))?;
947    }
948    Ok(niters)
949}
950
951fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
952    if dev_info.is_empty() {
953        return Err(Error::InvalidVectorShape);
954    }
955    Ok(())
956}
957
958fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
959    if actual < required {
960        return Err(Error::InsufficientWorkspaceSize { required, actual });
961    }
962    Ok(())
963}
964
965unsafe fn copy_residual_history<T: Copy>(
966    history: *const T,
967    leading_dimension: usize,
968    valid_rows: usize,
969) -> ResidualHistory<T> {
970    let history = unsafe { slice::from_raw_parts(history, leading_dimension.saturating_mul(2)) };
971    let mut rows = Vec::with_capacity(valid_rows);
972    for row in 0..valid_rows {
973        rows.push(ResidualHistoryEntry {
974            total_iterations: history[row],
975            residual_norm: history[row + leading_dimension],
976        });
977    }
978    ResidualHistory {
979        rows,
980        leading_dimension,
981    }
982}
983
984fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
985    if rows == 0 || cols == 0 {
986        return Err(Error::InvalidMatrixShape);
987    }
988    if lda < rows {
989        return Err(Error::InvalidLeadingDimension);
990    }
991    let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
992    if len < required {
993        return Err(Error::InvalidMatrixShape);
994    }
995    Ok(())
996}
997
998#[cfg(all(test, feature = "testing"))]
999mod tests {
1000    use singe_cuda::memory::DeviceMemory;
1001
1002    use super::*;
1003    use crate::testing::setup_context_if_available;
1004
1005    #[test]
1006    fn test_xgesv_solves_diagonal_system() -> Result<()> {
1007        let Some(ctx) = setup_context_if_available()? else {
1008            return Ok(());
1009        };
1010        let mut params = IrsParams::create()?;
1011        let infos = IrsInfos::create()?;
1012
1013        let mut a = DeviceMemory::from_slice(&[
1014            2.0_f32, 0.0, //
1015            0.0, 4.0,
1016        ])?;
1017        let b = DeviceMemory::from_slice(&[
1018            6.0_f32, //
1019            8.0,
1020        ])?;
1021        let mut x = DeviceMemory::create(2)?;
1022        let workspace_bytes = xgesv_buffer_size::<f32>(&ctx, &mut params, 2, 1)?;
1023        let mut workspace = DeviceMemory::create(workspace_bytes.max(1))?;
1024        let mut dev_info = DeviceMemory::create(1)?;
1025
1026        let _ = xgesv(
1027            &ctx,
1028            &mut params,
1029            &infos,
1030            2,
1031            1,
1032            MatrixMut::new(&mut a, 2),
1033            MatrixRef::new(&b, 2),
1034            MatrixMut::new(&mut x, 2),
1035            &mut workspace,
1036            &mut dev_info,
1037        )?;
1038
1039        assert_eq!(dev_info.copy_to_host_vec()?, vec![0]);
1040        assert_eq!(x.copy_to_host_vec()?, vec![3.0, 2.0]);
1041        Ok(())
1042    }
1043}