ndarray_layout/
fmt.rs

1use crate::ArrayLayout;
2use std::fmt;
3
4impl<const N: usize> ArrayLayout<N> {
5    /// 高维数组格式化。
6    ///
7    /// # Safety
8    ///
9    /// 这个函数从对裸指针解引用以获得要格式化的数组元素。
10    pub unsafe fn write_array<T: fmt::Display + Copy>(
11        &self,
12        f: &mut fmt::Formatter,
13        ptr: *const T,
14    ) -> fmt::Result {
15        match self.ndim() {
16            0 => {
17                write!(f, "array<> = [{}]", unsafe {
18                    ptr.byte_offset(self.offset()).read_unaligned()
19                })
20            }
21            1 => {
22                let &[n] = self.shape() else { unreachable!() };
23                let &[s] = self.strides() else { unreachable!() };
24
25                writeln!(f, "array<{n}>[")?;
26                let ptr = unsafe { ptr.byte_offset(self.offset()) };
27                for i in 0..n as isize {
28                    writeln!(f, "    {}", unsafe {
29                        ptr.byte_offset(i * s).read_unaligned()
30                    })?
31                }
32                writeln!(f, "]")?;
33                Ok(())
34            }
35            _ => {
36                let mut title = "array<".to_string();
37                for d in self.shape() {
38                    title.push_str(&format!("{d}x"))
39                }
40                assert_eq!(title.pop(), Some('x'));
41                title.push('>');
42
43                let mut stack = Vec::with_capacity(self.ndim() - 2);
44                self.write_recursive(f, ptr, &title, &mut stack)
45            }
46        }
47    }
48
49    fn write_recursive<T: fmt::Display>(
50        &self,
51        f: &mut fmt::Formatter,
52        ptr: *const T,
53        title: &str,
54        indices: &mut Vec<usize>,
55    ) -> fmt::Result {
56        match *self.shape() {
57            [] | [_] => unreachable!(),
58            [rows, cols] => {
59                write!(f, "{title}[")?;
60                for i in indices {
61                    write!(f, "{i}, ")?
62                }
63                writeln!(f, "..]")?;
64
65                let &[rs, cs] = self.strides() else {
66                    unreachable!()
67                };
68
69                let ptr = unsafe { ptr.byte_offset(self.offset()) };
70                for r in 0..rows as isize {
71                    for c in 0..cols as isize {
72                        write!(f, "{} ", unsafe {
73                            ptr.byte_offset(r * rs + c * cs).read_unaligned()
74                        })?
75                    }
76                    writeln!(f)?
77                }
78            }
79            [batch, ..] => {
80                for i in 0..batch {
81                    indices.push(i);
82                    self.index(0, i).write_recursive(f, ptr, title, indices)?;
83                    indices.pop();
84                }
85            }
86        }
87        Ok(())
88    }
89}
90
91#[test]
92fn test() {
93    const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
94
95    struct Tensor(ArrayLayout<4>);
96
97    impl fmt::Display for Tensor {
98        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99            unsafe { self.0.write_array(f, DATA.as_ptr()) }
100        }
101    }
102
103    let tensor = Tensor(ArrayLayout::<4>::new_contiguous(
104        &[DATA.len()],
105        crate::Endian::BigEndian,
106        1,
107    ));
108    println!("{}", tensor);
109
110    let tensor = Tensor(tensor.0.tile_be(0, &[1, DATA.len()]).broadcast(0, 6));
111    println!("{}", tensor);
112
113    let tensor = Tensor(tensor.0.tile_be(0, &[2, 3]).tile_be(2, &[5, 2]));
114    println!("{}", tensor);
115
116    let tensor = Tensor(ArrayLayout::<4>::with_ndim(0));
117    println!("{}", tensor);
118}