use super::{Attribute, AttributeLike};
use crate::{Context, Error};
use mlir_sys::{
MlirAttribute, mlirStridedLayoutAttrGet, mlirStridedLayoutAttrGetNumStrides,
mlirStridedLayoutAttrGetOffset, mlirStridedLayoutAttrGetStride,
};
#[derive(Clone, Copy, Hash)]
pub struct StridedLayoutAttribute<'c> {
attribute: Attribute<'c>,
}
impl<'c> StridedLayoutAttribute<'c> {
pub fn new(context: &'c Context, offset: i64, strides: &[i64]) -> Self {
unsafe {
Self::from_raw(mlirStridedLayoutAttrGet(
context.to_raw(),
offset,
strides.len() as isize,
strides.as_ptr(),
))
}
}
pub fn offset(&self) -> i64 {
unsafe { mlirStridedLayoutAttrGetOffset(self.to_raw()) }
}
pub fn stride_count(&self) -> usize {
(unsafe { mlirStridedLayoutAttrGetNumStrides(self.to_raw()) }) as usize
}
pub fn stride(&self, index: usize) -> Result<i64, Error> {
if index < self.stride_count() {
Ok(unsafe { mlirStridedLayoutAttrGetStride(self.to_raw(), index as isize) })
} else {
Err(Error::PositionOutOfBounds {
name: "stride",
value: self.to_string(),
index,
})
}
}
}
attribute_traits!(StridedLayoutAttribute, is_strided_layout, "strided layout");
#[cfg(test)]
mod tests {
use super::*;
use crate::test::create_test_context;
#[test]
fn new() {
let context = create_test_context();
StridedLayoutAttribute::new(&context, 0, &[1, 2, 3]);
}
#[test]
fn offset() {
let context = create_test_context();
let attribute = StridedLayoutAttribute::new(&context, 42, &[]);
assert_eq!(attribute.offset(), 42);
}
#[test]
fn stride_count() {
let context = create_test_context();
let attribute = StridedLayoutAttribute::new(&context, 0, &[1, 2, 3]);
assert_eq!(attribute.stride_count(), 3);
}
#[test]
fn stride() {
let context = create_test_context();
let attribute = StridedLayoutAttribute::new(&context, 0, &[4, 5, 6]);
assert_eq!(attribute.stride(0).unwrap(), 4);
assert_eq!(attribute.stride(1).unwrap(), 5);
assert_eq!(attribute.stride(2).unwrap(), 6);
assert!(matches!(
attribute.stride(3),
Err(Error::PositionOutOfBounds { .. })
));
}
}