#![cfg(feature = "cudnn")]
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::graph::{GraphOpRecord, GraphRecordCtx};
use crate::kernel::{ActivationKind, ConvParams};
pub struct ConvForwardOp {
pub x: GpuRef<f32>,
pub x_dims: [i32; 4],
pub w: GpuRef<f32>,
pub w_dims: [i32; 4],
pub y: GpuRef<f32>,
pub y_dims: [i32; 4],
pub conv: ConvParams,
pub alpha: f32,
pub beta: f32,
}
pub struct ActivationOp {
pub kind: ActivationKind,
pub x: GpuRef<f32>,
pub y: GpuRef<f32>,
pub dims: [i32; 4],
pub alpha: f32,
pub beta: f32,
}
pub struct SoftmaxOp {
pub x: GpuRef<f32>,
pub y: GpuRef<f32>,
pub dims: [i32; 4],
pub alpha: f32,
pub beta: f32,
}
impl GraphOpRecord for ConvForwardOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
validate_dims(&self.x_dims, "conv: x_dims")?;
validate_dims(&self.w_dims, "conv: w_dims")?;
validate_dims(&self.y_dims, "conv: y_dims")?;
let _ = self.x.access()?;
let _ = self.w.access()?;
let _ = self.y.access()?;
let _ = ctx;
Err(GpuError::Unrecoverable(
"graph::record::cudnn::ConvForward: cuDNN capture-mode \
entry not yet wired (Phase 3 surface only)"
.into(),
))
}
}
impl GraphOpRecord for ActivationOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
validate_dims(&self.dims, "activation: dims")?;
let _ = self.x.access()?;
let _ = self.y.access()?;
let _ = ctx;
Err(GpuError::Unrecoverable(
"graph::record::cudnn::Activation: cuDNN capture-mode \
entry not yet wired"
.into(),
))
}
}
impl GraphOpRecord for SoftmaxOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
validate_dims(&self.dims, "softmax: dims")?;
let _ = self.x.access()?;
let _ = self.y.access()?;
let _ = ctx;
Err(GpuError::Unrecoverable(
"graph::record::cudnn::Softmax: cuDNN capture-mode \
entry not yet wired"
.into(),
))
}
}
fn validate_dims(d: &[i32; 4], who: &str) -> Result<(), GpuError> {
if d.iter().any(|&x| x <= 0) {
Err(GpuError::Unrecoverable(format!(
"{who}: non-positive dim in {d:?}"
)))
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceState;
use crate::graph::MockGraphRecordCtx;
use cudarc::driver::sys as driver_sys;
use std::sync::Arc;
fn dead_gpu_ref() -> GpuRef<f32> {
let _ = DeviceState::new(0);
unimplemented!("not used — dim-validation tests cover the path")
}
#[test]
fn conv_op_records() {
let null_graph: driver_sys::CUgraph = std::ptr::null_mut();
let mock = MockGraphRecordCtx::new(null_graph);
let ctx = mock.as_ctx();
assert!(validate_dims(&[1, 1, 1, 1], "ok").is_ok());
assert!(validate_dims(&[0, 1, 1, 1], "bad").is_err());
let _ = dead_gpu_ref;
let _ = Arc::new(()) as Arc<()>;
let _ = ctx;
}
#[test]
fn activation_op_records() {
assert!(validate_dims(&[1, 2, 3, 4], "ok").is_ok());
assert!(validate_dims(&[1, 2, 3, -1], "bad").is_err());
}
#[test]
fn softmax_op_records() {
assert!(validate_dims(&[2, 4, 1, 1], "ok").is_ok());
}
}