cubecl-zspace 0.10.0

CubeCL ZSpace Library.
Documentation
//! # Common Burn Errors

use alloc::string::String;
use core::{
    error::Error,
    fmt::{Display, Formatter},
    ops::Range,
};

use crate::{Shape, Strides};

/// Describes the kind of an index.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum IndexKind {
    /// The index of an element in a dimension.
    Element,

    /// The index of a dimension.
    Dimension,
}

impl IndexKind {
    /// Get the display name of the kind.
    pub fn name(&self) -> &'static str {
        match self {
            IndexKind::Element => "element",
            IndexKind::Dimension => "dimension",
        }
    }
}

/// Access Bounds Error.
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum BoundsError {
    /// Generic bounds error.
    Generic(String),

    /// Index out of bounds.
    Index {
        /// The kind of index that was out of bounds.
        kind: IndexKind,

        /// The index that was out of bounds.
        index: isize,

        /// The range of valid indices.
        bounds: Range<isize>,
    },
}

impl BoundsError {
    /// Create a new index error.
    pub fn index(kind: IndexKind, index: isize, bounds: Range<isize>) -> Self {
        Self::Index {
            kind,
            index,
            bounds,
        }
    }
}

impl core::fmt::Display for BoundsError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::Generic(msg) => write!(f, "BoundsError: {}", msg),
            Self::Index {
                kind,
                index,
                bounds: range,
            } => write!(
                f,
                "BoundsError: {} {} out of bounds: {:?}",
                kind.name(),
                index,
                range
            ),
        }
    }
}

impl core::error::Error for BoundsError {}

/// Common Expression Error.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExpressionError {
    /// Parse Error.
    ParseError {
        /// The error message.
        message: String,
        /// The source expression.
        source: String,
    },

    /// Invalid Expression.
    InvalidExpression {
        /// The error message.
        message: String,
        /// The source expression.
        source: String,
    },
}

impl core::fmt::Display for ExpressionError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::ParseError { message, source } => {
                write!(f, "ExpressionError: ParseError: {} ({})", message, source)
            }
            Self::InvalidExpression { message, source } => write!(
                f,
                "ExpressionError: InvalidExpression: {} ({})",
                message, source
            ),
        }
    }
}

impl core::error::Error for ExpressionError {}

impl ExpressionError {
    /// Constructs a new [`ExpressionError::ParseError`].
    ///
    /// This function is a utility for creating instances where a parsing error needs to be represented,
    /// encapsulating a descriptive error message and the source of the error.
    ///
    /// # Parameters
    ///
    /// - `message`: A value that can be converted into a `String`, representing a human-readable description
    ///   of the parsing error.
    /// - `source`: A value that can be converted into a `String`, typically identifying the origin or
    ///   input that caused the parsing error.
    pub fn parse_error(message: impl Into<String>, source: impl Into<String>) -> Self {
        Self::ParseError {
            message: message.into(),
            source: source.into(),
        }
    }

    /// Creates a new [`ExpressionError::InvalidExpression`].
    ///
    /// # Parameters
    /// - `message`: A detailed message describing the nature of the invalid expression.
    ///   Accepts any type that can be converted into a `String`.
    /// - `source`: The source or context in which the invalid expression occurred.
    ///   Accepts any type that can be converted into a `String`.
    pub fn invalid_expression(message: impl Into<String>, source: impl Into<String>) -> Self {
        Self::InvalidExpression {
            message: message.into(),
            source: source.into(),
        }
    }
}

/// Collected shape/stride record.
///
/// As this is used for error messages, there is no expectation that this is valid,
/// or that the ranks match.
#[derive(Debug, Clone, PartialEq)]
pub struct StrideRecord {
    pub shape: Shape,
    pub strides: Strides,
}

impl StrideRecord {
    /// Create a new [`StrideRecord`] from a slice of usize strides.
    pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord {
        StrideRecord {
            shape: shape.into(),
            strides: strides.iter().map(|s| *s as isize).collect(),
        }
    }

    /// Create a new [`StrideRecord`] from a slice of isize strides.
    pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord {
        StrideRecord {
            shape: shape.into(),
            strides: strides.into(),
        }
    }
}

/// Error describing striding issues.
#[derive(Debug, Clone, PartialEq)]
pub enum StrideError {
    /// The ranks of the shape and strides do not match.
    MalformedRanks { record: StrideRecord },

    /// This is an unsupported rank.
    UnsupportedRank { rank: usize, record: StrideRecord },

    /// The strides violate a constraint.
    Invalid {
        message: String,
        record: StrideRecord,
    },
}

impl Display for StrideError {
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
        match self {
            StrideError::MalformedRanks { record } => write!(f, "Malformed strides: {:?}", record),
            StrideError::UnsupportedRank { rank, record } => {
                write!(f, "Unsupported rank {}: {:?}", rank, record)
            }
            StrideError::Invalid { message, record } => {
                write!(f, "Invalid strides: {}: {:?}", message, record)
            }
        }
    }
}

impl Error for StrideError {}

#[cfg(test)]
mod tests {
    use super::*;
    use alloc::format;
    use alloc::string::ToString;

    #[test]
    fn test_bounds_error_display() {
        assert_eq!(
            format!("{}", BoundsError::Generic("test".to_string())),
            "BoundsError: test"
        );
        assert_eq!(
            format!(
                "{}",
                BoundsError::Index {
                    kind: IndexKind::Element,
                    index: 1,
                    bounds: 0..2
                }
            ),
            "BoundsError: element 1 out of bounds: 0..2"
        );
    }

    #[test]
    fn test_parse_error() {
        let err = ExpressionError::parse_error("test", "source");
        assert_eq!(
            format!("{:?}", err),
            "ParseError { message: \"test\", source: \"source\" }"
        );
    }

    #[test]
    fn test_invalid_expression() {
        let err = ExpressionError::invalid_expression("test", "source");
        assert_eq!(
            format!("{:?}", err),
            "InvalidExpression { message: \"test\", source: \"source\" }"
        );
    }
}