use std::sync::Arc;
use oxicuda_driver::{CUdeviceptr, Module};
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{DiagType, FillMode, GpuFloat, MatrixDesc, Transpose};
const TRSV_SINGLE_BLOCK_MAX: u32 = 4096;
const TRSV_BLOCK_SIZE: u32 = 1024;
const TRSV_GEMV_BLOCK: u32 = 256;
#[allow(clippy::too_many_arguments)]
pub fn trsv<T: GpuFloat>(
handle: &BlasHandle,
uplo: FillMode,
trans: Transpose,
diag: DiagType,
n: u32,
a: &MatrixDesc<T>,
x: &mut DeviceBuffer<T>,
incx: i32,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_trsv_args(n, a, x, incx)?;
if n > TRSV_SINGLE_BLOCK_MAX {
return trsv_blocked(handle, uplo, trans, diag, n, a, x, incx);
}
let ptx = generate_trsv_ptx::<T>(handle.sm_version(), uplo, trans, diag, n)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, "trsv")?;
let block_size = n.min(256);
let params = LaunchParams::new(1u32, block_size);
kernel.launch(
¶ms,
handle.stream(),
&(a.ptr, x.as_device_ptr(), n, a.ld, incx as u32),
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn trsv_blocked<T: GpuFloat>(
handle: &BlasHandle,
uplo: FillMode,
trans: Transpose,
diag: DiagType,
n: u32,
a: &MatrixDesc<T>,
x: &mut DeviceBuffer<T>,
incx: i32,
) -> BlasResult<()> {
let trsv_ptx = generate_trsv_ptx::<T>(handle.sm_version(), uplo, trans, diag, n)?;
let trsv_module = Arc::new(Module::from_ptx(&trsv_ptx)?);
let trsv_kernel = Kernel::from_module(trsv_module, "trsv")?;
let gemv_ptx = generate_trsv_update_gemv_ptx::<T>(handle.sm_version(), trans)?;
let gemv_module = Arc::new(Module::from_ptx(&gemv_ptx)?);
let gemv_kernel = Kernel::from_module(gemv_module, "trsv_update_gemv")?;
let elem_bytes = u64::from(T::size_u32());
let lda64 = u64::from(a.ld);
let incx_u = incx as u32;
let incx64 = u64::from(incx_u);
let is_upper = matches!(uplo, FillMode::Upper);
let use_trans = matches!(trans, Transpose::Trans | Transpose::ConjTrans);
let forward = is_upper == use_trans;
let block = TRSV_BLOCK_SIZE;
let inner_block_size = block.min(256);
if forward {
let mut i: u32 = 0;
while i < n {
let k = block.min(n - i);
let a_block_ptr = offset_ptr(a.ptr, u64::from(i) * lda64 + u64::from(i), elem_bytes);
let x_segment_ptr = offset_ptr(x.as_device_ptr(), u64::from(i) * incx64, elem_bytes);
let inner_threads = k.min(inner_block_size);
let inner_params = LaunchParams::new(1u32, inner_threads.max(1));
trsv_kernel.launch(
&inner_params,
handle.stream(),
&(a_block_ptr, x_segment_ptr, k, a.ld, incx_u),
)?;
let remaining = n - (i + k);
if remaining > 0 {
let x_remaining_ptr =
offset_ptr(x.as_device_ptr(), u64::from(i + k) * incx64, elem_bytes);
let panel_ptr = trsv_panel_below_ptr(a, i, k, use_trans, elem_bytes);
launch_trsv_gemv_update(
&gemv_kernel,
handle,
panel_ptr,
x_segment_ptr,
x_remaining_ptr,
remaining,
k,
a.ld,
incx_u,
)?;
}
i += k;
}
} else {
let mut i_end = n;
while i_end > 0 {
let k = block.min(i_end);
let i = i_end - k;
let a_block_ptr = offset_ptr(a.ptr, u64::from(i) * lda64 + u64::from(i), elem_bytes);
let x_segment_ptr = offset_ptr(x.as_device_ptr(), u64::from(i) * incx64, elem_bytes);
let inner_threads = k.min(inner_block_size);
let inner_params = LaunchParams::new(1u32, inner_threads.max(1));
trsv_kernel.launch(
&inner_params,
handle.stream(),
&(a_block_ptr, x_segment_ptr, k, a.ld, incx_u),
)?;
let remaining = i;
if remaining > 0 {
let x_remaining_ptr = offset_ptr(x.as_device_ptr(), 0u64, elem_bytes);
let panel_ptr = trsv_panel_above_ptr(a, i, k, use_trans, elem_bytes);
launch_trsv_gemv_update(
&gemv_kernel,
handle,
panel_ptr,
x_segment_ptr,
x_remaining_ptr,
remaining,
k,
a.ld,
incx_u,
)?;
}
i_end = i;
}
}
Ok(())
}
#[inline]
fn offset_ptr(base: CUdeviceptr, elements: u64, elem_bytes: u64) -> CUdeviceptr {
base.wrapping_add(elements * elem_bytes)
}
#[inline]
fn trsv_panel_below_ptr<T: GpuFloat>(
a: &MatrixDesc<T>,
i: u32,
k: u32,
use_trans: bool,
elem_bytes: u64,
) -> CUdeviceptr {
let lda = u64::from(a.ld);
let i_u64 = u64::from(i);
let k_u64 = u64::from(k);
if !use_trans {
offset_ptr(a.ptr, (i_u64 + k_u64) * lda + i_u64, elem_bytes)
} else {
offset_ptr(a.ptr, i_u64 * lda + i_u64 + k_u64, elem_bytes)
}
}
#[inline]
fn trsv_panel_above_ptr<T: GpuFloat>(
a: &MatrixDesc<T>,
i: u32,
k: u32,
use_trans: bool,
elem_bytes: u64,
) -> CUdeviceptr {
let lda = u64::from(a.ld);
let i_u64 = u64::from(i);
let k_u64 = u64::from(k);
if !use_trans {
offset_ptr(a.ptr, i_u64, elem_bytes)
} else {
let _ = k_u64;
offset_ptr(a.ptr, i_u64 * lda, elem_bytes)
}
}
#[allow(clippy::too_many_arguments)]
fn launch_trsv_gemv_update(
kernel: &Kernel,
handle: &BlasHandle,
a_panel_ptr: CUdeviceptr,
x_panel_ptr: CUdeviceptr,
y_target_ptr: CUdeviceptr,
output_len: u32,
inner_len: u32,
lda: u32,
incx: u32,
) -> BlasResult<()> {
let block_size = TRSV_GEMV_BLOCK;
let grid_size = grid_size_for(output_len, block_size);
let params = LaunchParams::new(grid_size, block_size);
kernel.launch(
¶ms,
handle.stream(),
&(
a_panel_ptr,
x_panel_ptr,
y_target_ptr,
lda,
incx,
output_len,
inner_len,
),
)?;
Ok(())
}
fn validate_trsv_args<T: GpuFloat>(
n: u32,
a: &MatrixDesc<T>,
x: &DeviceBuffer<T>,
incx: i32,
) -> BlasResult<()> {
if incx <= 0 {
return Err(BlasError::InvalidArgument(
"incx must be positive".to_string(),
));
}
if a.rows < n || a.cols < n {
return Err(BlasError::InvalidDimension(format!(
"A must be at least {n}x{n}, got {}x{}",
a.rows, a.cols
)));
}
let x_req = required_elements(n, incx);
if x.len() < x_req {
return Err(BlasError::BufferTooSmall {
expected: x_req,
actual: x.len(),
});
}
Ok(())
}
fn generate_trsv_ptx<T: GpuFloat>(
sm: SmVersion,
uplo: FillMode,
trans: Transpose,
diag: DiagType,
_n: u32,
) -> BlasResult<String> {
let is_f64 = T::SIZE == 8;
let elem_bytes = T::size_u32();
let ptx_ty = T::PTX_TYPE;
let is_upper = matches!(uplo, FillMode::Upper);
let use_trans = matches!(trans, Transpose::Trans | Transpose::ConjTrans);
let is_unit = matches!(diag, DiagType::Unit);
let forward = is_upper == use_trans;
KernelBuilder::new("trsv")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("x_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("lda", PtxType::U32)
.param("incx", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let one_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {one_reg}, 1;"));
b.if_lt_u32(tid, one_reg, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let x_ptr = b.load_param_u64("x_ptr");
let n_reg = b.load_param_u32("n");
let lda = b.load_param_u32("lda");
let incx = b.load_param_u32("incx");
let outer_label = b.fresh_label("trsv_outer");
let outer_done = b.fresh_label("trsv_outer_done");
let i = b.alloc_reg(PtxType::U32);
if forward {
b.raw_ptx(&format!("mov.u32 {i}, 0;"));
} else {
b.raw_ptx(&format!("sub.u32 {i}, {n_reg}, 1;"));
}
b.label(&outer_label);
let outer_pred = b.alloc_reg(PtxType::Pred);
if forward {
b.raw_ptx(&format!("setp.lo.u32 {outer_pred}, {i}, {n_reg};"));
} else {
b.raw_ptx(&format!("setp.lo.u32 {outer_pred}, {i}, {n_reg};"));
}
b.raw_ptx(&format!("@!{outer_pred} bra {outer_done};"));
let xi_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {xi_idx}, {i}, {incx};"));
let xi_addr = b.byte_offset_addr(x_ptr.clone(), xi_idx, elem_bytes);
let xi_val = load_float(b, xi_addr.clone(), is_f64);
let inner_label = b.fresh_label("trsv_inner");
let inner_done = b.fresh_label("trsv_inner_done");
let j = b.alloc_reg(PtxType::U32);
let sum = b.alloc_reg(ptx_ty);
emit_zero(b, sum.clone(), is_f64);
if forward {
b.raw_ptx(&format!("mov.u32 {j}, 0;"));
} else {
let i_plus1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {i_plus1}, {i}, 1;"));
b.raw_ptx(&format!("mov.u32 {j}, {i_plus1};"));
}
b.label(&inner_label);
let inner_pred = b.alloc_reg(PtxType::Pred);
if forward {
b.raw_ptx(&format!("setp.lo.u32 {inner_pred}, {j}, {i};"));
} else {
b.raw_ptx(&format!("setp.lo.u32 {inner_pred}, {j}, {n_reg};"));
}
b.raw_ptx(&format!("@!{inner_pred} bra {inner_done};"));
let (row, col) = if !use_trans {
(i.clone(), j.clone())
} else {
(j.clone(), i.clone())
};
let row_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {row_off}, {row}, {lda};"));
let a_idx = b.add_u32(row_off, col);
let a_addr = b.byte_offset_addr(a_ptr.clone(), a_idx, elem_bytes);
let a_val = load_float(b, a_addr, is_f64);
let xj_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {xj_idx}, {j}, {incx};"));
let xj_addr = b.byte_offset_addr(x_ptr.clone(), xj_idx, elem_bytes);
let xj_val = load_float(b, xj_addr, is_f64);
let new_sum = if is_f64 {
b.fma_f64(a_val, xj_val, sum.clone())
} else {
b.fma_f32(a_val, xj_val, sum.clone())
};
emit_mov_float(b, sum.clone(), new_sum, is_f64);
b.raw_ptx(&format!("add.u32 {j}, {j}, 1;"));
b.branch(&inner_label);
b.label(&inner_done);
let diff = if is_f64 {
b.sub_f64(xi_val, sum)
} else {
b.sub_f32(xi_val, sum)
};
let result = if is_unit {
diff
} else {
let diag_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {diag_off}, {i}, {lda};"));
let diag_idx = b.add_u32(diag_off, i.clone());
let diag_addr = b.byte_offset_addr(a_ptr.clone(), diag_idx, elem_bytes);
let diag_val = load_float(b, diag_addr, is_f64);
let r = b.alloc_reg(ptx_ty);
if is_f64 {
b.raw_ptx(&format!("div.rn.f64 {r}, {diff}, {diag_val};"));
} else {
b.raw_ptx(&format!("div.rn.f32 {r}, {diff}, {diag_val};"));
}
r
};
store_float(b, xi_addr, result, is_f64);
if forward {
b.raw_ptx(&format!("add.u32 {i}, {i}, 1;"));
} else {
b.raw_ptx(&format!("sub.u32 {i}, {i}, 1;"));
}
b.branch(&outer_label);
b.label(&outer_done);
});
b.ret();
})
.build()
.map_err(|e| BlasError::PtxGeneration(e.to_string()))
}
fn emit_zero(b: &mut BodyBuilder<'_>, reg: Register, is_f64: bool) {
if is_f64 {
b.raw_ptx(&format!("mov.b64 {reg}, 0d0000000000000000;"));
} else {
b.raw_ptx(&format!("mov.b32 {reg}, 0f00000000;"));
}
}
fn emit_mov_float(b: &mut BodyBuilder<'_>, dst: Register, src: Register, is_f64: bool) {
let ty = if is_f64 { "f64" } else { "f32" };
b.raw_ptx(&format!("mov.{ty} {dst}, {src};"));
}
fn load_float(b: &mut BodyBuilder<'_>, addr: Register, is_f64: bool) -> Register {
if is_f64 {
b.load_global_f64(addr)
} else {
b.load_global_f32(addr)
}
}
fn store_float(b: &mut BodyBuilder<'_>, addr: Register, val: Register, is_f64: bool) {
if is_f64 {
b.store_global_f64(addr, val);
} else {
b.store_global_f32(addr, val);
}
}
fn required_elements(n: u32, inc: i32) -> usize {
if n == 0 {
return 0;
}
1 + (n as usize - 1) * inc.unsigned_abs() as usize
}
fn generate_trsv_update_gemv_ptx<T: GpuFloat>(
sm: SmVersion,
trans: Transpose,
) -> BlasResult<String> {
let is_f64 = T::SIZE == 8;
let elem_bytes = T::size_u32();
let ptx_ty = T::PTX_TYPE;
let use_trans = matches!(trans, Transpose::Trans | Transpose::ConjTrans);
KernelBuilder::new("trsv_update_gemv")
.target(sm)
.param("a_ptr", PtxType::U64)
.param("x_ptr", PtxType::U64)
.param("y_ptr", PtxType::U64)
.param("lda", PtxType::U32)
.param("incx", PtxType::U32)
.param("output_len", PtxType::U32)
.param("inner_len", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let output_len = b.load_param_u32("output_len");
let gid_inner = gid.clone();
b.if_lt_u32(gid, output_len, move |b| {
let gid = gid_inner;
let a_ptr = b.load_param_u64("a_ptr");
let x_ptr = b.load_param_u64("x_ptr");
let y_ptr = b.load_param_u64("y_ptr");
let inner_len = b.load_param_u32("inner_len");
let lda = b.load_param_u32("lda");
let incx = b.load_param_u32("incx");
let acc = b.alloc_reg(ptx_ty);
emit_zero(b, acc.clone(), is_f64);
let loop_label = b.fresh_label("trsv_gemv_loop");
let done_label = b.fresh_label("trsv_gemv_done");
let k = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {k}, 0;"));
b.label(&loop_label);
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred}, {k}, {inner_len};"));
b.raw_ptx(&format!("@!{pred} bra {done_label};"));
let row_off = b.alloc_reg(PtxType::U32);
if !use_trans {
b.raw_ptx(&format!("mul.lo.u32 {row_off}, {gid}, {lda};"));
let idx = b.add_u32(row_off.clone(), k.clone());
let a_addr = b.byte_offset_addr(a_ptr.clone(), idx, elem_bytes);
let a_val = load_float(b, a_addr, is_f64);
let x_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {x_idx}, {k}, {incx};"));
let x_addr = b.byte_offset_addr(x_ptr.clone(), x_idx, elem_bytes);
let x_val = load_float(b, x_addr, is_f64);
let new_acc = if is_f64 {
b.fma_f64(a_val, x_val, acc.clone())
} else {
b.fma_f32(a_val, x_val, acc.clone())
};
emit_mov_float(b, acc.clone(), new_acc, is_f64);
} else {
b.raw_ptx(&format!("mul.lo.u32 {row_off}, {k}, {lda};"));
let idx = b.add_u32(row_off.clone(), gid.clone());
let a_addr = b.byte_offset_addr(a_ptr.clone(), idx, elem_bytes);
let a_val = load_float(b, a_addr, is_f64);
let x_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {x_idx}, {k}, {incx};"));
let x_addr = b.byte_offset_addr(x_ptr.clone(), x_idx, elem_bytes);
let x_val = load_float(b, x_addr, is_f64);
let new_acc = if is_f64 {
b.fma_f64(a_val, x_val, acc.clone())
} else {
b.fma_f32(a_val, x_val, acc.clone())
};
emit_mov_float(b, acc.clone(), new_acc, is_f64);
}
b.raw_ptx(&format!("add.u32 {k}, {k}, 1;"));
b.branch(&loop_label);
b.label(&done_label);
let y_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {y_idx}, {gid}, {incx};"));
let y_addr = b.byte_offset_addr(y_ptr, y_idx, elem_bytes);
let y_cur = load_float(b, y_addr.clone(), is_f64);
let updated = if is_f64 {
b.sub_f64(y_cur, acc)
} else {
b.sub_f32(y_cur, acc)
};
store_float(b, y_addr, updated, is_f64);
});
b.ret();
})
.build()
.map_err(|e| BlasError::PtxGeneration(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trsv_ptx_generation_lower_notrans_nonunit() {
let ptx = generate_trsv_ptx::<f32>(
SmVersion::Sm80,
FillMode::Lower,
Transpose::NoTrans,
DiagType::NonUnit,
64,
);
assert!(ptx.is_ok());
let ptx = ptx.expect("test: PTX generation should succeed");
assert!(ptx.contains(".entry trsv"));
}
#[test]
fn trsv_ptx_generation_upper_trans_unit() {
let ptx = generate_trsv_ptx::<f64>(
SmVersion::Sm80,
FillMode::Upper,
Transpose::Trans,
DiagType::Unit,
128,
);
assert!(ptx.is_ok());
let ptx = ptx.expect("test: PTX generation should succeed");
assert!(ptx.contains(".entry trsv"));
}
#[test]
fn trsv_ptx_generation_various_sizes() {
for &sz in &[1, 32, 256, 512] {
let ptx = generate_trsv_ptx::<f32>(
SmVersion::Sm80,
FillMode::Lower,
Transpose::NoTrans,
DiagType::NonUnit,
sz,
);
assert!(ptx.is_ok(), "failed for n={sz}");
}
}
#[test]
fn trsv_update_gemv_ptx_compiles_for_all_modes() {
for &trans in &[Transpose::NoTrans, Transpose::Trans, Transpose::ConjTrans] {
let f32_ptx = generate_trsv_update_gemv_ptx::<f32>(SmVersion::Sm80, trans);
assert!(f32_ptx.is_ok(), "f32 TRSV update GEMV failed for {trans:?}");
let f32_ptx = f32_ptx.expect("test: PTX generation should succeed");
assert!(f32_ptx.contains(".entry trsv_update_gemv"));
assert!(f32_ptx.contains("ld.global.f32"));
assert!(f32_ptx.contains("st.global.f32"));
assert!(f32_ptx.contains("fma.rn.f32"));
assert!(f32_ptx.contains("sub.f32"));
let f64_ptx = generate_trsv_update_gemv_ptx::<f64>(SmVersion::Sm80, trans);
assert!(f64_ptx.is_ok(), "f64 TRSV update GEMV failed for {trans:?}");
let f64_ptx = f64_ptx.expect("test: PTX generation should succeed");
assert!(f64_ptx.contains(".entry trsv_update_gemv"));
assert!(f64_ptx.contains("fma.rn.f64"));
assert!(f64_ptx.contains("sub.f64"));
}
}
#[test]
fn trsv_panel_below_ptr_layout_lower_notrans() {
let lda = 8192u32;
let n = 8192u32;
let elem_bytes = u64::from(<f32 as GpuFloat>::size_u32());
let base: CUdeviceptr = 0x1000_0000;
let a = MatrixDesc::<f32>::from_raw(base, n, n, lda, crate::types::Layout::RowMajor);
let i = TRSV_BLOCK_SIZE; let k = TRSV_BLOCK_SIZE;
let p = trsv_panel_below_ptr(&a, i, k, false, elem_bytes);
let expected_offset = (u64::from(i + k) * u64::from(lda) + u64::from(i)) * elem_bytes;
assert_eq!(p - base, expected_offset);
}
#[test]
fn trsv_panel_below_ptr_layout_upper_trans() {
let lda = 8192u32;
let n = 8192u32;
let elem_bytes = u64::from(<f32 as GpuFloat>::size_u32());
let base: CUdeviceptr = 0x2000_0000;
let a = MatrixDesc::<f32>::from_raw(base, n, n, lda, crate::types::Layout::RowMajor);
let i = TRSV_BLOCK_SIZE * 2;
let k = TRSV_BLOCK_SIZE;
let p = trsv_panel_below_ptr(&a, i, k, true, elem_bytes);
let expected_offset =
(u64::from(i) * u64::from(lda) + u64::from(i) + u64::from(k)) * elem_bytes;
assert_eq!(p - base, expected_offset);
}
#[test]
fn trsv_blocked_iteration_count_8192() {
let n: u32 = 8192;
let block = TRSV_BLOCK_SIZE;
let mut diag_solves = 0u32;
let mut off_updates = 0u32;
let mut i: u32 = 0;
while i < n {
let k = block.min(n - i);
diag_solves += 1;
if i + k < n {
off_updates += 1;
}
i += k;
}
assert_eq!(diag_solves, 8);
assert_eq!(off_updates, 7);
}
#[test]
fn trsv_blocked_iteration_count_partial_last() {
let n: u32 = 8200;
let block = TRSV_BLOCK_SIZE;
let mut diag_solves = 0u32;
let mut off_updates = 0u32;
let mut i: u32 = 0;
while i < n {
let k = block.min(n - i);
diag_solves += 1;
if i + k < n {
off_updates += 1;
}
i += k;
}
assert_eq!(diag_solves, 9);
assert_eq!(off_updates, 8);
}
#[test]
fn cpu_reference_blocked_lower_notrans_matches_dense_solve() {
let n: usize = 256;
let lda = n;
let block: usize = 32;
let mut l = vec![0.0f64; n * n];
for i in 0..n {
for j in 0..=i {
l[i * lda + j] = if i == j {
1.0 + (i as f64) * 0.001
} else {
-0.01 + 0.001 * ((i + j) as f64).cos()
};
}
}
let b: Vec<f64> = (0..n).map(|i| 1.0 + (i as f64).sin()).collect();
let mut x_ref = b.clone();
for i in 0..n {
let mut s = 0.0f64;
for j in 0..i {
s += l[i * lda + j] * x_ref[j];
}
x_ref[i] = (x_ref[i] - s) / l[i * lda + i];
}
let mut x_blk = b.clone();
let mut i = 0usize;
while i < n {
let k = block.min(n - i);
for ii in i..(i + k) {
let mut s = 0.0f64;
for jj in i..ii {
s += l[ii * lda + jj] * x_blk[jj];
}
x_blk[ii] = (x_blk[ii] - s) / l[ii * lda + ii];
}
if i + k < n {
for r in (i + k)..n {
let mut s = 0.0f64;
for c in i..(i + k) {
s += l[r * lda + c] * x_blk[c];
}
x_blk[r] -= s;
}
}
i += k;
}
for idx in 0..n {
let diff = (x_blk[idx] - x_ref[idx]).abs();
assert!(
diff < 1e-10,
"blocked algorithm diverged at index {idx}: diff={diff}"
);
}
}
#[test]
fn cpu_reference_blocked_upper_notrans_matches_dense_solve() {
let n: usize = 256;
let lda = n;
let block: usize = 32;
let mut u = vec![0.0f64; n * n];
for i in 0..n {
for j in i..n {
u[i * lda + j] = if i == j {
1.0 + (i as f64) * 0.001
} else {
-0.01 + 0.001 * ((i + j) as f64).sin()
};
}
}
let b: Vec<f64> = (0..n).map(|i| 0.5 + (i as f64).cos()).collect();
let mut x_ref = b.clone();
for i in (0..n).rev() {
let mut s = 0.0f64;
for j in (i + 1)..n {
s += u[i * lda + j] * x_ref[j];
}
x_ref[i] = (x_ref[i] - s) / u[i * lda + i];
}
let mut x_blk = b.clone();
let mut i_end = n;
while i_end > 0 {
let k = block.min(i_end);
let i = i_end - k;
for ii in (i..(i + k)).rev() {
let mut s = 0.0f64;
for jj in (ii + 1)..(i + k) {
s += u[ii * lda + jj] * x_blk[jj];
}
x_blk[ii] = (x_blk[ii] - s) / u[ii * lda + ii];
}
if i > 0 {
for r in 0..i {
let mut s = 0.0f64;
for c in i..(i + k) {
s += u[r * lda + c] * x_blk[c];
}
x_blk[r] -= s;
}
}
i_end = i;
}
for idx in 0..n {
let diff = (x_blk[idx] - x_ref[idx]).abs();
assert!(
diff < 1e-10,
"blocked upper-triangular algorithm diverged at index {idx}: diff={diff}"
);
}
}
const _: () = assert!(TRSV_BLOCK_SIZE <= TRSV_SINGLE_BLOCK_MAX);
}