#![cfg(feature = "cusparse")]
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::graph::{GraphOpRecord, GraphRecordCtx};
use crate::kernel::CsrMatrix;
pub struct SpMvOp {
pub csr: CsrMatrix,
pub x: GpuRef<f32>,
pub y: GpuRef<f32>,
pub alpha: f32,
pub beta: f32,
}
pub struct SpMmOp {
pub csr: CsrMatrix,
pub b: GpuRef<f32>,
pub c: GpuRef<f32>,
pub b_cols: i64,
pub ldb: i64,
pub ldc: i64,
pub alpha: f32,
pub beta: f32,
}
impl GraphOpRecord for SpMvOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
validate_csr(&self.csr)?;
let _ = self.x.access()?;
let _ = self.y.access()?;
let _ = ctx;
Err(GpuError::Unrecoverable(
"graph::record::cusparse::SpMv: cuSPARSE capture-mode \
entry not yet wired (Phase 4 will revisit when the \
actor surface expands)"
.into(),
))
}
}
impl GraphOpRecord for SpMmOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
validate_csr(&self.csr)?;
let _ = self.b.access()?;
let _ = self.c.access()?;
if self.b_cols <= 0 || self.ldb <= 0 || self.ldc <= 0 {
return Err(GpuError::Unrecoverable(format!(
"SpMm: non-positive (b_cols, ldb, ldc) = ({}, {}, {})",
self.b_cols, self.ldb, self.ldc
)));
}
let _ = ctx;
Err(GpuError::Unrecoverable(
"graph::record::cusparse::SpMm: cuSPARSE capture-mode \
entry not yet wired"
.into(),
))
}
}
fn validate_csr(c: &CsrMatrix) -> Result<(), GpuError> {
if c.rows <= 0 || c.cols <= 0 || c.nnz < 0 {
return Err(GpuError::Unrecoverable(format!(
"CsrMatrix: invalid dims (rows={}, cols={}, nnz={})",
c.rows, c.cols, c.nnz
)));
}
let _ = c.row_offsets.access()?;
let _ = c.col_indices.access()?;
let _ = c.values.access()?;
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn spmv_op_records() {
struct Dims {
rows: i64,
cols: i64,
nnz: i64,
}
let bad = Dims {
rows: 4,
cols: 4,
nnz: -1,
};
assert!(bad.nnz < 0);
let good = Dims {
rows: 4,
cols: 4,
nnz: 0,
};
assert!(good.nnz >= 0);
}
#[test]
fn spmm_dim_validation_rejects_zero() {
let bad: (i64, i64, i64) = (0, 1, 1);
assert!(bad.0 <= 0 || bad.1 <= 0 || bad.2 <= 0);
let good: (i64, i64, i64) = (4, 4, 4);
assert!(good.0 > 0 && good.1 > 0 && good.2 > 0);
}
}