1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub enum MemoryLayout {
6 C,
8 Fortran,
10 Custom,
12}
13
14impl MemoryLayout {
15 #[inline]
17 pub fn is_c_contiguous(self) -> bool {
18 self == Self::C
19 }
20
21 #[inline]
23 pub fn is_f_contiguous(self) -> bool {
24 self == Self::Fortran
25 }
26
27 #[inline]
29 pub fn is_custom(self) -> bool {
30 self == Self::Custom
31 }
32}
33
34impl core::fmt::Display for MemoryLayout {
35 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36 match self {
37 Self::C => write!(f, "C"),
38 Self::Fortran => write!(f, "F"),
39 Self::Custom => write!(f, "Custom"),
40 }
41 }
42}
43
44#[cfg(feature = "std")]
50#[inline]
51pub(crate) fn classify_layout(
52 is_standard: bool,
53 shape: &[usize],
54 strides: &[isize],
55) -> MemoryLayout {
56 if is_standard {
57 MemoryLayout::C
58 } else {
59 detect_layout(shape, strides)
60 }
61}
62
63#[cfg(feature = "std")]
65pub(crate) fn detect_layout(shape: &[usize], strides: &[isize]) -> MemoryLayout {
66 if shape.is_empty() {
67 return MemoryLayout::C; }
69
70 let is_c = is_c_contiguous(shape, strides);
71 let is_f = is_f_contiguous(shape, strides);
72
73 if is_c {
74 MemoryLayout::C
75 } else if is_f {
76 MemoryLayout::Fortran
77 } else {
78 MemoryLayout::Custom
79 }
80}
81
82#[cfg(feature = "std")]
83fn is_c_contiguous(shape: &[usize], strides: &[isize]) -> bool {
84 if shape.len() != strides.len() {
85 return false;
86 }
87 let ndim = shape.len();
88 if ndim == 0 {
89 return true;
90 }
91 let mut expected: isize = 1;
92 for i in (0..ndim).rev() {
93 if shape[i] == 0 {
94 return true; }
96 if shape[i] != 1 && strides[i] != expected {
97 return false;
98 }
99 expected = strides[i] * shape[i] as isize;
100 }
101 true
102}
103
104#[cfg(feature = "std")]
105fn is_f_contiguous(shape: &[usize], strides: &[isize]) -> bool {
106 if shape.len() != strides.len() {
107 return false;
108 }
109 let ndim = shape.len();
110 if ndim == 0 {
111 return true;
112 }
113 let mut expected: isize = 1;
114 for i in 0..ndim {
115 if shape[i] == 0 {
116 return true; }
118 if shape[i] != 1 && strides[i] != expected {
119 return false;
120 }
121 expected = strides[i] * shape[i] as isize;
122 }
123 true
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn detect_c_contiguous() {
132 assert_eq!(detect_layout(&[3, 4], &[4, 1]), MemoryLayout::C);
134 }
135
136 #[test]
137 fn detect_f_contiguous() {
138 assert_eq!(detect_layout(&[3, 4], &[1, 3]), MemoryLayout::Fortran);
140 }
141
142 #[test]
143 fn detect_custom() {
144 assert_eq!(detect_layout(&[3, 4], &[8, 2]), MemoryLayout::Custom);
146 }
147
148 #[test]
149 fn detect_empty() {
150 assert_eq!(detect_layout(&[], &[]), MemoryLayout::C);
151 }
152
153 #[test]
154 fn display() {
155 assert_eq!(MemoryLayout::C.to_string(), "C");
156 assert_eq!(MemoryLayout::Fortran.to_string(), "F");
157 assert_eq!(MemoryLayout::Custom.to_string(), "Custom");
158 }
159}