use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::graph::{GraphOp, GraphRecordCtx};
use crate::kernel::record::{BlasRecorder, BlasSgemmOp, RecordMode};
pub struct SgemmOp {
inner: Option<BlasSgemmOp>,
}
impl SgemmOp {
pub fn new(
a: GpuRef<f32>,
b: GpuRef<f32>,
c: GpuRef<f32>,
m: i32,
n: i32,
k: i32,
alpha: f32,
beta: f32,
) -> Self {
Self {
inner: Some(BlasSgemmOp {
a,
b,
c,
m,
n,
k,
alpha,
beta,
}),
}
}
}
impl GraphOp for SgemmOp {
fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
let stream = ctx.require_stream()?;
let blas = ctx.blas.ok_or_else(|| {
GpuError::Unrecoverable("SgemmOp::record: cuBLAS handle not available in ctx".into())
})?;
let op = self
.inner
.take()
.ok_or_else(|| GpuError::Unrecoverable("SgemmOp::record: already consumed".into()))?;
let mut recorder = BlasRecorder { handle: blas };
recorder.enqueue_record(stream, op)
}
fn op_name(&self) -> &'static str {
"graph::sgemm"
}
}