bunsen 0.21.0-pre.2

bunsen is a community companion library for burn
Documentation
use burn::prelude;
use strum;

/// A meta-descriptor for [`burn::tensor::TensorKind`].
#[derive(
    Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, strum::EnumString, strum::Display,
)]
#[non_exhaustive]
pub enum TensorKindDesc {
    /// A Bool Tensor
    /// Equivalent to [`burn::tensor::Bool`].
    Bool,

    /// A Float Tensor
    /// Equivalent to [`burn::tensor::Float`].
    Float,

    /// An Int Tensor
    /// Equivalent to [`burn::tensor::Int`].
    Int,
}

impl TensorKindDesc {
    /// Get the [`TensorKindDesc`] for a kit [`burn::tensor::TensorKind`].
    pub const fn for_kind<K: ParamKindBinding>() -> Self {
        K::KIND
    }
}

/// A trait that binds a kit Tensor Kind to a `ParamKind`.
pub trait ParamKindBinding {
    /// The [`TensorKindDesc`] kind wrapper.
    const KIND: TensorKindDesc;
}

impl ParamKindBinding for prelude::Bool {
    const KIND: TensorKindDesc = TensorKindDesc::Bool;
}

impl ParamKindBinding for prelude::Float {
    const KIND: TensorKindDesc = TensorKindDesc::Float;
}

impl ParamKindBinding for prelude::Int {
    const KIND: TensorKindDesc = TensorKindDesc::Int;
}

#[cfg(test)]
mod tests {
    use burn::tensor;

    use crate::kit::descriptors::TensorKindDesc;

    #[test]
    fn test_tensor_kinds() {
        assert_eq!(
            TensorKindDesc::for_kind::<tensor::Bool>(),
            TensorKindDesc::Bool
        );
        assert_eq!(
            TensorKindDesc::for_kind::<tensor::Float>(),
            TensorKindDesc::Float
        );
        assert_eq!(
            TensorKindDesc::for_kind::<tensor::Int>(),
            TensorKindDesc::Int
        );
    }
}