1use std::fmt::{Debug, Error, Formatter};
2
3use crate::{AsView, Layout, MatrixLayout, MutLayout, NdTensorView, Storage, TensorBase};
4
5enum Entry<T: Debug> {
7 Value(T),
8
9 Ellipsis,
11}
12
13impl<T: Debug> Debug for Entry<T> {
14 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
15 match self {
16 Entry::Value(val) => write!(f, "{:?}", val),
17 Entry::Ellipsis => write!(f, "..."),
18 }
19 }
20}
21
22#[derive(Clone, Debug)]
24struct FormatOptions {
25 pub max_columns: usize,
27
28 pub max_rows: usize,
30
31 pub max_matrices: usize,
33}
34
35impl Default for FormatOptions {
36 fn default() -> Self {
37 FormatOptions {
38 max_columns: 10,
39 max_rows: 10,
40 max_matrices: 10,
41 }
42 }
43}
44
45struct FormatTensor<'a, S: Storage, L: MutLayout> {
48 tensor: &'a TensorBase<S, L>,
49 opts: FormatOptions,
50}
51
52impl<'a, S: Storage, L: MutLayout> FormatTensor<'a, S, L> {
53 fn new(tensor: &'a TensorBase<S, L>, opts: FormatOptions) -> Self {
54 Self { tensor, opts }
55 }
56
57 fn write_vector<T: Debug>(
59 &self,
60 f: &mut Formatter<'_>,
61 row: impl ExactSizeIterator<Item = T> + Clone,
62 ) -> Result<(), Error> {
63 let len = row.len();
64
65 let head = row.clone().take(self.opts.max_columns / 2);
66 let tail = row
67 .clone()
68 .skip(self.opts.max_columns / 2)
69 .skip(len.saturating_sub(self.opts.max_columns));
70
71 let mut data_fmt = f.debug_list();
72 data_fmt.entries(head.map(Entry::Value));
73 if len > self.opts.max_columns {
74 data_fmt.entry(&Entry::<T>::Ellipsis);
75 }
76 data_fmt.entries(tail.map(Entry::Value));
77 data_fmt.finish()
78 }
79
80 fn write_matrix<T: Debug>(
86 &self,
87 f: &mut Formatter<'_>,
88 mat: NdTensorView<T, 2>,
89 extra_indent: usize,
90 ) -> Result<(), Error> {
91 write!(f, "[")?;
92 for row in 0..mat.rows().min(self.opts.max_rows) {
93 self.write_vector(f, mat.slice(row).iter())?;
94
95 if row < mat.rows().min(self.opts.max_rows) - 1 {
96 write!(f, ",\n{:>width$}", ' ', width = extra_indent + 1)?;
97 } else if mat.rows() > self.opts.max_rows {
98 write!(f, ",\n{}...", " ".repeat(extra_indent + 1))?;
99 }
100 }
101 write!(f, "]")?;
102 Ok(())
103 }
104}
105
106impl<S: Storage, L: MutLayout> Debug for FormatTensor<'_, S, L>
107where
108 S::Elem: Debug,
109{
110 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
111 let tensor = self.tensor;
112
113 match tensor.ndim() {
114 0 => write!(f, "({:?})", tensor.item().unwrap())?,
115 1 => self.write_vector(f, tensor.iter())?,
116 n => {
117 let outer_dims = n - 2;
119 write!(f, "{}", "[".repeat(outer_dims))?;
120
121 let n_matrices: usize = tensor.shape().as_ref().iter().take(outer_dims).product();
122
123 for (i, mat) in tensor
124 .inner_iter::<2>()
125 .enumerate()
126 .take(self.opts.max_matrices)
127 {
128 if i > 0 {
129 write!(f, "{}", " ".repeat(outer_dims))?;
130 }
131
132 self.write_matrix(f, mat, outer_dims)?;
133
134 if i < n_matrices.min(self.opts.max_matrices) - 1 {
135 write!(f, ",\n\n")?;
136 } else if n_matrices > self.opts.max_matrices {
137 write!(f, "\n\n{}...\n\n", " ".repeat(outer_dims))?;
138 }
139 }
140
141 write!(f, "{}", "]".repeat(outer_dims))?;
142 }
143 }
144
145 write!(
146 f,
147 ", shape={:?}, strides={:?}",
148 tensor.shape(),
149 tensor.strides()
150 )
151 }
152}
153
154impl<S: Storage, L: MutLayout> Debug for TensorBase<S, L>
155where
156 S::Elem: Debug,
157{
158 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
159 write!(f, "{:?}", FormatTensor::new(self, FormatOptions::default()))
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use rten_testing::TestCases;
166
167 use super::{FormatOptions, FormatTensor};
168 use crate::Tensor;
169
170 #[test]
171 fn test_debug() {
172 #[derive(Clone, Debug)]
173 struct Case<'a> {
174 tensor: Tensor,
175 opts: FormatOptions,
176 expected: &'a str,
177 }
178
179 let cases = [
180 Case {
182 tensor: Tensor::from(2.),
183 opts: FormatOptions::default(),
184 expected: "(2.0), shape=[], strides=[]",
185 },
186 Case {
188 tensor: Tensor::from([0.; 0]),
189 opts: FormatOptions::default(),
190 expected: "[], shape=[0], strides=[1]",
191 },
192 Case {
194 tensor: Tensor::from([1., 2., 3., 4.]),
195 opts: FormatOptions::default(),
196 expected: "[1.0, 2.0, 3.0, 4.0], shape=[4], strides=[1]",
197 },
198 Case {
200 tensor: Tensor::from([1e-8, 1e-7]),
201 opts: FormatOptions::default(),
202 expected: "[1e-8, 1e-7], shape=[2], strides=[1]",
203 },
204 Case {
206 tensor: Tensor::arange(1., 22., None),
207 opts: FormatOptions {
208 max_columns: 10,
209 ..Default::default()
210 },
211 expected: "[1.0, 2.0, 3.0, 4.0, 5.0, ..., 17.0, 18.0, 19.0, 20.0, 21.0], shape=[21], strides=[1]",
212 },
213 Case {
215 tensor: Tensor::from([[1., 2.], [3., 4.]]),
216 opts: FormatOptions::default(),
217 expected: "
218[[1.0, 2.0],
219 [3.0, 4.0]], shape=[2, 2], strides=[2, 1]"
220 .trim(),
221 },
222 Case {
224 tensor: Tensor::from([[1., 2.], [3., 4.], [5., 6.]]),
225 opts: FormatOptions {
226 max_rows: 2,
227 ..Default::default()
228 },
229 expected: "
230[[1.0, 2.0],
231 [3.0, 4.0],
232 ...], shape=[3, 2], strides=[2, 1]"
233 .trim(),
234 },
235 Case {
237 tensor: Tensor::from([[[1., 2.], [3., 4.]]]),
238 opts: FormatOptions::default(),
239 expected: "
240[[[1.0, 2.0],
241 [3.0, 4.0]]], shape=[1, 2, 2], strides=[4, 2, 1]"
242 .trim(),
243 },
244 Case {
246 tensor: Tensor::from([
247 [[1., 2.], [3., 4.]],
248 [[1., 2.], [3., 4.]],
249 [[1., 2.], [3., 4.]],
250 ]),
251 opts: FormatOptions {
252 max_matrices: 2,
253 ..Default::default()
254 },
255 expected: "
256[[[1.0, 2.0],
257 [3.0, 4.0]],
258
259 [[1.0, 2.0],
260 [3.0, 4.0]]
261
262 ...
263
264], shape=[3, 2, 2], strides=[4, 2, 1]"
265 .trim(),
266 },
267 Case {
269 tensor: Tensor::from([[[1., 2.], [3., 4.]]]).into_shape([1, 1, 2, 2].as_slice()),
270 opts: FormatOptions::default(),
271 expected: "
272[[[[1.0, 2.0],
273 [3.0, 4.0]]]], shape=[1, 1, 2, 2], strides=[4, 4, 2, 1]"
274 .trim(),
275 },
276 ];
277
278 cases.test_each_clone(|case| {
279 let debug_str = format!("{:?}", FormatTensor::new(&case.tensor, case.opts));
280 assert_eq!(debug_str, case.expected);
281 })
282 }
283}