Skip to main content

oxicuda_solver/dense/
matrix_functions.rs

1//! Matrix functions: exponential, logarithm, and square root.
2//!
3//! Provides GPU-accelerated matrix function computations via PTX kernel
4//! generation:
5//!
6//! - **Matrix Exponential (`expm`)**: Scaling and squaring with Padé
7//!   approximation. Computes `e^A` for a square matrix A.
8//! - **Matrix Logarithm (`logm`)**: Inverse scaling and squaring with Padé
9//!   approximation. Computes `log(A)` for a matrix with eigenvalues in the
10//!   right half-plane.
11//! - **Matrix Square Root (`sqrtm`)**: Denman–Beavers iteration. Computes
12//!   `A^{1/2}` such that `sqrtm(A) * sqrtm(A) = A`.
13//!
14//! Each plan struct generates self-contained PTX kernels using
15//! [`KernelBuilder`]/[`BodyBuilder`] from the `oxicuda-ptx` crate.
16
17#![allow(dead_code)]
18
19use oxicuda_ptx::ir::PtxType;
20use oxicuda_ptx::prelude::*;
21
22use crate::error::{SolverError, SolverResult};
23
24// ---------------------------------------------------------------------------
25// Padé coefficients
26// ---------------------------------------------------------------------------
27
28/// Padé coefficients for the matrix exponential (numerator/denominator
29/// polynomial of `[p/p]` approximant to `e^x`).
30///
31/// Returns `(numerator_coeffs, denominator_coeffs)` for the given order.
32/// Orders 3, 5, 7, 9, 13 are supported, matching the Higham (2005) algorithm.
33fn pade_coefficients(order: u32) -> SolverResult<Vec<f64>> {
34    match order {
35        3 => Ok(vec![120.0, 60.0, 12.0, 1.0]),
36        5 => Ok(vec![30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]),
37        7 => Ok(vec![
38            17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0,
39        ]),
40        9 => Ok(vec![
41            17643225600.0,
42            8821612800.0,
43            2075673600.0,
44            302702400.0,
45            30270240.0,
46            2162160.0,
47            110880.0,
48            3960.0,
49            90.0,
50            1.0,
51        ]),
52        13 => Ok(vec![
53            64764752532480000.0,
54            32382376266240000.0,
55            7771770303897600.0,
56            1187353796428800.0,
57            129060195264000.0,
58            10559470521600.0,
59            670442572800.0,
60            33522128640.0,
61            1323241920.0,
62            40840800.0,
63            960960.0,
64            16380.0,
65            182.0,
66            1.0,
67        ]),
68        _ => Err(SolverError::InternalError(format!(
69            "unsupported Padé order {order}; valid orders are 3, 5, 7, 9, 13"
70        ))),
71    }
72}
73
74/// Theta thresholds for each Padé order. If `||A|| <= theta[m]`, then the
75/// Padé approximant of that order achieves unit-roundoff accuracy.
76#[allow(clippy::excessive_precision)]
77fn pade_theta(order: u32) -> SolverResult<f64> {
78    match order {
79        3 => Ok(1.495_585_217_958_292e-2),
80        5 => Ok(2.539_398_330_063_230e-1),
81        7 => Ok(9.504_178_996_162_932e-1),
82        9 => Ok(2.097_847_961_257_068),
83        13 => Ok(5.371_920_351_148_152),
84        _ => Err(SolverError::InternalError(format!(
85            "no theta for Padé order {order}"
86        ))),
87    }
88}
89
90// =========================================================================
91// Matrix Exponential (expm)
92// =========================================================================
93
94/// Configuration for matrix exponential computation.
95#[derive(Debug, Clone)]
96pub struct MatrixExpConfig {
97    /// Matrix dimension (n × n).
98    pub n: u32,
99    /// Precision: `"f32"` or `"f64"`.
100    pub precision: String,
101    /// Padé approximant order (3, 5, 7, 9, or 13). Default: 13.
102    pub pade_order: u32,
103}
104
105impl MatrixExpConfig {
106    /// Creates a new configuration with sensible defaults.
107    pub fn new(n: u32, precision: &str) -> Self {
108        Self {
109            n,
110            precision: precision.to_string(),
111            pade_order: 13,
112        }
113    }
114
115    /// Sets the Padé order.
116    pub fn with_pade_order(mut self, order: u32) -> Self {
117        self.pade_order = order;
118        self
119    }
120
121    /// Validates the configuration.
122    fn validate(&self) -> SolverResult<()> {
123        if self.n == 0 {
124            return Err(SolverError::DimensionMismatch(
125                "expm: matrix dimension must be > 0".into(),
126            ));
127        }
128        if self.precision != "f32" && self.precision != "f64" {
129            return Err(SolverError::InternalError(format!(
130                "expm: unsupported precision '{}'; use 'f32' or 'f64'",
131                self.precision
132            )));
133        }
134        // Validate Padé order.
135        pade_coefficients(self.pade_order)?;
136        Ok(())
137    }
138}
139
140/// Execution plan for the matrix exponential.
141///
142/// The plan pre-computes the Padé coefficients and kernel names, then
143/// generates PTX on demand.
144#[derive(Debug, Clone)]
145pub struct MatrixExpPlan {
146    config: MatrixExpConfig,
147    pade_coeffs: Vec<f64>,
148    theta: f64,
149}
150
151impl MatrixExpPlan {
152    /// Creates a plan from a validated configuration.
153    pub fn new(config: MatrixExpConfig) -> SolverResult<Self> {
154        config.validate()?;
155        let pade_coeffs = pade_coefficients(config.pade_order)?;
156        let theta = pade_theta(config.pade_order)?;
157        Ok(Self {
158            config,
159            pade_coeffs,
160            theta,
161        })
162    }
163
164    /// Returns the Padé coefficients used by this plan.
165    pub fn pade_coefficients(&self) -> &[f64] {
166        &self.pade_coeffs
167    }
168
169    /// Returns the theta threshold for the configured Padé order.
170    pub fn theta(&self) -> f64 {
171        self.theta
172    }
173
174    /// Generates PTX source for the matrix exponential kernels.
175    ///
176    /// The generated code contains multiple entry points:
177    /// 1. **scale kernel** — computes `A_scaled = A / 2^s` where `s` is chosen
178    ///    so that `||A_scaled|| <= theta`.
179    /// 2. **Padé numerator/denominator kernels** — evaluates the Padé
180    ///    polynomial pair `P(A)` and `Q(A)` using Horner's method.
181    /// 3. **squaring kernel** — repeated squaring `F = F^{2^s}` via matrix
182    ///    multiply.
183    pub fn generate_ptx(&self) -> SolverResult<String> {
184        let n = self.config.n;
185        let float_ty = precision_to_ptx_type(&self.config.precision)?;
186        let sm = SmVersion::Sm75;
187
188        let mut all_ptx = Vec::new();
189
190        // Kernel 1: Scale matrix by 2^(-s).
191        let scale_ptx = self.emit_scale_kernel(n, float_ty, sm)?;
192        all_ptx.push(scale_ptx);
193
194        // Kernel 2: Padé polynomial evaluation (Horner's method for P(A) and Q(A)).
195        let pade_ptx = self.emit_pade_kernel(n, float_ty, sm)?;
196        all_ptx.push(pade_ptx);
197
198        // Kernel 3: Repeated squaring.
199        let square_ptx = self.emit_squaring_kernel(n, float_ty, sm)?;
200        all_ptx.push(square_ptx);
201
202        Ok(all_ptx.join("\n"))
203    }
204
205    /// Emits PTX for `A_scaled[i,j] = A[i,j] / 2^s`.
206    fn emit_scale_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
207        let name = format!("solver_expm_scale_{}_n{}", ptx_type_suffix(float_ty), n);
208
209        let ptx = KernelBuilder::new(&name)
210            .target(sm)
211            .max_threads_per_block(256)
212            .param("a_ptr", PtxType::U64)
213            .param("out_ptr", PtxType::U64)
214            .param("n", PtxType::U32)
215            .param("scale_exp", PtxType::U32)
216            .body(move |b| {
217                // Each thread handles one matrix element.
218                let gid = b.global_thread_id_x();
219                let n_reg = b.load_param_u32("n");
220                let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
221
222                b.if_lt_u32(gid, total, |b| {
223                    let a_ptr = b.load_param_u64("a_ptr");
224                    let out_ptr = b.load_param_u64("out_ptr");
225                    let scale_exp = b.load_param_u32("scale_exp");
226
227                    // Load element.
228                    let gid_repeat = b.global_thread_id_x();
229                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
230                    let addr = b.byte_offset_addr(a_ptr, gid_repeat.clone(), elem_size);
231                    let val = load_float(b, float_ty, addr);
232
233                    // Compute divisor = 2^scale_exp using IEEE 754 biased exponent.
234                    // For f64: bits = (scale_exp + 1023) << 52.
235                    // For f32: bits = (scale_exp + 127) << 23.
236                    let out_addr = b.byte_offset_addr(out_ptr, gid_repeat, elem_size);
237
238                    let result = if float_ty == PtxType::F64 {
239                        // Widen scale_exp to 64-bit.
240                        let se64 = b.cvt_u32_to_u64(scale_exp);
241                        // Add IEEE 754 f64 exponent bias (1023).
242                        let biased = b.alloc_reg(PtxType::U64);
243                        b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
244                        // Shift left 52 to place in exponent field.
245                        let shift_amt = b.alloc_reg(PtxType::U32);
246                        b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
247                        let bits = b.shl_b64(biased, shift_amt);
248                        // Reinterpret bits as f64.
249                        let divisor = b.alloc_reg(PtxType::F64);
250                        b.raw_ptx(&format!("mov.b64 {divisor}, {bits};"));
251                        // result = val / 2^scale_exp.
252                        let res = b.alloc_reg(PtxType::F64);
253                        b.raw_ptx(&format!("div.rn.f64 {res}, {val}, {divisor};"));
254                        res
255                    } else {
256                        // For f32: bias is 127, exponent field starts at bit 23.
257                        let biased = b.alloc_reg(PtxType::U32);
258                        b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
259                        let shift_amt = b.alloc_reg(PtxType::U32);
260                        b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
261                        let bits = b.shl_b32(biased, shift_amt);
262                        // Reinterpret bits as f32.
263                        let divisor = b.alloc_reg(PtxType::F32);
264                        b.raw_ptx(&format!("mov.b32 {divisor}, {bits};"));
265                        // result = val / 2^scale_exp.
266                        let res = b.alloc_reg(PtxType::F32);
267                        b.raw_ptx(&format!("div.rn.f32 {res}, {val}, {divisor};"));
268                        res
269                    };
270
271                    store_float(b, float_ty, out_addr, result);
272                });
273
274                b.ret();
275            })
276            .build()?;
277
278        Ok(ptx)
279    }
280
281    /// Emits PTX for Padé polynomial evaluation using Horner's method.
282    ///
283    /// Evaluates:
284    ///   P(A) = c_0 * I + c_1 * A + c_2 * A^2 + ... + c_p * A^p
285    ///   Q(A) = c_0 * I - c_1 * A + c_2 * A^2 - ... ± c_p * A^p
286    ///
287    /// where the even coefficients are the same and odd coefficients differ
288    /// only in sign.
289    fn emit_pade_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
290        let order = self.config.pade_order;
291        let name = format!(
292            "solver_expm_pade_{}_n{}_p{}",
293            ptx_type_suffix(float_ty),
294            n,
295            order
296        );
297
298        let ptx = KernelBuilder::new(&name)
299            .target(sm)
300            .max_threads_per_block(256)
301            .param("a_ptr", PtxType::U64)
302            .param("p_ptr", PtxType::U64)
303            .param("q_ptr", PtxType::U64)
304            .param("n", PtxType::U32)
305            .param("coeffs_ptr", PtxType::U64)
306            .param("num_coeffs", PtxType::U32)
307            .body(move |b| {
308                // Each thread computes one element of P(A) and Q(A).
309                // For small matrices, this is feasible; for large matrices,
310                // the host code orchestrates multiple GEMM calls.
311                let gid = b.global_thread_id_x();
312                let n_reg = b.load_param_u32("n");
313                let total = b.mul_lo_u32(n_reg.clone(), n_reg);
314
315                b.if_lt_u32(gid, total, |b| {
316                    let a_ptr = b.load_param_u64("a_ptr");
317                    let p_ptr = b.load_param_u64("p_ptr");
318                    let q_ptr = b.load_param_u64("q_ptr");
319                    let coeffs_ptr = b.load_param_u64("coeffs_ptr");
320                    let num_coeffs = b.load_param_u32("num_coeffs");
321
322                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
323                    // Coefficients are always stored as f64 regardless of precision.
324                    const COEFF_SIZE: u32 = 8u32;
325                    let gid_r = b.global_thread_id_x();
326
327                    // Load A[gid] — the input matrix element.
328                    let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
329                    let a_val = load_float(b, float_ty, a_addr);
330
331                    // Horner's method: traverse coefficients from highest to lowest.
332                    // P(x) = c[m] + x*(c[m-2] + x*(...))  (even-indexed terms)
333                    // Q(x) = c[m] + x*(-c[m-2] + x*(...))  (odd terms negate for Q)
334                    // Simplified scalar Horner: P = c[m], Q = c[m], then for each
335                    // lower degree: P = P*x + c[k], Q = Q*x + sign(k)*c[k].
336                    // We load each coefficient from coeffs_ptr (f64 array).
337                    //
338                    // Loop: idx from (num_coeffs-1) down to 0.
339                    // acc_p = 0, acc_q = 0.
340                    // For k = num_coeffs-1 downto 0:
341                    //   load c_k from coeffs_ptr[k] (f64), convert to float_ty
342                    //   acc_p = fma(acc_p, a_val, c_k)
343                    //   sign = (-1)^k → for Q: if k is odd, negate c_k
344                    //   acc_q = fma(acc_q, a_val, ±c_k)
345
346                    let acc_p = zero_const(b, float_ty);
347                    let acc_q = zero_const(b, float_ty);
348
349                    // idx_reg counts down from (num_coeffs) to 0.
350                    let idx_reg = b.alloc_reg(PtxType::U32);
351                    b.raw_ptx(&format!("mov.u32 {idx_reg}, {num_coeffs};"));
352
353                    let horner_loop = b.fresh_label("horner_loop");
354                    let horner_exit = b.fresh_label("horner_exit");
355
356                    b.raw_ptx(&format!("{horner_loop}:"));
357                    // Check idx_reg == 0; if so, exit.
358                    let done_pred = b.alloc_reg(PtxType::Pred);
359                    b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {idx_reg}, 0;"));
360                    b.raw_ptx(&format!("@{done_pred} bra {horner_exit};"));
361
362                    // idx_reg -= 1 (current coefficient index = idx_reg - 1 after decrement).
363                    b.raw_ptx(&format!("sub.u32 {idx_reg}, {idx_reg}, 1;"));
364
365                    // Load f64 coefficient from coeffs_ptr[idx_reg].
366                    let coeff_addr =
367                        b.byte_offset_addr(coeffs_ptr.clone(), idx_reg.clone(), COEFF_SIZE);
368                    let coeff_f64 = load_float(b, PtxType::F64, coeff_addr);
369
370                    // Convert coefficient to working precision.
371                    let c_k = if float_ty == PtxType::F64 {
372                        coeff_f64.clone()
373                    } else {
374                        // cvt.rn.f32.f64
375                        let dst = b.alloc_reg(PtxType::F32);
376                        b.raw_ptx(&format!("cvt.rn.f32.f64 {dst}, {coeff_f64};"));
377                        dst
378                    };
379
380                    // Horner step for P: acc_p = acc_p * a_val + c_k.
381                    let new_acc_p = if float_ty == PtxType::F64 {
382                        b.fma_f64(acc_p.clone(), a_val.clone(), c_k.clone())
383                    } else {
384                        b.fma_f32(acc_p.clone(), a_val.clone(), c_k.clone())
385                    };
386                    b.raw_ptx(&format!(
387                        "mov{} {acc_p}, {new_acc_p};",
388                        float_ty.as_ptx_str()
389                    ));
390
391                    // For Q: negate c_k when idx_reg is odd (the current index after decrement).
392                    // idx_reg is the coefficient index; odd index → negate for Q.
393                    let odd_pred = b.alloc_reg(PtxType::Pred);
394                    let lsb = b.alloc_reg(PtxType::U32);
395                    b.raw_ptx(&format!("and.b32 {lsb}, {idx_reg}, 1;"));
396                    b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
397
398                    // neg_c_k = -c_k.
399                    let neg_c_k = b.alloc_reg(float_ty);
400                    b.raw_ptx(&format!("neg{} {neg_c_k}, {c_k};", float_ty.as_ptx_str()));
401                    // q_coeff = odd ? neg_c_k : c_k.
402                    let q_coeff = b.alloc_reg(float_ty);
403                    b.raw_ptx(&format!(
404                        "selp{} {q_coeff}, {neg_c_k}, {c_k}, {odd_pred};",
405                        float_ty.as_ptx_str()
406                    ));
407
408                    // Horner step for Q: acc_q = acc_q * a_val + q_coeff.
409                    let new_acc_q = if float_ty == PtxType::F64 {
410                        b.fma_f64(acc_q.clone(), a_val.clone(), q_coeff)
411                    } else {
412                        b.fma_f32(acc_q.clone(), a_val.clone(), q_coeff)
413                    };
414                    b.raw_ptx(&format!(
415                        "mov{} {acc_q}, {new_acc_q};",
416                        float_ty.as_ptx_str()
417                    ));
418
419                    b.raw_ptx(&format!("bra {horner_loop};"));
420                    b.raw_ptx(&format!("{horner_exit}:"));
421
422                    // Store results.
423                    let p_addr = b.byte_offset_addr(p_ptr, gid_r.clone(), elem_size);
424                    let q_addr = b.byte_offset_addr(q_ptr, gid_r, elem_size);
425                    store_float(b, float_ty, p_addr, acc_p);
426                    store_float(b, float_ty, q_addr, acc_q);
427                });
428
429                b.ret();
430            })
431            .build()?;
432
433        Ok(ptx)
434    }
435
436    /// Emits PTX for the repeated squaring step: `F = F * F` applied `s` times.
437    fn emit_squaring_kernel(
438        &self,
439        n: u32,
440        float_ty: PtxType,
441        sm: SmVersion,
442    ) -> SolverResult<String> {
443        let name = format!("solver_expm_square_{}_n{}", ptx_type_suffix(float_ty), n);
444
445        let ptx = KernelBuilder::new(&name)
446            .target(sm)
447            .max_threads_per_block(256)
448            .param("f_ptr", PtxType::U64)
449            .param("tmp_ptr", PtxType::U64)
450            .param("n", PtxType::U32)
451            .body(move |b| {
452                // Each thread computes one element of the product F * F.
453                // Row = gid / n, Col = gid % n.
454                let gid = b.global_thread_id_x();
455                let n_reg = b.load_param_u32("n");
456                let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
457
458                b.if_lt_u32(gid, total, |b| {
459                    let f_ptr = b.load_param_u64("f_ptr");
460                    let tmp_ptr = b.load_param_u64("tmp_ptr");
461                    let n_inner = b.load_param_u32("n");
462
463                    // Decode column-major index: gid = col * n + row.
464                    let gid_r = b.global_thread_id_x();
465                    let row = b.alloc_reg(PtxType::U32);
466                    let col = b.alloc_reg(PtxType::U32);
467                    b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
468                    b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
469
470                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
471
472                    // Accumulate dot product: tmp[col*n+row] = sum_{k=0}^{n-1} F[k*n+row] * F[col*n+k].
473                    // Column-major layout: element (r, c) is at index c*n + r.
474                    let acc = zero_const(b, float_ty);
475                    let k_reg = b.alloc_reg(PtxType::U32);
476                    b.raw_ptx(&format!("mov.u32 {k_reg}, 0;"));
477
478                    let loop_label = b.fresh_label("sq_loop");
479                    let exit_label = b.fresh_label("sq_exit");
480
481                    b.raw_ptx(&format!("{loop_label}:"));
482                    // Check k < n; if not, exit loop.
483                    let pred = b.alloc_reg(PtxType::Pred);
484                    b.raw_ptx(&format!("setp.ge.u32 {pred}, {k_reg}, {n_inner};"));
485                    b.raw_ptx(&format!("@{pred} bra {exit_label};"));
486
487                    // Load F[k*n + row] — k-th column, row-th element (column-major).
488                    let a_idx_base = b.mul_lo_u32(k_reg.clone(), n_inner.clone());
489                    let a_idx = b.add_u32(a_idx_base, row.clone());
490                    let a_addr = b.byte_offset_addr(f_ptr.clone(), a_idx, elem_size);
491                    let a_val = load_float(b, float_ty, a_addr);
492
493                    // Load F[col*n + k] — col-th column, k-th element (column-major).
494                    let b_idx_base = b.mul_lo_u32(col.clone(), n_inner.clone());
495                    let b_idx = b.add_u32(b_idx_base, k_reg.clone());
496                    let b_addr = b.byte_offset_addr(f_ptr.clone(), b_idx, elem_size);
497                    let b_val = load_float(b, float_ty, b_addr);
498
499                    // acc = fma(a_val, b_val, acc).
500                    let new_acc = if float_ty == PtxType::F64 {
501                        b.fma_f64(a_val, b_val, acc.clone())
502                    } else {
503                        b.fma_f32(a_val, b_val, acc.clone())
504                    };
505                    // Move new_acc into acc register via raw PTX.
506                    b.raw_ptx(&format!("mov{} {acc}, {new_acc};", float_ty.as_ptx_str()));
507
508                    // k += 1.
509                    b.raw_ptx(&format!("add.u32 {k_reg}, {k_reg}, 1;"));
510                    b.raw_ptx(&format!("bra {loop_label};"));
511
512                    b.raw_ptx(&format!("{exit_label}:"));
513
514                    // Store accumulated result to tmp[col*n + row].
515                    let out_idx_base = b.mul_lo_u32(col, n_inner);
516                    let out_idx = b.add_u32(out_idx_base, row);
517                    let out_addr = b.byte_offset_addr(tmp_ptr, out_idx, elem_size);
518                    store_float(b, float_ty, out_addr, acc);
519                });
520
521                b.ret();
522            })
523            .build()?;
524
525        Ok(ptx)
526    }
527}
528
529// =========================================================================
530// Matrix Logarithm (logm)
531// =========================================================================
532
533/// Configuration for matrix logarithm computation.
534#[derive(Debug, Clone)]
535pub struct MatrixLogConfig {
536    /// Matrix dimension (n × n).
537    pub n: u32,
538    /// Precision: `"f32"` or `"f64"`.
539    pub precision: String,
540    /// Maximum number of square-root iterations. Default: 100.
541    pub max_sqrt_iters: u32,
542}
543
544impl MatrixLogConfig {
545    /// Creates a new configuration with sensible defaults.
546    pub fn new(n: u32, precision: &str) -> Self {
547        Self {
548            n,
549            precision: precision.to_string(),
550            max_sqrt_iters: 100,
551        }
552    }
553
554    /// Sets the maximum number of square-root iterations.
555    pub fn with_max_sqrt_iters(mut self, iters: u32) -> Self {
556        self.max_sqrt_iters = iters;
557        self
558    }
559
560    /// Validates the configuration.
561    fn validate(&self) -> SolverResult<()> {
562        if self.n == 0 {
563            return Err(SolverError::DimensionMismatch(
564                "logm: matrix dimension must be > 0".into(),
565            ));
566        }
567        if self.precision != "f32" && self.precision != "f64" {
568            return Err(SolverError::InternalError(format!(
569                "logm: unsupported precision '{}'; use 'f32' or 'f64'",
570                self.precision
571            )));
572        }
573        if self.max_sqrt_iters == 0 {
574            return Err(SolverError::InternalError(
575                "logm: max_sqrt_iters must be > 0".into(),
576            ));
577        }
578        Ok(())
579    }
580}
581
582/// Execution plan for the matrix logarithm.
583///
584/// Uses inverse scaling and squaring:
585/// 1. Reduce `A` via repeated matrix square roots until `||A - I||` is small.
586/// 2. Apply Padé approximation of `log(I + X)` for the reduced matrix.
587/// 3. Scale the result back by `2^s`.
588#[derive(Debug, Clone)]
589pub struct MatrixLogPlan {
590    config: MatrixLogConfig,
591}
592
593impl MatrixLogPlan {
594    /// Creates a plan from a validated configuration.
595    pub fn new(config: MatrixLogConfig) -> SolverResult<Self> {
596        config.validate()?;
597        Ok(Self { config })
598    }
599
600    /// Returns the maximum allowed square-root iterations.
601    pub fn max_sqrt_iters(&self) -> u32 {
602        self.config.max_sqrt_iters
603    }
604
605    /// Generates PTX source for the matrix logarithm kernels.
606    ///
607    /// The generated code contains entry points for:
608    /// 1. **shift kernel** — computes `X = A - I`.
609    /// 2. **square-root iteration kernel** — Denman–Beavers step applied to A
610    ///    until `||A - I||` is below threshold.
611    /// 3. **Padé log kernel** — evaluates the `[m/m]` Padé approximant to
612    ///    `log(I + X)` for small `X`.
613    /// 4. **scale-back kernel** — multiplies the result by `2^s`.
614    pub fn generate_ptx(&self) -> SolverResult<String> {
615        let n = self.config.n;
616        let float_ty = precision_to_ptx_type(&self.config.precision)?;
617        let sm = SmVersion::Sm75;
618
619        let mut all_ptx = Vec::new();
620
621        // Kernel 1: A_shifted = A - I.
622        let shift_ptx = self.emit_shift_kernel(n, float_ty, sm)?;
623        all_ptx.push(shift_ptx);
624
625        // Kernel 2: Matrix square root step (for reducing A close to I).
626        let sqrt_step_ptx = self.emit_sqrt_step_kernel(n, float_ty, sm)?;
627        all_ptx.push(sqrt_step_ptx);
628
629        // Kernel 3: Padé approximation of log(I + X).
630        let pade_log_ptx = self.emit_pade_log_kernel(n, float_ty, sm)?;
631        all_ptx.push(pade_log_ptx);
632
633        // Kernel 4: Scale back by 2^s.
634        let scale_ptx = self.emit_scale_back_kernel(n, float_ty, sm)?;
635        all_ptx.push(scale_ptx);
636
637        Ok(all_ptx.join("\n"))
638    }
639
640    /// Emits PTX for `X[i,j] = A[i,j] - delta(i,j)` where delta is Kronecker.
641    fn emit_shift_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
642        let name = format!("solver_logm_shift_{}_n{}", ptx_type_suffix(float_ty), n);
643
644        let ptx = KernelBuilder::new(&name)
645            .target(sm)
646            .max_threads_per_block(256)
647            .param("a_ptr", PtxType::U64)
648            .param("out_ptr", PtxType::U64)
649            .param("n", PtxType::U32)
650            .body(move |b| {
651                let gid = b.global_thread_id_x();
652                let n_reg = b.load_param_u32("n");
653                let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
654
655                b.if_lt_u32(gid, total, |b| {
656                    let a_ptr = b.load_param_u64("a_ptr");
657                    let out_ptr = b.load_param_u64("out_ptr");
658                    let n_inner = b.load_param_u32("n");
659                    let gid_r = b.global_thread_id_x();
660
661                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
662
663                    // Load A[gid].
664                    let src_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
665                    let val = load_float(b, float_ty, src_addr);
666
667                    // Compute row and column to check if on diagonal.
668                    let row = b.alloc_reg(PtxType::U32);
669                    let col = b.alloc_reg(PtxType::U32);
670                    b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
671                    b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
672
673                    // Subtract 1.0 from diagonal elements: out[i,j] = A[i,j] - delta(i,j).
674                    let is_diag = b.alloc_reg(PtxType::Pred);
675                    b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
676                    let one = one_const(b, float_ty);
677                    let zero = zero_const(b, float_ty);
678                    // diag_sub = 1.0 if on diagonal, else 0.0.
679                    let diag_sub = b.alloc_reg(float_ty);
680                    b.raw_ptx(&format!(
681                        "selp{} {diag_sub}, {one}, {zero}, {is_diag};",
682                        float_ty.as_ptx_str()
683                    ));
684                    // result = val - diag_sub.
685                    let result = b.alloc_reg(float_ty);
686                    b.raw_ptx(&format!(
687                        "sub{} {result}, {val}, {diag_sub};",
688                        float_ty.as_ptx_str()
689                    ));
690
691                    let dst_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
692                    store_float(b, float_ty, dst_addr, result);
693                });
694
695                b.ret();
696            })
697            .build()?;
698
699        Ok(ptx)
700    }
701
702    /// Emits PTX for one Denman–Beavers iteration used during the square-root
703    /// reduction phase of the matrix logarithm.
704    fn emit_sqrt_step_kernel(
705        &self,
706        n: u32,
707        float_ty: PtxType,
708        sm: SmVersion,
709    ) -> SolverResult<String> {
710        let name = format!("solver_logm_sqrt_step_{}_n{}", ptx_type_suffix(float_ty), n);
711
712        let ptx = KernelBuilder::new(&name)
713            .target(sm)
714            .max_threads_per_block(256)
715            .param("y_ptr", PtxType::U64)
716            .param("z_ptr", PtxType::U64)
717            .param("y_next_ptr", PtxType::U64)
718            .param("z_next_ptr", PtxType::U64)
719            .param("n", PtxType::U32)
720            .body(move |b| {
721                let gid = b.global_thread_id_x();
722                let n_reg = b.load_param_u32("n");
723                let total = b.mul_lo_u32(n_reg.clone(), n_reg);
724
725                b.if_lt_u32(gid, total, |b| {
726                    let y_ptr = b.load_param_u64("y_ptr");
727                    let z_ptr = b.load_param_u64("z_ptr");
728                    let y_next_ptr = b.load_param_u64("y_next_ptr");
729                    let z_next_ptr = b.load_param_u64("z_next_ptr");
730                    let n_inner = b.load_param_u32("n");
731                    let gid_r = b.global_thread_id_x();
732                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
733
734                    // Compute row and column for diagonal check.
735                    let row = b.alloc_reg(PtxType::U32);
736                    let col = b.alloc_reg(PtxType::U32);
737                    b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
738                    b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
739
740                    let is_diag = b.alloc_reg(PtxType::Pred);
741                    b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
742                    let one = one_const(b, float_ty);
743                    let zero = zero_const(b, float_ty);
744
745                    // After host computes M_y = Y_k * Z_k^{-1} (stored in y_ptr) and
746                    // M_z = Z_k * Y_k^{-1} (stored in z_ptr), this kernel computes:
747                    //   Y_{k+1}[i,j] = (M_y[i,j] + delta(i,j)) / 2
748                    //   Z_{k+1}[i,j] = (M_z[i,j] + delta(i,j)) / 2
749                    let diag_add = b.alloc_reg(float_ty);
750                    b.raw_ptx(&format!(
751                        "selp{} {diag_add}, {one}, {zero}, {is_diag};",
752                        float_ty.as_ptx_str()
753                    ));
754                    let half = half_const(b, float_ty);
755
756                    // Process Y channel.
757                    let y_src = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
758                    let y_val = load_float(b, float_ty, y_src);
759                    let y_sum = b.alloc_reg(float_ty);
760                    b.raw_ptx(&format!(
761                        "add{} {y_sum}, {y_val}, {diag_add};",
762                        float_ty.as_ptx_str()
763                    ));
764                    let y_result = b.alloc_reg(float_ty);
765                    b.raw_ptx(&format!(
766                        "mul{} {y_result}, {y_sum}, {half};",
767                        float_ty.as_ptx_str()
768                    ));
769                    let y_dst = b.byte_offset_addr(y_next_ptr, gid_r.clone(), elem_size);
770                    store_float(b, float_ty, y_dst, y_result);
771
772                    // Process Z channel.
773                    let z_src = b.byte_offset_addr(z_ptr, gid_r.clone(), elem_size);
774                    let z_val = load_float(b, float_ty, z_src);
775                    let z_sum = b.alloc_reg(float_ty);
776                    b.raw_ptx(&format!(
777                        "add{} {z_sum}, {z_val}, {diag_add};",
778                        float_ty.as_ptx_str()
779                    ));
780                    let z_result = b.alloc_reg(float_ty);
781                    b.raw_ptx(&format!(
782                        "mul{} {z_result}, {z_sum}, {half};",
783                        float_ty.as_ptx_str()
784                    ));
785                    let z_dst = b.byte_offset_addr(z_next_ptr, gid_r, elem_size);
786                    store_float(b, float_ty, z_dst, z_result);
787                });
788
789                b.ret();
790            })
791            .build()?;
792
793        Ok(ptx)
794    }
795
796    /// Emits PTX for the Padé approximation of `log(I + X)` for small `X`.
797    ///
798    /// Uses a diagonal Padé approximant:
799    ///   `log(I + X) ≈ P(X) * Q(X)^{-1}`
800    ///
801    /// where `P` and `Q` are matrix polynomials.
802    fn emit_pade_log_kernel(
803        &self,
804        n: u32,
805        float_ty: PtxType,
806        sm: SmVersion,
807    ) -> SolverResult<String> {
808        let name = format!("solver_logm_pade_{}_n{}", ptx_type_suffix(float_ty), n);
809
810        let ptx = KernelBuilder::new(&name)
811            .target(sm)
812            .max_threads_per_block(256)
813            .param("x_ptr", PtxType::U64)
814            .param("result_ptr", PtxType::U64)
815            .param("n", PtxType::U32)
816            .param("num_terms", PtxType::U32)
817            .body(move |b| {
818                let gid = b.global_thread_id_x();
819                let n_reg = b.load_param_u32("n");
820                let total = b.mul_lo_u32(n_reg.clone(), n_reg);
821
822                b.if_lt_u32(gid, total, |b| {
823                    let x_ptr = b.load_param_u64("x_ptr");
824                    let result_ptr = b.load_param_u64("result_ptr");
825                    let num_terms = b.load_param_u32("num_terms");
826                    let gid_r = b.global_thread_id_x();
827                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
828
829                    // Scalar element-wise evaluation of truncated log(1+x) series.
830                    // For small |x|, log(1+x) ≈ sum_{k=1}^{m} (-1)^{k+1} * x^k / k.
831                    // We use Horner's method from the highest term down to k=1.
832                    //
833                    // Horner form: log(1+x) ≈ x * (1 - x/2 * (1 - x/3 * (1 - ...)))
834                    // i.e. evaluate inner to outer.
835                    //
836                    // Algorithm: start with acc = 1/num_terms, then for k = num_terms-1 down to 1:
837                    //   acc = 1/k - x * acc   (alternating sign absorbed into the 1/k term)
838                    // then result = x * acc.
839
840                    let src = b.byte_offset_addr(x_ptr, gid_r.clone(), elem_size);
841                    let x_val = load_float(b, float_ty, src);
842
843                    // acc_reg holds the running Horner accumulator.
844                    let acc_reg = b.alloc_reg(float_ty);
845                    // Initialize acc = 0.
846                    let zero = zero_const(b, float_ty);
847                    b.raw_ptx(&format!("mov{} {acc_reg}, {zero};", float_ty.as_ptx_str()));
848
849                    // k_reg starts at num_terms and decrements to 1.
850                    let k_reg = b.alloc_reg(PtxType::U32);
851                    b.raw_ptx(&format!("mov.u32 {k_reg}, {num_terms};"));
852
853                    let log_loop = b.fresh_label("log_loop");
854                    let log_exit = b.fresh_label("log_exit");
855
856                    b.raw_ptx(&format!("{log_loop}:"));
857                    // Exit when k_reg == 0.
858                    let done_pred = b.alloc_reg(PtxType::Pred);
859                    b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {k_reg}, 0;"));
860                    b.raw_ptx(&format!("@{done_pred} bra {log_exit};"));
861
862                    // Convert k to float: k_f = (float) k_reg.
863                    let k_f = b.alloc_reg(float_ty);
864                    if float_ty == PtxType::F64 {
865                        b.raw_ptx(&format!("cvt.rn.f64.u32 {k_f}, {k_reg};"));
866                    } else {
867                        b.raw_ptx(&format!("cvt.rn.f32.u32 {k_f}, {k_reg};"));
868                    }
869
870                    // inv_k = 1.0 / k_f.
871                    let inv_k = if float_ty == PtxType::F64 {
872                        b.rcp_f64(k_f)
873                    } else {
874                        b.rcp_f32(k_f)
875                    };
876
877                    // Determine sign: (-1)^{k+1} = +1 when k is odd, -1 when k is even.
878                    let odd_pred = b.alloc_reg(PtxType::Pred);
879                    let lsb = b.alloc_reg(PtxType::U32);
880                    b.raw_ptx(&format!("and.b32 {lsb}, {k_reg}, 1;"));
881                    b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
882
883                    let neg_inv_k = b.alloc_reg(float_ty);
884                    b.raw_ptx(&format!(
885                        "neg{} {neg_inv_k}, {inv_k};",
886                        float_ty.as_ptx_str()
887                    ));
888                    // signed_inv_k = odd ? +inv_k : -inv_k  (sign = (-1)^{k+1}).
889                    let signed_inv_k = b.alloc_reg(float_ty);
890                    b.raw_ptx(&format!(
891                        "selp{} {signed_inv_k}, {inv_k}, {neg_inv_k}, {odd_pred};",
892                        float_ty.as_ptx_str()
893                    ));
894
895                    // Horner step: acc = signed_inv_k + x * acc  → but we accumulate as
896                    // acc_new = signed_inv_k - x_val * acc (matches series structure for log).
897                    // Actually use: acc = fma(x_val, acc, signed_inv_k).
898                    let new_acc = if float_ty == PtxType::F64 {
899                        b.fma_f64(x_val.clone(), acc_reg.clone(), signed_inv_k)
900                    } else {
901                        b.fma_f32(x_val.clone(), acc_reg.clone(), signed_inv_k)
902                    };
903                    b.raw_ptx(&format!(
904                        "mov{} {acc_reg}, {new_acc};",
905                        float_ty.as_ptx_str()
906                    ));
907
908                    // k -= 1.
909                    b.raw_ptx(&format!("sub.u32 {k_reg}, {k_reg}, 1;"));
910                    b.raw_ptx(&format!("bra {log_loop};"));
911                    b.raw_ptx(&format!("{log_exit}:"));
912
913                    // Result = x * acc (the k=0 term gives the leading x factor).
914                    let result = if float_ty == PtxType::F64 {
915                        let r = b.alloc_reg(PtxType::F64);
916                        b.raw_ptx(&format!("mul.rn.f64 {r}, {x_val}, {acc_reg};"));
917                        r
918                    } else {
919                        let r = b.alloc_reg(PtxType::F32);
920                        b.raw_ptx(&format!("mul.rn.f32 {r}, {x_val}, {acc_reg};"));
921                        r
922                    };
923
924                    let dst = b.byte_offset_addr(result_ptr, gid_r, elem_size);
925                    store_float(b, float_ty, dst, result);
926                });
927
928                b.ret();
929            })
930            .build()?;
931
932        Ok(ptx)
933    }
934
935    /// Emits PTX for scaling the result by `2^s`: `result *= 2^s`.
936    fn emit_scale_back_kernel(
937        &self,
938        n: u32,
939        float_ty: PtxType,
940        sm: SmVersion,
941    ) -> SolverResult<String> {
942        let name = format!(
943            "solver_logm_scale_back_{}_n{}",
944            ptx_type_suffix(float_ty),
945            n
946        );
947
948        let ptx = KernelBuilder::new(&name)
949            .target(sm)
950            .max_threads_per_block(256)
951            .param("result_ptr", PtxType::U64)
952            .param("n", PtxType::U32)
953            .param("scale_exp", PtxType::U32)
954            .body(move |b| {
955                let gid = b.global_thread_id_x();
956                let n_reg = b.load_param_u32("n");
957                let total = b.mul_lo_u32(n_reg.clone(), n_reg);
958
959                b.if_lt_u32(gid, total, |b| {
960                    let result_ptr = b.load_param_u64("result_ptr");
961                    let scale_exp = b.load_param_u32("scale_exp");
962                    let gid_r = b.global_thread_id_x();
963                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
964
965                    // Multiply each element by 2^scale_exp using IEEE 754 bit construction.
966                    // For f64: bits = (scale_exp + 1023) << 52.
967                    // For f32: bits = (scale_exp + 127) << 23.
968                    let addr = b.byte_offset_addr(result_ptr, gid_r, elem_size);
969                    let val = load_float(b, float_ty, addr.clone());
970
971                    let result = if float_ty == PtxType::F64 {
972                        let se64 = b.cvt_u32_to_u64(scale_exp);
973                        let biased = b.alloc_reg(PtxType::U64);
974                        b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
975                        let shift_amt = b.alloc_reg(PtxType::U32);
976                        b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
977                        let bits = b.shl_b64(biased, shift_amt);
978                        let factor = b.alloc_reg(PtxType::F64);
979                        b.raw_ptx(&format!("mov.b64 {factor}, {bits};"));
980                        let res = b.alloc_reg(PtxType::F64);
981                        b.raw_ptx(&format!("mul.rn.f64 {res}, {val}, {factor};"));
982                        res
983                    } else {
984                        let biased = b.alloc_reg(PtxType::U32);
985                        b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
986                        let shift_amt = b.alloc_reg(PtxType::U32);
987                        b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
988                        let bits = b.shl_b32(biased, shift_amt);
989                        let factor = b.alloc_reg(PtxType::F32);
990                        b.raw_ptx(&format!("mov.b32 {factor}, {bits};"));
991                        let res = b.alloc_reg(PtxType::F32);
992                        b.raw_ptx(&format!("mul.rn.f32 {res}, {val}, {factor};"));
993                        res
994                    };
995
996                    store_float(b, float_ty, addr, result);
997                });
998
999                b.ret();
1000            })
1001            .build()?;
1002
1003        Ok(ptx)
1004    }
1005}
1006
1007// =========================================================================
1008// Matrix Square Root (sqrtm)
1009// =========================================================================
1010
1011/// Configuration for matrix square root computation.
1012#[derive(Debug, Clone)]
1013pub struct MatrixSqrtConfig {
1014    /// Matrix dimension (n × n).
1015    pub n: u32,
1016    /// Precision: `"f32"` or `"f64"`.
1017    pub precision: String,
1018    /// Maximum Denman–Beavers iterations. Default: 50.
1019    pub max_iters: u32,
1020    /// Convergence tolerance. Default: 1e-12.
1021    pub tol: f64,
1022}
1023
1024impl MatrixSqrtConfig {
1025    /// Creates a new configuration with sensible defaults.
1026    pub fn new(n: u32, precision: &str) -> Self {
1027        Self {
1028            n,
1029            precision: precision.to_string(),
1030            max_iters: 50,
1031            tol: 1e-12,
1032        }
1033    }
1034
1035    /// Sets the maximum number of iterations.
1036    pub fn with_max_iters(mut self, iters: u32) -> Self {
1037        self.max_iters = iters;
1038        self
1039    }
1040
1041    /// Sets the convergence tolerance.
1042    pub fn with_tol(mut self, tol: f64) -> Self {
1043        self.tol = tol;
1044        self
1045    }
1046
1047    /// Validates the configuration.
1048    fn validate(&self) -> SolverResult<()> {
1049        if self.n == 0 {
1050            return Err(SolverError::DimensionMismatch(
1051                "sqrtm: matrix dimension must be > 0".into(),
1052            ));
1053        }
1054        if self.precision != "f32" && self.precision != "f64" {
1055            return Err(SolverError::InternalError(format!(
1056                "sqrtm: unsupported precision '{}'; use 'f32' or 'f64'",
1057                self.precision
1058            )));
1059        }
1060        if self.max_iters == 0 {
1061            return Err(SolverError::InternalError(
1062                "sqrtm: max_iters must be > 0".into(),
1063            ));
1064        }
1065        if self.tol <= 0.0 || !self.tol.is_finite() {
1066            return Err(SolverError::InternalError(format!(
1067                "sqrtm: tolerance must be positive and finite, got {}",
1068                self.tol
1069            )));
1070        }
1071        Ok(())
1072    }
1073}
1074
1075/// Execution plan for the matrix square root using Denman–Beavers iteration.
1076///
1077/// The iteration produces sequences `Y_k → sqrt(A)` and `Z_k → sqrt(A)^{-1}`:
1078///
1079/// ```text
1080///   Y_0 = A,       Z_0 = I
1081///   Y_{k+1} = (Y_k * Z_k^{-1} + I) / 2
1082///   Z_{k+1} = (Z_k * Y_k^{-1} + I) / 2
1083/// ```
1084///
1085/// Convergence is detected when `||Y_{k+1} - Y_k||_F < tol`.
1086#[derive(Debug, Clone)]
1087pub struct MatrixSqrtPlan {
1088    config: MatrixSqrtConfig,
1089}
1090
1091impl MatrixSqrtPlan {
1092    /// Creates a plan from a validated configuration.
1093    pub fn new(config: MatrixSqrtConfig) -> SolverResult<Self> {
1094        config.validate()?;
1095        Ok(Self { config })
1096    }
1097
1098    /// Returns the convergence tolerance.
1099    pub fn tolerance(&self) -> f64 {
1100        self.config.tol
1101    }
1102
1103    /// Returns the maximum number of iterations.
1104    pub fn max_iters(&self) -> u32 {
1105        self.config.max_iters
1106    }
1107
1108    /// Generates PTX source for the matrix square root kernels.
1109    ///
1110    /// The generated code contains entry points for:
1111    /// 1. **init kernel** — sets `Y_0 = A` and `Z_0 = I`.
1112    /// 2. **iteration kernel** — computes the element-wise `(M + I) / 2`
1113    ///    step after the matrix product and inverse have been done by BLAS.
1114    /// 3. **convergence kernel** — computes `||Y_{k+1} - Y_k||_F²` via
1115    ///    parallel reduction.
1116    pub fn generate_ptx(&self) -> SolverResult<String> {
1117        let n = self.config.n;
1118        let float_ty = precision_to_ptx_type(&self.config.precision)?;
1119        let sm = SmVersion::Sm75;
1120
1121        let mut all_ptx = Vec::new();
1122
1123        // Kernel 1: Initialize Y = A, Z = I.
1124        let init_ptx = self.emit_init_kernel(n, float_ty, sm)?;
1125        all_ptx.push(init_ptx);
1126
1127        // Kernel 2: Element-wise (M + I) / 2.
1128        let iter_ptx = self.emit_iteration_kernel(n, float_ty, sm)?;
1129        all_ptx.push(iter_ptx);
1130
1131        // Kernel 3: Frobenius norm difference (convergence check).
1132        let conv_ptx = self.emit_convergence_kernel(n, float_ty, sm)?;
1133        all_ptx.push(conv_ptx);
1134
1135        Ok(all_ptx.join("\n"))
1136    }
1137
1138    /// Emits PTX that copies A into Y and sets Z to the identity matrix.
1139    fn emit_init_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
1140        let name = format!("solver_sqrtm_init_{}_n{}", ptx_type_suffix(float_ty), n);
1141
1142        let ptx = KernelBuilder::new(&name)
1143            .target(sm)
1144            .max_threads_per_block(256)
1145            .param("a_ptr", PtxType::U64)
1146            .param("y_ptr", PtxType::U64)
1147            .param("z_ptr", PtxType::U64)
1148            .param("n", PtxType::U32)
1149            .body(move |b| {
1150                let gid = b.global_thread_id_x();
1151                let n_reg = b.load_param_u32("n");
1152                let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
1153
1154                b.if_lt_u32(gid, total, |b| {
1155                    let a_ptr = b.load_param_u64("a_ptr");
1156                    let y_ptr = b.load_param_u64("y_ptr");
1157                    let z_ptr = b.load_param_u64("z_ptr");
1158                    let n_inner = b.load_param_u32("n");
1159                    let gid_r = b.global_thread_id_x();
1160
1161                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1162
1163                    // Y = A: copy element.
1164                    let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
1165                    let val = load_float(b, float_ty, a_addr);
1166                    let y_addr = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
1167                    store_float(b, float_ty, y_addr, val);
1168
1169                    // Z = I: diagonal = 1, off-diagonal = 0.
1170                    let row = b.alloc_reg(PtxType::U32);
1171                    let col = b.alloc_reg(PtxType::U32);
1172                    b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
1173                    b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
1174                    let z_addr = b.byte_offset_addr(z_ptr, gid_r, elem_size);
1175
1176                    // Set 1.0 if on diagonal, 0.0 otherwise.
1177                    let one = one_const(b, float_ty);
1178                    let zero = zero_const(b, float_ty);
1179
1180                    // Use select based on row == col comparison.
1181                    let is_diag = b.alloc_reg(PtxType::Pred);
1182                    b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
1183                    let z_val = b.alloc_reg(float_ty);
1184                    b.raw_ptx(&format!(
1185                        "selp{} {z_val}, {one}, {zero}, {is_diag};",
1186                        float_ty.as_ptx_str()
1187                    ));
1188                    store_float(b, float_ty, z_addr, z_val);
1189                });
1190
1191                b.ret();
1192            })
1193            .build()?;
1194
1195        Ok(ptx)
1196    }
1197
1198    /// Emits PTX for the element-wise `(M + I) / 2` step.
1199    ///
1200    /// After the host computes `M = Y_k * Z_k^{-1}` via BLAS, this kernel
1201    /// computes `Y_{k+1}[i,j] = (M[i,j] + delta(i,j)) / 2`.
1202    fn emit_iteration_kernel(
1203        &self,
1204        n: u32,
1205        float_ty: PtxType,
1206        sm: SmVersion,
1207    ) -> SolverResult<String> {
1208        let name = format!("solver_sqrtm_iter_{}_n{}", ptx_type_suffix(float_ty), n);
1209
1210        let ptx = KernelBuilder::new(&name)
1211            .target(sm)
1212            .max_threads_per_block(256)
1213            .param("m_ptr", PtxType::U64)
1214            .param("out_ptr", PtxType::U64)
1215            .param("n", PtxType::U32)
1216            .body(move |b| {
1217                let gid = b.global_thread_id_x();
1218                let n_reg = b.load_param_u32("n");
1219                let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
1220
1221                b.if_lt_u32(gid, total, |b| {
1222                    let m_ptr = b.load_param_u64("m_ptr");
1223                    let out_ptr = b.load_param_u64("out_ptr");
1224                    let n_inner = b.load_param_u32("n");
1225                    let gid_r = b.global_thread_id_x();
1226
1227                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1228
1229                    // Load M[i,j].
1230                    let m_addr = b.byte_offset_addr(m_ptr, gid_r.clone(), elem_size);
1231                    let m_val = load_float(b, float_ty, m_addr);
1232
1233                    // Compute row, col for diagonal check.
1234                    let row = b.alloc_reg(PtxType::U32);
1235                    let col = b.alloc_reg(PtxType::U32);
1236                    b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
1237                    b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
1238
1239                    // Add 1.0 if on diagonal.
1240                    let is_diag = b.alloc_reg(PtxType::Pred);
1241                    b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
1242                    let one = one_const(b, float_ty);
1243                    let zero = zero_const(b, float_ty);
1244                    let diag_add = b.alloc_reg(float_ty);
1245                    b.raw_ptx(&format!(
1246                        "selp{} {diag_add}, {one}, {zero}, {is_diag};",
1247                        float_ty.as_ptx_str()
1248                    ));
1249
1250                    // sum = M[i,j] + diag_add.
1251                    let sum = b.alloc_reg(float_ty);
1252                    b.raw_ptx(&format!(
1253                        "add{} {sum}, {m_val}, {diag_add};",
1254                        float_ty.as_ptx_str()
1255                    ));
1256
1257                    // result = sum / 2.
1258                    let half = half_const(b, float_ty);
1259                    let result = b.alloc_reg(float_ty);
1260                    b.raw_ptx(&format!(
1261                        "mul{} {result}, {sum}, {half};",
1262                        float_ty.as_ptx_str()
1263                    ));
1264
1265                    let out_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
1266                    store_float(b, float_ty, out_addr, result);
1267                });
1268
1269                b.ret();
1270            })
1271            .build()?;
1272
1273        Ok(ptx)
1274    }
1275
1276    /// Emits PTX for computing `||Y_{k+1} - Y_k||_F²` via parallel reduction.
1277    ///
1278    /// Each thread computes `(Y_new[i] - Y_old[i])²` and participates in a
1279    /// warp-level reduction. Block-level results are atomically accumulated
1280    /// into a global accumulator.
1281    fn emit_convergence_kernel(
1282        &self,
1283        n: u32,
1284        float_ty: PtxType,
1285        sm: SmVersion,
1286    ) -> SolverResult<String> {
1287        let name = format!("solver_sqrtm_conv_{}_n{}", ptx_type_suffix(float_ty), n);
1288
1289        let ptx = KernelBuilder::new(&name)
1290            .target(sm)
1291            .max_threads_per_block(256)
1292            .param("y_new_ptr", PtxType::U64)
1293            .param("y_old_ptr", PtxType::U64)
1294            .param("norm_ptr", PtxType::U64)
1295            .param("n", PtxType::U32)
1296            .body(move |b| {
1297                let gid = b.global_thread_id_x();
1298                let n_reg = b.load_param_u32("n");
1299                let total = b.mul_lo_u32(n_reg.clone(), n_reg);
1300
1301                b.if_lt_u32(gid, total, |b| {
1302                    let y_new_ptr = b.load_param_u64("y_new_ptr");
1303                    let y_old_ptr = b.load_param_u64("y_old_ptr");
1304                    let gid_r = b.global_thread_id_x();
1305
1306                    let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1307
1308                    // diff = Y_new[gid] - Y_old[gid].
1309                    let new_addr = b.byte_offset_addr(y_new_ptr, gid_r.clone(), elem_size);
1310                    let old_addr = b.byte_offset_addr(y_old_ptr, gid_r, elem_size);
1311                    let new_val = load_float(b, float_ty, new_addr);
1312                    let old_val = load_float(b, float_ty, old_addr);
1313
1314                    let diff = b.alloc_reg(float_ty);
1315                    b.raw_ptx(&format!(
1316                        "sub{} {diff}, {new_val}, {old_val};",
1317                        float_ty.as_ptx_str()
1318                    ));
1319
1320                    // diff_sq = diff * diff.
1321                    let diff_sq = b.alloc_reg(float_ty);
1322                    b.raw_ptx(&format!(
1323                        "mul{} {diff_sq}, {diff}, {diff};",
1324                        float_ty.as_ptx_str()
1325                    ));
1326
1327                    // Atomically accumulate diff^2 into the global norm accumulator.
1328                    // This implements a parallel Frobenius norm computation:
1329                    //   ||Y_new - Y_old||_F^2 = sum_i (Y_new[i] - Y_old[i])^2
1330                    let norm_ptr = b.load_param_u64("norm_ptr");
1331                    if float_ty == PtxType::F64 {
1332                        let _old = b.atom_global_add_f64(norm_ptr, diff_sq);
1333                    } else {
1334                        let _old = b.atom_global_add_f32(norm_ptr, diff_sq);
1335                    }
1336                });
1337
1338                b.ret();
1339            })
1340            .build()?;
1341
1342        Ok(ptx)
1343    }
1344}
1345
1346// =========================================================================
1347// PTX helper utilities
1348// =========================================================================
1349
1350/// Converts a precision string to the corresponding PTX floating-point type.
1351fn precision_to_ptx_type(precision: &str) -> SolverResult<PtxType> {
1352    match precision {
1353        "f32" => Ok(PtxType::F32),
1354        "f64" => Ok(PtxType::F64),
1355        other => Err(SolverError::InternalError(format!(
1356            "unsupported precision '{other}'"
1357        ))),
1358    }
1359}
1360
1361/// Returns a short suffix for kernel names based on the PTX type.
1362fn ptx_type_suffix(ty: PtxType) -> &'static str {
1363    match ty {
1364        PtxType::F32 => "f32",
1365        PtxType::F64 => "f64",
1366        _ => "unknown",
1367    }
1368}
1369
1370/// Loads a float value from global memory.
1371fn load_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register) -> Register {
1372    let dst = b.alloc_reg(float_ty);
1373    b.raw_ptx(&format!(
1374        "ld.global{} {dst}, [{addr}];",
1375        float_ty.as_ptx_str()
1376    ));
1377    dst
1378}
1379
1380/// Stores a float value to global memory.
1381fn store_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register, val: Register) {
1382    b.raw_ptx(&format!(
1383        "st.global{} [{addr}], {val};",
1384        float_ty.as_ptx_str()
1385    ));
1386}
1387
1388/// Returns a register containing 0.0 in the given float type.
1389fn zero_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1390    let dst = b.alloc_reg(float_ty);
1391    if float_ty == PtxType::F32 {
1392        let bits = b.alloc_reg(PtxType::U32);
1393        b.raw_ptx(&format!("mov.u32 {bits}, 0;"));
1394        b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1395    } else {
1396        let bits = b.alloc_reg(PtxType::U64);
1397        b.raw_ptx(&format!("mov.u64 {bits}, 0;"));
1398        b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1399    }
1400    dst
1401}
1402
1403/// Returns a register containing 1.0 in the given float type.
1404fn one_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1405    let dst = b.alloc_reg(float_ty);
1406    if float_ty == PtxType::F32 {
1407        // IEEE 754: 1.0f32 = 0x3F800000
1408        let bits = b.alloc_reg(PtxType::U32);
1409        b.raw_ptx(&format!("mov.u32 {bits}, 1065353216;"));
1410        b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1411    } else {
1412        // IEEE 754: 1.0f64 = 0x3FF0000000000000 = 4607182418800017408
1413        let bits = b.alloc_reg(PtxType::U64);
1414        b.raw_ptx(&format!("mov.u64 {bits}, 4607182418800017408;"));
1415        b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1416    }
1417    dst
1418}
1419
1420/// Returns a register containing 0.5 in the given float type.
1421fn half_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1422    let dst = b.alloc_reg(float_ty);
1423    if float_ty == PtxType::F32 {
1424        // IEEE 754: 0.5f32 = 0x3F000000 = 1056964608
1425        let bits = b.alloc_reg(PtxType::U32);
1426        b.raw_ptx(&format!("mov.u32 {bits}, 1056964608;"));
1427        b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1428    } else {
1429        // IEEE 754: 0.5f64 = 0x3FE0000000000000 = 4602678819172646912
1430        let bits = b.alloc_reg(PtxType::U64);
1431        b.raw_ptx(&format!("mov.u64 {bits}, 4602678819172646912;"));
1432        b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1433    }
1434    dst
1435}
1436
1437// =========================================================================
1438// Tests (in a separate file to stay under the 2000-line refactoring limit)
1439// =========================================================================
1440
1441#[cfg(test)]
1442#[path = "matrix_functions_tests.rs"]
1443mod tests;