use crate::cuda::error::{CudaError, CudaResult};
use torsh_core::DType;
use super::handle::CudnnHandle;
#[cfg(feature = "cudnn")]
use cudnn_sys::*;
#[cfg(feature = "cudnn")]
use super::compat::{
cudnnCreateDropoutDescriptor, cudnnCreateRNNDataDescriptor, cudnnCreateRNNDescriptor,
cudnnDataType_t as CompatDataType_t, cudnnDestroyDropoutDescriptor,
cudnnDestroyRNNDataDescriptor, cudnnDestroyRNNDescriptor, cudnnDirectionMode_t,
cudnnDropoutDescriptor_t, cudnnDropoutGetStatesSize, cudnnForwardMode_t, cudnnMathType_t,
cudnnRNNAlgo_t, cudnnRNNDataDescriptor_t, cudnnRNNDataLayout_t, cudnnRNNDescriptor_t,
cudnnRNNInputMode_t, cudnnRNNMode_t, cudnnSetDropoutDescriptor, cudnnSetRNNDataDescriptor,
cudnnSetRNNDescriptor_v8,
};
#[derive(Debug, Clone, Copy)]
pub enum RNNInputMode {
LinearInput,
SkipInput,
}
impl RNNInputMode {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnRNNInputMode_t {
match self {
RNNInputMode::LinearInput => cudnnRNNInputMode_t::CUDNN_LINEAR_INPUT,
RNNInputMode::SkipInput => cudnnRNNInputMode_t::CUDNN_SKIP_INPUT,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RNNDirectionMode {
Unidirectional,
Bidirectional,
}
impl RNNDirectionMode {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnDirectionMode_t {
match self {
RNNDirectionMode::Unidirectional => cudnnDirectionMode_t::CUDNN_UNIDIRECTIONAL,
RNNDirectionMode::Bidirectional => cudnnDirectionMode_t::CUDNN_BIDIRECTIONAL,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RNNMode {
LSTM,
GRU,
RNNRelu,
RNNTanh,
}
impl RNNMode {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnRNNMode_t {
match self {
RNNMode::LSTM => cudnnRNNMode_t::CUDNN_LSTM,
RNNMode::GRU => cudnnRNNMode_t::CUDNN_GRU,
RNNMode::RNNRelu => cudnnRNNMode_t::CUDNN_RNN_RELU,
RNNMode::RNNTanh => cudnnRNNMode_t::CUDNN_RNN_TANH,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RNNAlgorithm {
Standard,
PersistStatic,
PersistDynamic,
}
impl RNNAlgorithm {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnRNNAlgo_t {
match self {
RNNAlgorithm::Standard => cudnnRNNAlgo_t::CUDNN_STANDARD,
RNNAlgorithm::PersistStatic => cudnnRNNAlgo_t::CUDNN_STATIC_PERSISTENT,
RNNAlgorithm::PersistDynamic => cudnnRNNAlgo_t::CUDNN_DYNAMIC_PERSISTENT,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum MathType {
Default,
TensorOp,
}
impl MathType {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnMathType_t {
match self {
MathType::Default => cudnnMathType_t::CUDNN_DEFAULT_MATH,
MathType::TensorOp => cudnnMathType_t::CUDNN_TENSOR_OP_MATH,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RNNForwardMode {
Training,
Inference,
}
impl RNNForwardMode {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnForwardMode_t {
match self {
RNNForwardMode::Training => cudnnForwardMode_t::CUDNN_FWD_MODE_TRAINING,
RNNForwardMode::Inference => cudnnForwardMode_t::CUDNN_FWD_MODE_INFERENCE,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RNNDataLayout {
SeqMajorUnpacked,
SeqMajorPacked,
BatchMajorUnpacked,
}
impl RNNDataLayout {
#[cfg(feature = "cudnn")]
pub(crate) fn to_cudnn(self) -> cudnnRNNDataLayout_t {
match self {
RNNDataLayout::SeqMajorUnpacked => {
cudnnRNNDataLayout_t::CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED
}
RNNDataLayout::SeqMajorPacked => {
cudnnRNNDataLayout_t::CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED
}
RNNDataLayout::BatchMajorUnpacked => {
cudnnRNNDataLayout_t::CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED
}
}
}
}
pub struct RNNDescriptor {
#[cfg(feature = "cudnn")]
descriptor: cudnnRNNDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl RNNDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut descriptor: cudnnRNNDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateRNNDescriptor(&mut descriptor) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create RNN descriptor: {:?}",
status
)));
}
Ok(Self { descriptor })
}
#[cfg(not(feature = "cudnn"))]
{
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set_lstm(
&mut self,
hidden_size: i32,
num_layers: i32,
dropout_desc: &DropoutDescriptor,
input_mode: RNNInputMode,
direction: RNNDirectionMode,
mode: RNNMode,
algorithm: RNNAlgorithm,
math_precision: MathType,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let status = unsafe {
cudnnSetRNNDescriptor_v8(
self.descriptor,
algorithm.to_cudnn(),
direction.to_cudnn(),
mode.to_cudnn(),
input_mode.to_cudnn(),
hidden_size,
num_layers,
dropout_desc.raw(),
0, math_precision.to_cudnn(),
std::ptr::null_mut(), )
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set RNN descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
hidden_size,
num_layers,
dropout_desc,
input_mode,
direction,
mode,
algorithm,
math_precision,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnRNNDescriptor_t {
self.descriptor
}
}
impl Drop for RNNDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.descriptor.is_null() {
unsafe {
let _status = cudnnDestroyRNNDescriptor(self.descriptor);
}
}
}
}
}
pub struct DropoutDescriptor {
#[cfg(feature = "cudnn")]
descriptor: cudnnDropoutDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl DropoutDescriptor {
pub fn new(handle: &CudnnHandle, dropout: f32, seed: u64) -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut descriptor: cudnnDropoutDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateDropoutDescriptor(&mut descriptor) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create dropout descriptor: {:?}",
status
)));
}
let mut states_size: usize = 0;
let status = unsafe { cudnnDropoutGetStatesSize(handle.raw(), &mut states_size) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to get dropout states size: {:?}",
status
)));
}
let states = std::ptr::null_mut();
let status = unsafe {
cudnnSetDropoutDescriptor(
descriptor,
handle.raw(),
dropout,
states,
states_size,
seed,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set dropout descriptor: {:?}",
status
)));
}
Ok(Self { descriptor })
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (handle, dropout, seed);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnDropoutDescriptor_t {
self.descriptor
}
}
impl Drop for DropoutDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.descriptor.is_null() {
unsafe {
let _status = cudnnDestroyDropoutDescriptor(self.descriptor);
}
}
}
}
}
pub struct RNNDataDescriptor {
#[cfg(feature = "cudnn")]
descriptor: cudnnRNNDataDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl RNNDataDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut descriptor: cudnnRNNDataDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateRNNDataDescriptor(&mut descriptor) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create RNN data descriptor: {:?}",
status
)));
}
Ok(Self { descriptor })
}
#[cfg(not(feature = "cudnn"))]
{
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set(
&mut self,
data_type: DType,
layout: RNNDataLayout,
max_seq_length: i32,
batch_size: i32,
vector_size: i32,
seq_length_array: &[i32],
padding_fill: Option<f32>,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_data_type = match data_type {
DType::F32 => CompatDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => CompatDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => CompatDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported data type for RNN: {:?}",
data_type
)))
}
};
let status = unsafe {
cudnnSetRNNDataDescriptor(
self.descriptor,
cudnn_data_type,
layout.to_cudnn(),
max_seq_length,
batch_size,
vector_size,
seq_length_array.as_ptr(),
padding_fill
.map(|f| &f as *const f32)
.unwrap_or(std::ptr::null()),
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set RNN data descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
data_type,
layout,
max_seq_length,
batch_size,
vector_size,
seq_length_array,
padding_fill,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnRNNDataDescriptor_t {
self.descriptor
}
}
impl Drop for RNNDataDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.descriptor.is_null() {
unsafe {
let _status = cudnnDestroyRNNDataDescriptor(self.descriptor);
}
}
}
}
}
unsafe impl Send for RNNDescriptor {}
unsafe impl Sync for RNNDescriptor {}
unsafe impl Send for DropoutDescriptor {}
unsafe impl Sync for DropoutDescriptor {}
unsafe impl Send for RNNDataDescriptor {}
unsafe impl Sync for RNNDataDescriptor {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rnn_descriptor_creation() {
#[cfg(feature = "cudnn")]
{
match RNNDescriptor::new() {
Ok(_desc) => {
assert!(true);
}
Err(_) => {
}
}
}
#[cfg(not(feature = "cudnn"))]
{
let result = RNNDescriptor::new();
assert!(result.is_err());
}
}
#[test]
fn test_rnn_data_descriptor_creation() {
#[cfg(feature = "cudnn")]
{
match RNNDataDescriptor::new() {
Ok(_desc) => {
assert!(true);
}
Err(_) => {
}
}
}
#[cfg(not(feature = "cudnn"))]
{
let result = RNNDataDescriptor::new();
assert!(result.is_err());
}
}
#[test]
fn test_rnn_enum_conversions() {
let input_mode = RNNInputMode::LinearInput;
let direction = RNNDirectionMode::Bidirectional;
let mode = RNNMode::LSTM;
let algorithm = RNNAlgorithm::Standard;
let math_type = MathType::Default;
let forward_mode = RNNForwardMode::Training;
let layout = RNNDataLayout::SeqMajorUnpacked;
assert_eq!(
std::mem::discriminant(&input_mode),
std::mem::discriminant(&RNNInputMode::LinearInput)
);
assert_eq!(
std::mem::discriminant(&direction),
std::mem::discriminant(&RNNDirectionMode::Bidirectional)
);
assert_eq!(
std::mem::discriminant(&mode),
std::mem::discriminant(&RNNMode::LSTM)
);
assert_eq!(
std::mem::discriminant(&algorithm),
std::mem::discriminant(&RNNAlgorithm::Standard)
);
assert_eq!(
std::mem::discriminant(&math_type),
std::mem::discriminant(&MathType::Default)
);
assert_eq!(
std::mem::discriminant(&forward_mode),
std::mem::discriminant(&RNNForwardMode::Training)
);
assert_eq!(
std::mem::discriminant(&layout),
std::mem::discriminant(&RNNDataLayout::SeqMajorUnpacked)
);
}
#[test]
fn test_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<RNNDescriptor>();
assert_sync::<RNNDescriptor>();
assert_send::<DropoutDescriptor>();
assert_sync::<DropoutDescriptor>();
assert_send::<RNNDataDescriptor>();
assert_sync::<RNNDataDescriptor>();
}
}