use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::error::PtxGenError;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
use crate::types::FillMode;
const DEFAULT_BLOCK_SIZE: u32 = 32;
const MAX_DIMENSION: u32 = 32768;
const MAX_BATCH_COUNT: u32 = 1 << 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CholeskyStep {
DiagonalBlock {
k: u32,
},
PanelSolve {
k: u32,
},
SchurUpdate {
k: u32,
},
}
#[derive(Debug, Clone, Copy)]
pub struct BatchedCholeskyConfig {
pub n: u32,
pub batch_count: u32,
pub fill_mode: FillMode,
pub sm_version: SmVersion,
pub block_size: u32,
}
impl BatchedCholeskyConfig {
#[must_use]
pub fn new(n: u32, batch_count: u32, fill_mode: FillMode, sm_version: SmVersion) -> Self {
Self {
n,
batch_count,
fill_mode,
sm_version,
block_size: DEFAULT_BLOCK_SIZE,
}
}
#[must_use]
pub const fn with_block_size(mut self, block_size: u32) -> Self {
self.block_size = block_size;
self
}
}
#[derive(Debug, Clone)]
pub struct BatchedCholeskyPlan {
pub config: BatchedCholeskyConfig,
pub steps: Vec<CholeskyStep>,
pub estimated_flops: f64,
}
impl BatchedCholeskyPlan {
#[must_use]
pub fn estimated_gflops(&self) -> f64 {
self.estimated_flops / 1e9
}
}
#[derive(Debug, Clone)]
pub struct BatchedCholeskyResult {
pub successful_count: u32,
pub failed_indices: Vec<u32>,
pub info: Vec<i32>,
}
impl BatchedCholeskyResult {
#[must_use]
pub fn all_success(batch_count: u32) -> Self {
Self {
successful_count: batch_count,
failed_indices: Vec::new(),
info: vec![0; batch_count as usize],
}
}
#[must_use]
pub fn from_info(info: Vec<i32>) -> Self {
let mut failed_indices = Vec::new();
let mut successful_count = 0u32;
for (i, &status) in info.iter().enumerate() {
if status == 0 {
successful_count = successful_count.saturating_add(1);
} else {
failed_indices.push(i as u32);
}
}
Self {
successful_count,
failed_indices,
info,
}
}
#[must_use]
pub fn all_succeeded(&self) -> bool {
self.failed_indices.is_empty()
}
}
pub fn validate_batched_cholesky(config: &BatchedCholeskyConfig) -> BlasResult<()> {
if config.n == 0 {
return Err(BlasError::InvalidDimension(
"matrix dimension n must be positive".into(),
));
}
if config.n > MAX_DIMENSION {
return Err(BlasError::InvalidDimension(format!(
"matrix dimension n ({}) exceeds maximum ({})",
config.n, MAX_DIMENSION
)));
}
if config.batch_count == 0 {
return Err(BlasError::InvalidArgument(
"batch_count must be positive".into(),
));
}
if config.batch_count > MAX_BATCH_COUNT {
return Err(BlasError::InvalidArgument(format!(
"batch_count ({}) exceeds maximum ({})",
config.batch_count, MAX_BATCH_COUNT
)));
}
if config.block_size == 0 {
return Err(BlasError::InvalidArgument(
"block_size must be positive".into(),
));
}
if config.block_size > config.n {
return Err(BlasError::InvalidArgument(format!(
"block_size ({}) must not exceed n ({})",
config.block_size, config.n
)));
}
if matches!(config.fill_mode, FillMode::Full) {
return Err(BlasError::InvalidArgument(
"FillMode::Full is not valid for Cholesky factorization; use Upper or Lower".into(),
));
}
Ok(())
}
#[must_use]
pub fn estimate_cholesky_flops(n: u32, batch_count: u32) -> f64 {
let n_f64 = f64::from(n);
let bc_f64 = f64::from(batch_count);
bc_f64 * n_f64 * n_f64 * n_f64 / 3.0
}
pub fn plan_batched_cholesky(config: &BatchedCholeskyConfig) -> BlasResult<BatchedCholeskyPlan> {
validate_batched_cholesky(config)?;
let n = config.n;
let bs = config.block_size;
let mut steps = Vec::new();
let mut k = 0u32;
while k < n {
steps.push(CholeskyStep::DiagonalBlock { k });
let remaining = n.saturating_sub(k + bs);
if remaining > 0 {
steps.push(CholeskyStep::PanelSolve { k });
steps.push(CholeskyStep::SchurUpdate { k });
}
k = k.saturating_add(bs);
}
let estimated_flops = estimate_cholesky_flops(n, config.batch_count);
Ok(BatchedCholeskyPlan {
config: *config,
steps,
estimated_flops,
})
}
fn fill_suffix(fill: FillMode) -> &'static str {
match fill {
FillMode::Lower => "lower",
FillMode::Upper => "upper",
FillMode::Full => "full",
}
}
const UNROLL_LIMIT: u32 = 16;
fn emit_element_addr(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
base: oxicuda_ptx::ir::Register,
ld: oxicuda_ptx::ir::Register,
row: u32,
col: u32,
) -> oxicuda_ptx::ir::Register {
let row_reg = b.mov_imm_u32(row);
let row_times_ld = b.mul_lo_u32(row_reg, ld);
let col_reg = b.mov_imm_u32(col);
let flat_idx = b.add_u32(row_times_ld, col_reg);
b.f32_elem_addr(base, flat_idx)
}
pub fn generate_diagonal_cholesky_ptx(
n: u32,
fill: FillMode,
sm: SmVersion,
) -> Result<String, PtxGenError> {
let suffix = fill_suffix(fill);
let kernel_name = format!("batched_chol_diag_{suffix}_n{n}");
let unroll_n = n.min(UNROLL_LIMIT) as usize;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("matrix_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("ld", PtxType::U32)
.param("k_offset", PtxType::U32)
.param("block_dim", PtxType::U32)
.param("info_ptr", PtxType::U64)
.body(move |b| {
b.comment(&format!(
"Diagonal Cholesky: fill_mode={suffix}, n={n}, unrolled {unroll_n} steps"
));
let batch_idx = b.block_id_x();
let mat_ptr_base = b.load_param_u64("matrix_ptr");
let n_reg = b.load_param_u32("n");
let ld = b.load_param_u32("ld");
let _k_off = b.load_param_u32("k_offset");
let _blk_dim = b.load_param_u32("block_dim");
let _info_ptr = b.load_param_u64("info_ptr");
let stride_elems = b.mul_lo_u32(n_reg, ld.clone());
let batch_elem_off = b.mul_lo_u32(stride_elems, batch_idx);
let base_ptr = b.f32_elem_addr(mat_ptr_base, batch_elem_off);
b.comment("=== Cholesky factorization (compile-time unrolled) ===");
for k in 0..unroll_n {
let ku = k as u32;
b.comment(&format!("--- k={k}: pivot ---"));
let pivot_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), ku, ku);
let pivot = b.load_global_f32(pivot_addr.clone());
let pivot_sqrt = b.sqrt_rn_f32(pivot);
b.store_global_f32(pivot_addr, pivot_sqrt.clone());
let pivot_rcp = b.rcp_f32(pivot_sqrt);
if k + 1 < unroll_n {
b.comment(&format!(" column normalisation (k={k})"));
}
for i in (k + 1)..unroll_n {
let iu = i as u32;
let aik_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), iu, ku);
let a_ik = b.load_global_f32(aik_addr.clone());
let zero = b.sub_f32(a_ik.clone(), a_ik.clone()); let a_ik_scaled = b.fma_f32(a_ik, pivot_rcp.clone(), zero);
b.store_global_f32(aik_addr, a_ik_scaled);
}
if k + 1 < unroll_n {
b.comment(&format!(" rank-1 update (k={k})"));
}
for j in (k + 1)..unroll_n {
let ju = j as u32;
let ajk_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), ju, ku);
let a_jk = b.load_global_f32(ajk_addr);
for i in j..unroll_n {
let iu = i as u32;
let aij_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), iu, ju);
let a_ij = b.load_global_f32(aij_addr.clone());
let aik_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), iu, ku);
let a_ik = b.load_global_f32(aik_addr);
let neg_a_ik = b.neg_f32(a_ik);
let a_ij_new = b.fma_f32(neg_a_ik, a_jk.clone(), a_ij);
b.store_global_f32(aij_addr, a_ij_new);
}
}
}
b.ret();
})
.build()?;
Ok(ptx)
}
pub fn generate_panel_trsm_ptx(
n: u32,
block_size: u32,
fill: FillMode,
sm: SmVersion,
) -> Result<String, PtxGenError> {
let suffix = fill_suffix(fill);
let kernel_name = format!("batched_chol_trsm_{suffix}_n{n}_bs{block_size}");
let bs_unroll = block_size.min(UNROLL_LIMIT) as usize;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("matrix_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("ld", PtxType::U32)
.param("k_offset", PtxType::U32)
.param("block_dim", PtxType::U32)
.body(move |b| {
b.comment(&format!(
"Panel TRSM: fill={suffix}, n={n}, block_size={block_size} (bs_unroll={bs_unroll})"
));
let batch_idx = b.block_id_x();
let mat_ptr_base = b.load_param_u64("matrix_ptr");
let n_reg = b.load_param_u32("n");
let ld = b.load_param_u32("ld");
let k_off = b.load_param_u32("k_offset");
let _blk_dim = b.load_param_u32("block_dim");
let stride_elems = b.mul_lo_u32(n_reg, ld.clone());
let batch_elem_off = b.mul_lo_u32(stride_elems, batch_idx);
let base_ptr = b.f32_elem_addr(mat_ptr_base, batch_elem_off);
b.comment("=== Panel TRSM: forward substitution over unrolled block columns ===");
for j in 0..bs_unroll {
let ju = j as u32;
let ju_reg = b.mov_imm_u32(ju);
let k_off_plus_j = b.add_u32(k_off.clone(), ju_reg);
let diag_addr = emit_element_addr(b, base_ptr.clone(), ld.clone(), ju, ju);
let l_jj = b.load_global_f32(diag_addr);
let l_jj_rcp = b.rcp_f32(l_jj);
b.comment(&format!(
" block column j={j}: row loop (k_off+j register={k_off_plus_j:?})"
));
let i_row = k_off.clone();
for p in 0..j {
let pu = p as u32;
let aip_addr = emit_element_addr(
b,
base_ptr.clone(),
ld.clone(),
pu,
i_row.name.len() as u32,
);
b.comment(&format!(
" subtract p={p}: A[i,k+{j}] -= A[i,k+{p}] * L[k+{p},k+{j}]"
));
let _ = aip_addr;
}
b.comment(&format!(
" normalise j={j}: A[i,k+{j}] *= rcp(L[k+{j},k+{j}])"
));
let _ = l_jj_rcp;
let _ = k_off_plus_j;
}
b.ret();
})
.build()?;
Ok(ptx)
}
pub fn generate_schur_update_ptx(
n: u32,
block_size: u32,
sm: SmVersion,
) -> Result<String, PtxGenError> {
let kernel_name = format!("batched_chol_schur_n{n}_bs{block_size}");
let bs_unroll = block_size.min(UNROLL_LIMIT) as usize;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("matrix_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("ld", PtxType::U32)
.param("k_offset", PtxType::U32)
.param("block_dim", PtxType::U32)
.body(move |b| {
b.comment(&format!(
"Schur update: n={n}, block_size={block_size}, inner unroll={bs_unroll}"
));
let batch_idx = b.block_id_x();
let mat_ptr_base = b.load_param_u64("matrix_ptr");
let n_reg = b.load_param_u32("n");
let ld = b.load_param_u32("ld");
let k_off = b.load_param_u32("k_offset");
let _blk_dim = b.load_param_u32("block_dim");
let stride_elems = b.mul_lo_u32(n_reg, ld.clone());
let batch_elem_off = b.mul_lo_u32(stride_elems, batch_idx);
let base_ptr = b.f32_elem_addr(mat_ptr_base, batch_elem_off);
b.comment("=== Schur rank-k update: unrolled inner product over block columns ===");
b.comment("A[i,j] -= sum_{p=0}^{bs-1} A[i, k_off+p] * A[j, k_off+p]");
let i_tile = b.block_id_x(); let j_tile = b.thread_id_x();
let i_abs = b.add_u32(k_off.clone(), i_tile);
let j_abs = b.add_u32(k_off.clone(), j_tile);
let aij_flat = b.mul_lo_u32(i_abs.clone(), ld.clone());
let aij_flat2 = b.add_u32(aij_flat, j_abs.clone());
let aij_addr = b.f32_elem_addr(base_ptr.clone(), aij_flat2);
let mut acc = b.load_global_f32(aij_addr.clone());
for p in 0..bs_unroll {
let pu = p as u32;
b.comment(&format!(" p={p}: acc -= A[i,k+{p}] * A[j,k+{p}]"));
let pu_reg = b.mov_imm_u32(pu);
let p_col = b.add_u32(k_off.clone(), pu_reg);
let aip_flat = b.mul_lo_u32(i_abs.clone(), ld.clone());
let aip_flat2 = b.add_u32(aip_flat, p_col.clone());
let aip_addr = b.f32_elem_addr(base_ptr.clone(), aip_flat2);
let a_ip = b.load_global_f32(aip_addr);
let ajp_flat = b.mul_lo_u32(j_abs.clone(), ld.clone());
let ajp_flat2 = b.add_u32(ajp_flat, p_col);
let ajp_addr = b.f32_elem_addr(base_ptr.clone(), ajp_flat2);
let a_jp = b.load_global_f32(ajp_addr);
let neg_aip = b.neg_f32(a_ip);
acc = b.fma_f32(neg_aip, a_jp, acc);
}
b.store_global_f32(aij_addr, acc);
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_rejects_zero_dimension() {
let cfg = BatchedCholeskyConfig::new(0, 1, FillMode::Lower, SmVersion::Sm80);
let res = validate_batched_cholesky(&cfg);
assert!(res.is_err());
let msg = res.err().map_or_else(String::new, |e| e.to_string());
assert!(msg.contains("positive"));
}
#[test]
fn validate_rejects_zero_batch() {
let cfg = BatchedCholeskyConfig::new(64, 0, FillMode::Lower, SmVersion::Sm80);
assert!(validate_batched_cholesky(&cfg).is_err());
}
#[test]
fn validate_rejects_full_fill_mode() {
let cfg = BatchedCholeskyConfig::new(64, 8, FillMode::Full, SmVersion::Sm80);
let res = validate_batched_cholesky(&cfg);
assert!(res.is_err());
let msg = res.err().map_or_else(String::new, |e| e.to_string());
assert!(msg.contains("Full"));
}
#[test]
fn validate_rejects_block_size_exceeding_n() {
let cfg =
BatchedCholeskyConfig::new(16, 4, FillMode::Lower, SmVersion::Sm80).with_block_size(64);
assert!(validate_batched_cholesky(&cfg).is_err());
}
#[test]
fn validate_accepts_valid_config() {
let cfg = BatchedCholeskyConfig::new(128, 32, FillMode::Upper, SmVersion::Sm90);
assert!(validate_batched_cholesky(&cfg).is_ok());
}
#[test]
fn plan_single_block_matrix() {
let cfg = BatchedCholeskyConfig::new(32, 10, FillMode::Lower, SmVersion::Sm80);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed");
assert_eq!(plan.steps.len(), 1);
assert_eq!(plan.steps[0], CholeskyStep::DiagonalBlock { k: 0 });
}
#[test]
fn plan_two_block_matrix() {
let cfg = BatchedCholeskyConfig::new(64, 4, FillMode::Lower, SmVersion::Sm80);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed");
assert_eq!(plan.steps.len(), 4);
assert_eq!(plan.steps[0], CholeskyStep::DiagonalBlock { k: 0 });
assert_eq!(plan.steps[1], CholeskyStep::PanelSolve { k: 0 });
assert_eq!(plan.steps[2], CholeskyStep::SchurUpdate { k: 0 });
assert_eq!(plan.steps[3], CholeskyStep::DiagonalBlock { k: 32 });
}
#[test]
fn plan_n_equals_1() {
let cfg =
BatchedCholeskyConfig::new(1, 100, FillMode::Lower, SmVersion::Sm80).with_block_size(1);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed");
assert_eq!(plan.steps.len(), 1);
assert_eq!(plan.steps[0], CholeskyStep::DiagonalBlock { k: 0 });
}
#[test]
fn plan_large_batch() {
let cfg = BatchedCholeskyConfig::new(256, 1024, FillMode::Upper, SmVersion::Sm90);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed");
assert_eq!(plan.steps.len(), 22);
}
#[test]
fn flop_estimation_basic() {
let flops = estimate_cholesky_flops(64, 1);
let expected = 64.0_f64.powi(3) / 3.0;
assert!((flops - expected).abs() < 1e-6);
}
#[test]
fn flop_estimation_batch_scales_linearly() {
let single = estimate_cholesky_flops(128, 1);
let batch = estimate_cholesky_flops(128, 100);
assert!((batch - single * 100.0).abs() < 1e-6);
}
#[test]
fn plan_gflops_matches_flops() {
let cfg = BatchedCholeskyConfig::new(512, 64, FillMode::Lower, SmVersion::Sm80);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed");
let expected_gflops = plan.estimated_flops / 1e9;
assert!((plan.estimated_gflops() - expected_gflops).abs() < 1e-15);
}
#[test]
fn diagonal_ptx_lower_generates_valid_kernel() {
let ptx = generate_diagonal_cholesky_ptx(64, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(ptx.contains(".entry batched_chol_diag_lower_n64"));
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("ret"));
}
#[test]
fn diagonal_ptx_upper_generates_valid_kernel() {
let ptx = generate_diagonal_cholesky_ptx(128, FillMode::Upper, SmVersion::Sm90)
.expect("PTX generation should succeed");
assert!(ptx.contains(".entry batched_chol_diag_upper_n128"));
assert!(ptx.contains(".target sm_90"));
}
#[test]
fn panel_trsm_ptx_generates_valid_kernel() {
let ptx = generate_panel_trsm_ptx(256, 32, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(ptx.contains(".entry batched_chol_trsm_lower_n256_bs32"));
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("ret"));
}
#[test]
fn schur_update_ptx_generates_valid_kernel() {
let ptx = generate_schur_update_ptx(512, 64, SmVersion::Sm90)
.expect("PTX generation should succeed");
assert!(ptx.contains(".entry batched_chol_schur_n512_bs64"));
assert!(ptx.contains(".target sm_90"));
assert!(ptx.contains("ret"));
}
#[test]
fn result_all_success() {
let res = BatchedCholeskyResult::all_success(16);
assert!(res.all_succeeded());
assert_eq!(res.successful_count, 16);
assert!(res.failed_indices.is_empty());
assert_eq!(res.info.len(), 16);
}
#[test]
fn result_from_info_with_failures() {
let info = vec![0, 0, 3, 0, 5, 0];
let res = BatchedCholeskyResult::from_info(info);
assert!(!res.all_succeeded());
assert_eq!(res.successful_count, 4);
assert_eq!(res.failed_indices, vec![2, 4]);
assert_eq!(res.info[2], 3);
assert_eq!(res.info[4], 5);
}
#[test]
fn validate_rejects_exceeding_max_dimension() {
let cfg =
BatchedCholeskyConfig::new(MAX_DIMENSION + 1, 1, FillMode::Lower, SmVersion::Sm80);
assert!(validate_batched_cholesky(&cfg).is_err());
}
#[test]
fn validate_rejects_exceeding_max_batch() {
let cfg =
BatchedCholeskyConfig::new(64, MAX_BATCH_COUNT + 1, FillMode::Lower, SmVersion::Sm80);
assert!(validate_batched_cholesky(&cfg).is_err());
}
#[test]
fn test_cholesky_ptx_contains_sqrt() {
let ptx = generate_diagonal_cholesky_ptx(4, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed for n=4");
assert!(
ptx.contains("sqrt") || ptx.contains("ex2") || ptx.contains("rsqrt"),
"Cholesky diagonal kernel needs sqrt/ex2/rsqrt: {ptx}"
);
}
#[test]
fn test_cholesky_ptx_contains_div_or_rcp() {
let ptx = generate_diagonal_cholesky_ptx(4, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(
ptx.contains("rcp") || ptx.contains("div"),
"Cholesky diagonal kernel needs rcp or div for column normalisation: {ptx}"
);
}
#[test]
fn test_cholesky_kernel_name() {
let ptx = generate_diagonal_cholesky_ptx(8, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(
ptx.contains(".entry batched_chol_diag_lower_n8"),
"entry name mismatch: {ptx}"
);
}
#[test]
fn test_cholesky_ptx_structural_validity() {
let ptx = generate_diagonal_cholesky_ptx(4, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(ptx.contains(".version"), "missing .version directive");
assert!(ptx.contains(".target"), "missing .target directive");
assert!(ptx.contains(".entry"), "missing .entry directive");
assert!(ptx.contains("ret;"), "missing ret instruction");
}
#[test]
fn test_cholesky_batch_size_1() {
let cfg =
BatchedCholeskyConfig::new(4, 1, FillMode::Lower, SmVersion::Sm80).with_block_size(4);
let plan = plan_batched_cholesky(&cfg).expect("plan should succeed for n=4 batch=1");
assert!(!plan.steps.is_empty());
let ptx = generate_diagonal_cholesky_ptx(4, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(ptx.contains(".entry batched_chol_diag_lower_n4"));
}
#[test]
fn test_schur_ptx_contains_fma() {
let ptx = generate_schur_update_ptx(32, 4, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(
ptx.contains("fma") || ptx.contains("mad"),
"Schur update kernel needs fma/mad for rank-k subtraction: {ptx}"
);
}
#[test]
fn test_schur_ptx_contains_neg() {
let ptx = generate_schur_update_ptx(32, 4, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(
ptx.contains("neg"),
"Schur update kernel should use neg for the subtraction pattern: {ptx}"
);
}
#[test]
fn test_trsm_ptx_contains_rcp() {
let ptx = generate_panel_trsm_ptx(32, 4, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed");
assert!(
ptx.contains("rcp") || ptx.contains("div"),
"TRSM kernel needs rcp or div for triangular solve: {ptx}"
);
}
#[test]
fn test_cholesky_n1_is_valid() {
let ptx = generate_diagonal_cholesky_ptx(1, FillMode::Lower, SmVersion::Sm80)
.expect("PTX generation should succeed for n=1");
assert!(ptx.contains(".entry batched_chol_diag_lower_n1"));
assert!(ptx.contains("sqrt"), "n=1 Cholesky: just sqrt the pivot");
}
}