1#![allow(dead_code)]
18
19use oxicuda_ptx::ir::PtxType;
20use oxicuda_ptx::prelude::*;
21
22use crate::error::{SolverError, SolverResult};
23
24fn pade_coefficients(order: u32) -> SolverResult<Vec<f64>> {
34 match order {
35 3 => Ok(vec![120.0, 60.0, 12.0, 1.0]),
36 5 => Ok(vec![30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]),
37 7 => Ok(vec![
38 17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0,
39 ]),
40 9 => Ok(vec![
41 17643225600.0,
42 8821612800.0,
43 2075673600.0,
44 302702400.0,
45 30270240.0,
46 2162160.0,
47 110880.0,
48 3960.0,
49 90.0,
50 1.0,
51 ]),
52 13 => Ok(vec![
53 64764752532480000.0,
54 32382376266240000.0,
55 7771770303897600.0,
56 1187353796428800.0,
57 129060195264000.0,
58 10559470521600.0,
59 670442572800.0,
60 33522128640.0,
61 1323241920.0,
62 40840800.0,
63 960960.0,
64 16380.0,
65 182.0,
66 1.0,
67 ]),
68 _ => Err(SolverError::InternalError(format!(
69 "unsupported Padé order {order}; valid orders are 3, 5, 7, 9, 13"
70 ))),
71 }
72}
73
74#[allow(clippy::excessive_precision)]
77fn pade_theta(order: u32) -> SolverResult<f64> {
78 match order {
79 3 => Ok(1.495_585_217_958_292e-2),
80 5 => Ok(2.539_398_330_063_230e-1),
81 7 => Ok(9.504_178_996_162_932e-1),
82 9 => Ok(2.097_847_961_257_068),
83 13 => Ok(5.371_920_351_148_152),
84 _ => Err(SolverError::InternalError(format!(
85 "no theta for Padé order {order}"
86 ))),
87 }
88}
89
90#[derive(Debug, Clone)]
96pub struct MatrixExpConfig {
97 pub n: u32,
99 pub precision: String,
101 pub pade_order: u32,
103}
104
105impl MatrixExpConfig {
106 pub fn new(n: u32, precision: &str) -> Self {
108 Self {
109 n,
110 precision: precision.to_string(),
111 pade_order: 13,
112 }
113 }
114
115 pub fn with_pade_order(mut self, order: u32) -> Self {
117 self.pade_order = order;
118 self
119 }
120
121 fn validate(&self) -> SolverResult<()> {
123 if self.n == 0 {
124 return Err(SolverError::DimensionMismatch(
125 "expm: matrix dimension must be > 0".into(),
126 ));
127 }
128 if self.precision != "f32" && self.precision != "f64" {
129 return Err(SolverError::InternalError(format!(
130 "expm: unsupported precision '{}'; use 'f32' or 'f64'",
131 self.precision
132 )));
133 }
134 pade_coefficients(self.pade_order)?;
136 Ok(())
137 }
138}
139
140#[derive(Debug, Clone)]
145pub struct MatrixExpPlan {
146 config: MatrixExpConfig,
147 pade_coeffs: Vec<f64>,
148 theta: f64,
149}
150
151impl MatrixExpPlan {
152 pub fn new(config: MatrixExpConfig) -> SolverResult<Self> {
154 config.validate()?;
155 let pade_coeffs = pade_coefficients(config.pade_order)?;
156 let theta = pade_theta(config.pade_order)?;
157 Ok(Self {
158 config,
159 pade_coeffs,
160 theta,
161 })
162 }
163
164 pub fn pade_coefficients(&self) -> &[f64] {
166 &self.pade_coeffs
167 }
168
169 pub fn theta(&self) -> f64 {
171 self.theta
172 }
173
174 pub fn generate_ptx(&self) -> SolverResult<String> {
184 let n = self.config.n;
185 let float_ty = precision_to_ptx_type(&self.config.precision)?;
186 let sm = SmVersion::Sm75;
187
188 let mut all_ptx = Vec::new();
189
190 let scale_ptx = self.emit_scale_kernel(n, float_ty, sm)?;
192 all_ptx.push(scale_ptx);
193
194 let pade_ptx = self.emit_pade_kernel(n, float_ty, sm)?;
196 all_ptx.push(pade_ptx);
197
198 let square_ptx = self.emit_squaring_kernel(n, float_ty, sm)?;
200 all_ptx.push(square_ptx);
201
202 Ok(all_ptx.join("\n"))
203 }
204
205 fn emit_scale_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
207 let name = format!("solver_expm_scale_{}_n{}", ptx_type_suffix(float_ty), n);
208
209 let ptx = KernelBuilder::new(&name)
210 .target(sm)
211 .max_threads_per_block(256)
212 .param("a_ptr", PtxType::U64)
213 .param("out_ptr", PtxType::U64)
214 .param("n", PtxType::U32)
215 .param("scale_exp", PtxType::U32)
216 .body(move |b| {
217 let gid = b.global_thread_id_x();
219 let n_reg = b.load_param_u32("n");
220 let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
221
222 b.if_lt_u32(gid, total, |b| {
223 let a_ptr = b.load_param_u64("a_ptr");
224 let out_ptr = b.load_param_u64("out_ptr");
225 let scale_exp = b.load_param_u32("scale_exp");
226
227 let gid_repeat = b.global_thread_id_x();
229 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
230 let addr = b.byte_offset_addr(a_ptr, gid_repeat.clone(), elem_size);
231 let val = load_float(b, float_ty, addr);
232
233 let out_addr = b.byte_offset_addr(out_ptr, gid_repeat, elem_size);
237
238 let result = if float_ty == PtxType::F64 {
239 let se64 = b.cvt_u32_to_u64(scale_exp);
241 let biased = b.alloc_reg(PtxType::U64);
243 b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
244 let shift_amt = b.alloc_reg(PtxType::U32);
246 b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
247 let bits = b.shl_b64(biased, shift_amt);
248 let divisor = b.alloc_reg(PtxType::F64);
250 b.raw_ptx(&format!("mov.b64 {divisor}, {bits};"));
251 let res = b.alloc_reg(PtxType::F64);
253 b.raw_ptx(&format!("div.rn.f64 {res}, {val}, {divisor};"));
254 res
255 } else {
256 let biased = b.alloc_reg(PtxType::U32);
258 b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
259 let shift_amt = b.alloc_reg(PtxType::U32);
260 b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
261 let bits = b.shl_b32(biased, shift_amt);
262 let divisor = b.alloc_reg(PtxType::F32);
264 b.raw_ptx(&format!("mov.b32 {divisor}, {bits};"));
265 let res = b.alloc_reg(PtxType::F32);
267 b.raw_ptx(&format!("div.rn.f32 {res}, {val}, {divisor};"));
268 res
269 };
270
271 store_float(b, float_ty, out_addr, result);
272 });
273
274 b.ret();
275 })
276 .build()?;
277
278 Ok(ptx)
279 }
280
281 fn emit_pade_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
290 let order = self.config.pade_order;
291 let name = format!(
292 "solver_expm_pade_{}_n{}_p{}",
293 ptx_type_suffix(float_ty),
294 n,
295 order
296 );
297
298 let ptx = KernelBuilder::new(&name)
299 .target(sm)
300 .max_threads_per_block(256)
301 .param("a_ptr", PtxType::U64)
302 .param("p_ptr", PtxType::U64)
303 .param("q_ptr", PtxType::U64)
304 .param("n", PtxType::U32)
305 .param("coeffs_ptr", PtxType::U64)
306 .param("num_coeffs", PtxType::U32)
307 .body(move |b| {
308 let gid = b.global_thread_id_x();
312 let n_reg = b.load_param_u32("n");
313 let total = b.mul_lo_u32(n_reg.clone(), n_reg);
314
315 b.if_lt_u32(gid, total, |b| {
316 let a_ptr = b.load_param_u64("a_ptr");
317 let p_ptr = b.load_param_u64("p_ptr");
318 let q_ptr = b.load_param_u64("q_ptr");
319 let coeffs_ptr = b.load_param_u64("coeffs_ptr");
320 let num_coeffs = b.load_param_u32("num_coeffs");
321
322 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
323 const COEFF_SIZE: u32 = 8u32;
325 let gid_r = b.global_thread_id_x();
326
327 let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
329 let a_val = load_float(b, float_ty, a_addr);
330
331 let acc_p = zero_const(b, float_ty);
347 let acc_q = zero_const(b, float_ty);
348
349 let idx_reg = b.alloc_reg(PtxType::U32);
351 b.raw_ptx(&format!("mov.u32 {idx_reg}, {num_coeffs};"));
352
353 let horner_loop = b.fresh_label("horner_loop");
354 let horner_exit = b.fresh_label("horner_exit");
355
356 b.raw_ptx(&format!("{horner_loop}:"));
357 let done_pred = b.alloc_reg(PtxType::Pred);
359 b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {idx_reg}, 0;"));
360 b.raw_ptx(&format!("@{done_pred} bra {horner_exit};"));
361
362 b.raw_ptx(&format!("sub.u32 {idx_reg}, {idx_reg}, 1;"));
364
365 let coeff_addr =
367 b.byte_offset_addr(coeffs_ptr.clone(), idx_reg.clone(), COEFF_SIZE);
368 let coeff_f64 = load_float(b, PtxType::F64, coeff_addr);
369
370 let c_k = if float_ty == PtxType::F64 {
372 coeff_f64.clone()
373 } else {
374 let dst = b.alloc_reg(PtxType::F32);
376 b.raw_ptx(&format!("cvt.rn.f32.f64 {dst}, {coeff_f64};"));
377 dst
378 };
379
380 let new_acc_p = if float_ty == PtxType::F64 {
382 b.fma_f64(acc_p.clone(), a_val.clone(), c_k.clone())
383 } else {
384 b.fma_f32(acc_p.clone(), a_val.clone(), c_k.clone())
385 };
386 b.raw_ptx(&format!(
387 "mov{} {acc_p}, {new_acc_p};",
388 float_ty.as_ptx_str()
389 ));
390
391 let odd_pred = b.alloc_reg(PtxType::Pred);
394 let lsb = b.alloc_reg(PtxType::U32);
395 b.raw_ptx(&format!("and.b32 {lsb}, {idx_reg}, 1;"));
396 b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
397
398 let neg_c_k = b.alloc_reg(float_ty);
400 b.raw_ptx(&format!("neg{} {neg_c_k}, {c_k};", float_ty.as_ptx_str()));
401 let q_coeff = b.alloc_reg(float_ty);
403 b.raw_ptx(&format!(
404 "selp{} {q_coeff}, {neg_c_k}, {c_k}, {odd_pred};",
405 float_ty.as_ptx_str()
406 ));
407
408 let new_acc_q = if float_ty == PtxType::F64 {
410 b.fma_f64(acc_q.clone(), a_val.clone(), q_coeff)
411 } else {
412 b.fma_f32(acc_q.clone(), a_val.clone(), q_coeff)
413 };
414 b.raw_ptx(&format!(
415 "mov{} {acc_q}, {new_acc_q};",
416 float_ty.as_ptx_str()
417 ));
418
419 b.raw_ptx(&format!("bra {horner_loop};"));
420 b.raw_ptx(&format!("{horner_exit}:"));
421
422 let p_addr = b.byte_offset_addr(p_ptr, gid_r.clone(), elem_size);
424 let q_addr = b.byte_offset_addr(q_ptr, gid_r, elem_size);
425 store_float(b, float_ty, p_addr, acc_p);
426 store_float(b, float_ty, q_addr, acc_q);
427 });
428
429 b.ret();
430 })
431 .build()?;
432
433 Ok(ptx)
434 }
435
436 fn emit_squaring_kernel(
438 &self,
439 n: u32,
440 float_ty: PtxType,
441 sm: SmVersion,
442 ) -> SolverResult<String> {
443 let name = format!("solver_expm_square_{}_n{}", ptx_type_suffix(float_ty), n);
444
445 let ptx = KernelBuilder::new(&name)
446 .target(sm)
447 .max_threads_per_block(256)
448 .param("f_ptr", PtxType::U64)
449 .param("tmp_ptr", PtxType::U64)
450 .param("n", PtxType::U32)
451 .body(move |b| {
452 let gid = b.global_thread_id_x();
455 let n_reg = b.load_param_u32("n");
456 let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
457
458 b.if_lt_u32(gid, total, |b| {
459 let f_ptr = b.load_param_u64("f_ptr");
460 let tmp_ptr = b.load_param_u64("tmp_ptr");
461 let n_inner = b.load_param_u32("n");
462
463 let gid_r = b.global_thread_id_x();
465 let row = b.alloc_reg(PtxType::U32);
466 let col = b.alloc_reg(PtxType::U32);
467 b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
468 b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
469
470 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
471
472 let acc = zero_const(b, float_ty);
475 let k_reg = b.alloc_reg(PtxType::U32);
476 b.raw_ptx(&format!("mov.u32 {k_reg}, 0;"));
477
478 let loop_label = b.fresh_label("sq_loop");
479 let exit_label = b.fresh_label("sq_exit");
480
481 b.raw_ptx(&format!("{loop_label}:"));
482 let pred = b.alloc_reg(PtxType::Pred);
484 b.raw_ptx(&format!("setp.ge.u32 {pred}, {k_reg}, {n_inner};"));
485 b.raw_ptx(&format!("@{pred} bra {exit_label};"));
486
487 let a_idx_base = b.mul_lo_u32(k_reg.clone(), n_inner.clone());
489 let a_idx = b.add_u32(a_idx_base, row.clone());
490 let a_addr = b.byte_offset_addr(f_ptr.clone(), a_idx, elem_size);
491 let a_val = load_float(b, float_ty, a_addr);
492
493 let b_idx_base = b.mul_lo_u32(col.clone(), n_inner.clone());
495 let b_idx = b.add_u32(b_idx_base, k_reg.clone());
496 let b_addr = b.byte_offset_addr(f_ptr.clone(), b_idx, elem_size);
497 let b_val = load_float(b, float_ty, b_addr);
498
499 let new_acc = if float_ty == PtxType::F64 {
501 b.fma_f64(a_val, b_val, acc.clone())
502 } else {
503 b.fma_f32(a_val, b_val, acc.clone())
504 };
505 b.raw_ptx(&format!("mov{} {acc}, {new_acc};", float_ty.as_ptx_str()));
507
508 b.raw_ptx(&format!("add.u32 {k_reg}, {k_reg}, 1;"));
510 b.raw_ptx(&format!("bra {loop_label};"));
511
512 b.raw_ptx(&format!("{exit_label}:"));
513
514 let out_idx_base = b.mul_lo_u32(col, n_inner);
516 let out_idx = b.add_u32(out_idx_base, row);
517 let out_addr = b.byte_offset_addr(tmp_ptr, out_idx, elem_size);
518 store_float(b, float_ty, out_addr, acc);
519 });
520
521 b.ret();
522 })
523 .build()?;
524
525 Ok(ptx)
526 }
527}
528
529#[derive(Debug, Clone)]
535pub struct MatrixLogConfig {
536 pub n: u32,
538 pub precision: String,
540 pub max_sqrt_iters: u32,
542}
543
544impl MatrixLogConfig {
545 pub fn new(n: u32, precision: &str) -> Self {
547 Self {
548 n,
549 precision: precision.to_string(),
550 max_sqrt_iters: 100,
551 }
552 }
553
554 pub fn with_max_sqrt_iters(mut self, iters: u32) -> Self {
556 self.max_sqrt_iters = iters;
557 self
558 }
559
560 fn validate(&self) -> SolverResult<()> {
562 if self.n == 0 {
563 return Err(SolverError::DimensionMismatch(
564 "logm: matrix dimension must be > 0".into(),
565 ));
566 }
567 if self.precision != "f32" && self.precision != "f64" {
568 return Err(SolverError::InternalError(format!(
569 "logm: unsupported precision '{}'; use 'f32' or 'f64'",
570 self.precision
571 )));
572 }
573 if self.max_sqrt_iters == 0 {
574 return Err(SolverError::InternalError(
575 "logm: max_sqrt_iters must be > 0".into(),
576 ));
577 }
578 Ok(())
579 }
580}
581
582#[derive(Debug, Clone)]
589pub struct MatrixLogPlan {
590 config: MatrixLogConfig,
591}
592
593impl MatrixLogPlan {
594 pub fn new(config: MatrixLogConfig) -> SolverResult<Self> {
596 config.validate()?;
597 Ok(Self { config })
598 }
599
600 pub fn max_sqrt_iters(&self) -> u32 {
602 self.config.max_sqrt_iters
603 }
604
605 pub fn generate_ptx(&self) -> SolverResult<String> {
615 let n = self.config.n;
616 let float_ty = precision_to_ptx_type(&self.config.precision)?;
617 let sm = SmVersion::Sm75;
618
619 let mut all_ptx = Vec::new();
620
621 let shift_ptx = self.emit_shift_kernel(n, float_ty, sm)?;
623 all_ptx.push(shift_ptx);
624
625 let sqrt_step_ptx = self.emit_sqrt_step_kernel(n, float_ty, sm)?;
627 all_ptx.push(sqrt_step_ptx);
628
629 let pade_log_ptx = self.emit_pade_log_kernel(n, float_ty, sm)?;
631 all_ptx.push(pade_log_ptx);
632
633 let scale_ptx = self.emit_scale_back_kernel(n, float_ty, sm)?;
635 all_ptx.push(scale_ptx);
636
637 Ok(all_ptx.join("\n"))
638 }
639
640 fn emit_shift_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
642 let name = format!("solver_logm_shift_{}_n{}", ptx_type_suffix(float_ty), n);
643
644 let ptx = KernelBuilder::new(&name)
645 .target(sm)
646 .max_threads_per_block(256)
647 .param("a_ptr", PtxType::U64)
648 .param("out_ptr", PtxType::U64)
649 .param("n", PtxType::U32)
650 .body(move |b| {
651 let gid = b.global_thread_id_x();
652 let n_reg = b.load_param_u32("n");
653 let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
654
655 b.if_lt_u32(gid, total, |b| {
656 let a_ptr = b.load_param_u64("a_ptr");
657 let out_ptr = b.load_param_u64("out_ptr");
658 let n_inner = b.load_param_u32("n");
659 let gid_r = b.global_thread_id_x();
660
661 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
662
663 let src_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
665 let val = load_float(b, float_ty, src_addr);
666
667 let row = b.alloc_reg(PtxType::U32);
669 let col = b.alloc_reg(PtxType::U32);
670 b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
671 b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
672
673 let is_diag = b.alloc_reg(PtxType::Pred);
675 b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
676 let one = one_const(b, float_ty);
677 let zero = zero_const(b, float_ty);
678 let diag_sub = b.alloc_reg(float_ty);
680 b.raw_ptx(&format!(
681 "selp{} {diag_sub}, {one}, {zero}, {is_diag};",
682 float_ty.as_ptx_str()
683 ));
684 let result = b.alloc_reg(float_ty);
686 b.raw_ptx(&format!(
687 "sub{} {result}, {val}, {diag_sub};",
688 float_ty.as_ptx_str()
689 ));
690
691 let dst_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
692 store_float(b, float_ty, dst_addr, result);
693 });
694
695 b.ret();
696 })
697 .build()?;
698
699 Ok(ptx)
700 }
701
702 fn emit_sqrt_step_kernel(
705 &self,
706 n: u32,
707 float_ty: PtxType,
708 sm: SmVersion,
709 ) -> SolverResult<String> {
710 let name = format!("solver_logm_sqrt_step_{}_n{}", ptx_type_suffix(float_ty), n);
711
712 let ptx = KernelBuilder::new(&name)
713 .target(sm)
714 .max_threads_per_block(256)
715 .param("y_ptr", PtxType::U64)
716 .param("z_ptr", PtxType::U64)
717 .param("y_next_ptr", PtxType::U64)
718 .param("z_next_ptr", PtxType::U64)
719 .param("n", PtxType::U32)
720 .body(move |b| {
721 let gid = b.global_thread_id_x();
722 let n_reg = b.load_param_u32("n");
723 let total = b.mul_lo_u32(n_reg.clone(), n_reg);
724
725 b.if_lt_u32(gid, total, |b| {
726 let y_ptr = b.load_param_u64("y_ptr");
727 let z_ptr = b.load_param_u64("z_ptr");
728 let y_next_ptr = b.load_param_u64("y_next_ptr");
729 let z_next_ptr = b.load_param_u64("z_next_ptr");
730 let n_inner = b.load_param_u32("n");
731 let gid_r = b.global_thread_id_x();
732 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
733
734 let row = b.alloc_reg(PtxType::U32);
736 let col = b.alloc_reg(PtxType::U32);
737 b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
738 b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
739
740 let is_diag = b.alloc_reg(PtxType::Pred);
741 b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
742 let one = one_const(b, float_ty);
743 let zero = zero_const(b, float_ty);
744
745 let diag_add = b.alloc_reg(float_ty);
750 b.raw_ptx(&format!(
751 "selp{} {diag_add}, {one}, {zero}, {is_diag};",
752 float_ty.as_ptx_str()
753 ));
754 let half = half_const(b, float_ty);
755
756 let y_src = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
758 let y_val = load_float(b, float_ty, y_src);
759 let y_sum = b.alloc_reg(float_ty);
760 b.raw_ptx(&format!(
761 "add{} {y_sum}, {y_val}, {diag_add};",
762 float_ty.as_ptx_str()
763 ));
764 let y_result = b.alloc_reg(float_ty);
765 b.raw_ptx(&format!(
766 "mul{} {y_result}, {y_sum}, {half};",
767 float_ty.as_ptx_str()
768 ));
769 let y_dst = b.byte_offset_addr(y_next_ptr, gid_r.clone(), elem_size);
770 store_float(b, float_ty, y_dst, y_result);
771
772 let z_src = b.byte_offset_addr(z_ptr, gid_r.clone(), elem_size);
774 let z_val = load_float(b, float_ty, z_src);
775 let z_sum = b.alloc_reg(float_ty);
776 b.raw_ptx(&format!(
777 "add{} {z_sum}, {z_val}, {diag_add};",
778 float_ty.as_ptx_str()
779 ));
780 let z_result = b.alloc_reg(float_ty);
781 b.raw_ptx(&format!(
782 "mul{} {z_result}, {z_sum}, {half};",
783 float_ty.as_ptx_str()
784 ));
785 let z_dst = b.byte_offset_addr(z_next_ptr, gid_r, elem_size);
786 store_float(b, float_ty, z_dst, z_result);
787 });
788
789 b.ret();
790 })
791 .build()?;
792
793 Ok(ptx)
794 }
795
796 fn emit_pade_log_kernel(
803 &self,
804 n: u32,
805 float_ty: PtxType,
806 sm: SmVersion,
807 ) -> SolverResult<String> {
808 let name = format!("solver_logm_pade_{}_n{}", ptx_type_suffix(float_ty), n);
809
810 let ptx = KernelBuilder::new(&name)
811 .target(sm)
812 .max_threads_per_block(256)
813 .param("x_ptr", PtxType::U64)
814 .param("result_ptr", PtxType::U64)
815 .param("n", PtxType::U32)
816 .param("num_terms", PtxType::U32)
817 .body(move |b| {
818 let gid = b.global_thread_id_x();
819 let n_reg = b.load_param_u32("n");
820 let total = b.mul_lo_u32(n_reg.clone(), n_reg);
821
822 b.if_lt_u32(gid, total, |b| {
823 let x_ptr = b.load_param_u64("x_ptr");
824 let result_ptr = b.load_param_u64("result_ptr");
825 let num_terms = b.load_param_u32("num_terms");
826 let gid_r = b.global_thread_id_x();
827 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
828
829 let src = b.byte_offset_addr(x_ptr, gid_r.clone(), elem_size);
841 let x_val = load_float(b, float_ty, src);
842
843 let acc_reg = b.alloc_reg(float_ty);
845 let zero = zero_const(b, float_ty);
847 b.raw_ptx(&format!("mov{} {acc_reg}, {zero};", float_ty.as_ptx_str()));
848
849 let k_reg = b.alloc_reg(PtxType::U32);
851 b.raw_ptx(&format!("mov.u32 {k_reg}, {num_terms};"));
852
853 let log_loop = b.fresh_label("log_loop");
854 let log_exit = b.fresh_label("log_exit");
855
856 b.raw_ptx(&format!("{log_loop}:"));
857 let done_pred = b.alloc_reg(PtxType::Pred);
859 b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {k_reg}, 0;"));
860 b.raw_ptx(&format!("@{done_pred} bra {log_exit};"));
861
862 let k_f = b.alloc_reg(float_ty);
864 if float_ty == PtxType::F64 {
865 b.raw_ptx(&format!("cvt.rn.f64.u32 {k_f}, {k_reg};"));
866 } else {
867 b.raw_ptx(&format!("cvt.rn.f32.u32 {k_f}, {k_reg};"));
868 }
869
870 let inv_k = if float_ty == PtxType::F64 {
872 b.rcp_f64(k_f)
873 } else {
874 b.rcp_f32(k_f)
875 };
876
877 let odd_pred = b.alloc_reg(PtxType::Pred);
879 let lsb = b.alloc_reg(PtxType::U32);
880 b.raw_ptx(&format!("and.b32 {lsb}, {k_reg}, 1;"));
881 b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
882
883 let neg_inv_k = b.alloc_reg(float_ty);
884 b.raw_ptx(&format!(
885 "neg{} {neg_inv_k}, {inv_k};",
886 float_ty.as_ptx_str()
887 ));
888 let signed_inv_k = b.alloc_reg(float_ty);
890 b.raw_ptx(&format!(
891 "selp{} {signed_inv_k}, {inv_k}, {neg_inv_k}, {odd_pred};",
892 float_ty.as_ptx_str()
893 ));
894
895 let new_acc = if float_ty == PtxType::F64 {
899 b.fma_f64(x_val.clone(), acc_reg.clone(), signed_inv_k)
900 } else {
901 b.fma_f32(x_val.clone(), acc_reg.clone(), signed_inv_k)
902 };
903 b.raw_ptx(&format!(
904 "mov{} {acc_reg}, {new_acc};",
905 float_ty.as_ptx_str()
906 ));
907
908 b.raw_ptx(&format!("sub.u32 {k_reg}, {k_reg}, 1;"));
910 b.raw_ptx(&format!("bra {log_loop};"));
911 b.raw_ptx(&format!("{log_exit}:"));
912
913 let result = if float_ty == PtxType::F64 {
915 let r = b.alloc_reg(PtxType::F64);
916 b.raw_ptx(&format!("mul.rn.f64 {r}, {x_val}, {acc_reg};"));
917 r
918 } else {
919 let r = b.alloc_reg(PtxType::F32);
920 b.raw_ptx(&format!("mul.rn.f32 {r}, {x_val}, {acc_reg};"));
921 r
922 };
923
924 let dst = b.byte_offset_addr(result_ptr, gid_r, elem_size);
925 store_float(b, float_ty, dst, result);
926 });
927
928 b.ret();
929 })
930 .build()?;
931
932 Ok(ptx)
933 }
934
935 fn emit_scale_back_kernel(
937 &self,
938 n: u32,
939 float_ty: PtxType,
940 sm: SmVersion,
941 ) -> SolverResult<String> {
942 let name = format!(
943 "solver_logm_scale_back_{}_n{}",
944 ptx_type_suffix(float_ty),
945 n
946 );
947
948 let ptx = KernelBuilder::new(&name)
949 .target(sm)
950 .max_threads_per_block(256)
951 .param("result_ptr", PtxType::U64)
952 .param("n", PtxType::U32)
953 .param("scale_exp", PtxType::U32)
954 .body(move |b| {
955 let gid = b.global_thread_id_x();
956 let n_reg = b.load_param_u32("n");
957 let total = b.mul_lo_u32(n_reg.clone(), n_reg);
958
959 b.if_lt_u32(gid, total, |b| {
960 let result_ptr = b.load_param_u64("result_ptr");
961 let scale_exp = b.load_param_u32("scale_exp");
962 let gid_r = b.global_thread_id_x();
963 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
964
965 let addr = b.byte_offset_addr(result_ptr, gid_r, elem_size);
969 let val = load_float(b, float_ty, addr.clone());
970
971 let result = if float_ty == PtxType::F64 {
972 let se64 = b.cvt_u32_to_u64(scale_exp);
973 let biased = b.alloc_reg(PtxType::U64);
974 b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
975 let shift_amt = b.alloc_reg(PtxType::U32);
976 b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
977 let bits = b.shl_b64(biased, shift_amt);
978 let factor = b.alloc_reg(PtxType::F64);
979 b.raw_ptx(&format!("mov.b64 {factor}, {bits};"));
980 let res = b.alloc_reg(PtxType::F64);
981 b.raw_ptx(&format!("mul.rn.f64 {res}, {val}, {factor};"));
982 res
983 } else {
984 let biased = b.alloc_reg(PtxType::U32);
985 b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
986 let shift_amt = b.alloc_reg(PtxType::U32);
987 b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
988 let bits = b.shl_b32(biased, shift_amt);
989 let factor = b.alloc_reg(PtxType::F32);
990 b.raw_ptx(&format!("mov.b32 {factor}, {bits};"));
991 let res = b.alloc_reg(PtxType::F32);
992 b.raw_ptx(&format!("mul.rn.f32 {res}, {val}, {factor};"));
993 res
994 };
995
996 store_float(b, float_ty, addr, result);
997 });
998
999 b.ret();
1000 })
1001 .build()?;
1002
1003 Ok(ptx)
1004 }
1005}
1006
1007#[derive(Debug, Clone)]
1013pub struct MatrixSqrtConfig {
1014 pub n: u32,
1016 pub precision: String,
1018 pub max_iters: u32,
1020 pub tol: f64,
1022}
1023
1024impl MatrixSqrtConfig {
1025 pub fn new(n: u32, precision: &str) -> Self {
1027 Self {
1028 n,
1029 precision: precision.to_string(),
1030 max_iters: 50,
1031 tol: 1e-12,
1032 }
1033 }
1034
1035 pub fn with_max_iters(mut self, iters: u32) -> Self {
1037 self.max_iters = iters;
1038 self
1039 }
1040
1041 pub fn with_tol(mut self, tol: f64) -> Self {
1043 self.tol = tol;
1044 self
1045 }
1046
1047 fn validate(&self) -> SolverResult<()> {
1049 if self.n == 0 {
1050 return Err(SolverError::DimensionMismatch(
1051 "sqrtm: matrix dimension must be > 0".into(),
1052 ));
1053 }
1054 if self.precision != "f32" && self.precision != "f64" {
1055 return Err(SolverError::InternalError(format!(
1056 "sqrtm: unsupported precision '{}'; use 'f32' or 'f64'",
1057 self.precision
1058 )));
1059 }
1060 if self.max_iters == 0 {
1061 return Err(SolverError::InternalError(
1062 "sqrtm: max_iters must be > 0".into(),
1063 ));
1064 }
1065 if self.tol <= 0.0 || !self.tol.is_finite() {
1066 return Err(SolverError::InternalError(format!(
1067 "sqrtm: tolerance must be positive and finite, got {}",
1068 self.tol
1069 )));
1070 }
1071 Ok(())
1072 }
1073}
1074
1075#[derive(Debug, Clone)]
1087pub struct MatrixSqrtPlan {
1088 config: MatrixSqrtConfig,
1089}
1090
1091impl MatrixSqrtPlan {
1092 pub fn new(config: MatrixSqrtConfig) -> SolverResult<Self> {
1094 config.validate()?;
1095 Ok(Self { config })
1096 }
1097
1098 pub fn tolerance(&self) -> f64 {
1100 self.config.tol
1101 }
1102
1103 pub fn max_iters(&self) -> u32 {
1105 self.config.max_iters
1106 }
1107
1108 pub fn generate_ptx(&self) -> SolverResult<String> {
1117 let n = self.config.n;
1118 let float_ty = precision_to_ptx_type(&self.config.precision)?;
1119 let sm = SmVersion::Sm75;
1120
1121 let mut all_ptx = Vec::new();
1122
1123 let init_ptx = self.emit_init_kernel(n, float_ty, sm)?;
1125 all_ptx.push(init_ptx);
1126
1127 let iter_ptx = self.emit_iteration_kernel(n, float_ty, sm)?;
1129 all_ptx.push(iter_ptx);
1130
1131 let conv_ptx = self.emit_convergence_kernel(n, float_ty, sm)?;
1133 all_ptx.push(conv_ptx);
1134
1135 Ok(all_ptx.join("\n"))
1136 }
1137
1138 fn emit_init_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
1140 let name = format!("solver_sqrtm_init_{}_n{}", ptx_type_suffix(float_ty), n);
1141
1142 let ptx = KernelBuilder::new(&name)
1143 .target(sm)
1144 .max_threads_per_block(256)
1145 .param("a_ptr", PtxType::U64)
1146 .param("y_ptr", PtxType::U64)
1147 .param("z_ptr", PtxType::U64)
1148 .param("n", PtxType::U32)
1149 .body(move |b| {
1150 let gid = b.global_thread_id_x();
1151 let n_reg = b.load_param_u32("n");
1152 let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
1153
1154 b.if_lt_u32(gid, total, |b| {
1155 let a_ptr = b.load_param_u64("a_ptr");
1156 let y_ptr = b.load_param_u64("y_ptr");
1157 let z_ptr = b.load_param_u64("z_ptr");
1158 let n_inner = b.load_param_u32("n");
1159 let gid_r = b.global_thread_id_x();
1160
1161 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1162
1163 let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
1165 let val = load_float(b, float_ty, a_addr);
1166 let y_addr = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
1167 store_float(b, float_ty, y_addr, val);
1168
1169 let row = b.alloc_reg(PtxType::U32);
1171 let col = b.alloc_reg(PtxType::U32);
1172 b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
1173 b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
1174 let z_addr = b.byte_offset_addr(z_ptr, gid_r, elem_size);
1175
1176 let one = one_const(b, float_ty);
1178 let zero = zero_const(b, float_ty);
1179
1180 let is_diag = b.alloc_reg(PtxType::Pred);
1182 b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
1183 let z_val = b.alloc_reg(float_ty);
1184 b.raw_ptx(&format!(
1185 "selp{} {z_val}, {one}, {zero}, {is_diag};",
1186 float_ty.as_ptx_str()
1187 ));
1188 store_float(b, float_ty, z_addr, z_val);
1189 });
1190
1191 b.ret();
1192 })
1193 .build()?;
1194
1195 Ok(ptx)
1196 }
1197
1198 fn emit_iteration_kernel(
1203 &self,
1204 n: u32,
1205 float_ty: PtxType,
1206 sm: SmVersion,
1207 ) -> SolverResult<String> {
1208 let name = format!("solver_sqrtm_iter_{}_n{}", ptx_type_suffix(float_ty), n);
1209
1210 let ptx = KernelBuilder::new(&name)
1211 .target(sm)
1212 .max_threads_per_block(256)
1213 .param("m_ptr", PtxType::U64)
1214 .param("out_ptr", PtxType::U64)
1215 .param("n", PtxType::U32)
1216 .body(move |b| {
1217 let gid = b.global_thread_id_x();
1218 let n_reg = b.load_param_u32("n");
1219 let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
1220
1221 b.if_lt_u32(gid, total, |b| {
1222 let m_ptr = b.load_param_u64("m_ptr");
1223 let out_ptr = b.load_param_u64("out_ptr");
1224 let n_inner = b.load_param_u32("n");
1225 let gid_r = b.global_thread_id_x();
1226
1227 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1228
1229 let m_addr = b.byte_offset_addr(m_ptr, gid_r.clone(), elem_size);
1231 let m_val = load_float(b, float_ty, m_addr);
1232
1233 let row = b.alloc_reg(PtxType::U32);
1235 let col = b.alloc_reg(PtxType::U32);
1236 b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
1237 b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
1238
1239 let is_diag = b.alloc_reg(PtxType::Pred);
1241 b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
1242 let one = one_const(b, float_ty);
1243 let zero = zero_const(b, float_ty);
1244 let diag_add = b.alloc_reg(float_ty);
1245 b.raw_ptx(&format!(
1246 "selp{} {diag_add}, {one}, {zero}, {is_diag};",
1247 float_ty.as_ptx_str()
1248 ));
1249
1250 let sum = b.alloc_reg(float_ty);
1252 b.raw_ptx(&format!(
1253 "add{} {sum}, {m_val}, {diag_add};",
1254 float_ty.as_ptx_str()
1255 ));
1256
1257 let half = half_const(b, float_ty);
1259 let result = b.alloc_reg(float_ty);
1260 b.raw_ptx(&format!(
1261 "mul{} {result}, {sum}, {half};",
1262 float_ty.as_ptx_str()
1263 ));
1264
1265 let out_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
1266 store_float(b, float_ty, out_addr, result);
1267 });
1268
1269 b.ret();
1270 })
1271 .build()?;
1272
1273 Ok(ptx)
1274 }
1275
1276 fn emit_convergence_kernel(
1282 &self,
1283 n: u32,
1284 float_ty: PtxType,
1285 sm: SmVersion,
1286 ) -> SolverResult<String> {
1287 let name = format!("solver_sqrtm_conv_{}_n{}", ptx_type_suffix(float_ty), n);
1288
1289 let ptx = KernelBuilder::new(&name)
1290 .target(sm)
1291 .max_threads_per_block(256)
1292 .param("y_new_ptr", PtxType::U64)
1293 .param("y_old_ptr", PtxType::U64)
1294 .param("norm_ptr", PtxType::U64)
1295 .param("n", PtxType::U32)
1296 .body(move |b| {
1297 let gid = b.global_thread_id_x();
1298 let n_reg = b.load_param_u32("n");
1299 let total = b.mul_lo_u32(n_reg.clone(), n_reg);
1300
1301 b.if_lt_u32(gid, total, |b| {
1302 let y_new_ptr = b.load_param_u64("y_new_ptr");
1303 let y_old_ptr = b.load_param_u64("y_old_ptr");
1304 let gid_r = b.global_thread_id_x();
1305
1306 let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
1307
1308 let new_addr = b.byte_offset_addr(y_new_ptr, gid_r.clone(), elem_size);
1310 let old_addr = b.byte_offset_addr(y_old_ptr, gid_r, elem_size);
1311 let new_val = load_float(b, float_ty, new_addr);
1312 let old_val = load_float(b, float_ty, old_addr);
1313
1314 let diff = b.alloc_reg(float_ty);
1315 b.raw_ptx(&format!(
1316 "sub{} {diff}, {new_val}, {old_val};",
1317 float_ty.as_ptx_str()
1318 ));
1319
1320 let diff_sq = b.alloc_reg(float_ty);
1322 b.raw_ptx(&format!(
1323 "mul{} {diff_sq}, {diff}, {diff};",
1324 float_ty.as_ptx_str()
1325 ));
1326
1327 let norm_ptr = b.load_param_u64("norm_ptr");
1331 if float_ty == PtxType::F64 {
1332 let _old = b.atom_global_add_f64(norm_ptr, diff_sq);
1333 } else {
1334 let _old = b.atom_global_add_f32(norm_ptr, diff_sq);
1335 }
1336 });
1337
1338 b.ret();
1339 })
1340 .build()?;
1341
1342 Ok(ptx)
1343 }
1344}
1345
1346fn precision_to_ptx_type(precision: &str) -> SolverResult<PtxType> {
1352 match precision {
1353 "f32" => Ok(PtxType::F32),
1354 "f64" => Ok(PtxType::F64),
1355 other => Err(SolverError::InternalError(format!(
1356 "unsupported precision '{other}'"
1357 ))),
1358 }
1359}
1360
1361fn ptx_type_suffix(ty: PtxType) -> &'static str {
1363 match ty {
1364 PtxType::F32 => "f32",
1365 PtxType::F64 => "f64",
1366 _ => "unknown",
1367 }
1368}
1369
1370fn load_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register) -> Register {
1372 let dst = b.alloc_reg(float_ty);
1373 b.raw_ptx(&format!(
1374 "ld.global{} {dst}, [{addr}];",
1375 float_ty.as_ptx_str()
1376 ));
1377 dst
1378}
1379
1380fn store_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register, val: Register) {
1382 b.raw_ptx(&format!(
1383 "st.global{} [{addr}], {val};",
1384 float_ty.as_ptx_str()
1385 ));
1386}
1387
1388fn zero_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1390 let dst = b.alloc_reg(float_ty);
1391 if float_ty == PtxType::F32 {
1392 let bits = b.alloc_reg(PtxType::U32);
1393 b.raw_ptx(&format!("mov.u32 {bits}, 0;"));
1394 b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1395 } else {
1396 let bits = b.alloc_reg(PtxType::U64);
1397 b.raw_ptx(&format!("mov.u64 {bits}, 0;"));
1398 b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1399 }
1400 dst
1401}
1402
1403fn one_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1405 let dst = b.alloc_reg(float_ty);
1406 if float_ty == PtxType::F32 {
1407 let bits = b.alloc_reg(PtxType::U32);
1409 b.raw_ptx(&format!("mov.u32 {bits}, 1065353216;"));
1410 b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1411 } else {
1412 let bits = b.alloc_reg(PtxType::U64);
1414 b.raw_ptx(&format!("mov.u64 {bits}, 4607182418800017408;"));
1415 b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1416 }
1417 dst
1418}
1419
1420fn half_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
1422 let dst = b.alloc_reg(float_ty);
1423 if float_ty == PtxType::F32 {
1424 let bits = b.alloc_reg(PtxType::U32);
1426 b.raw_ptx(&format!("mov.u32 {bits}, 1056964608;"));
1427 b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
1428 } else {
1429 let bits = b.alloc_reg(PtxType::U64);
1431 b.raw_ptx(&format!("mov.u64 {bits}, 4602678819172646912;"));
1432 b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
1433 }
1434 dst
1435}
1436
1437#[cfg(test)]
1442#[path = "matrix_functions_tests.rs"]
1443mod tests;