1use std::fmt::{Debug, Error, Formatter};
2
3use crate::layout::{Layout, MatrixLayout};
4use crate::{AsView, NdTensorView, Storage, TensorBase};
5
6enum Entry<T: Debug> {
8 Value(T),
9
10 Ellipsis,
12}
13
14impl<T: Debug> Debug for Entry<T> {
15 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
16 match self {
17 Entry::Value(val) => write!(f, "{:?}", val),
18 Entry::Ellipsis => write!(f, "..."),
19 }
20 }
21}
22
23#[derive(Clone, Debug)]
25struct FormatOptions {
26 pub max_columns: usize,
28
29 pub max_rows: usize,
31
32 pub max_matrices: usize,
34}
35
36impl Default for FormatOptions {
37 fn default() -> Self {
38 FormatOptions {
39 max_columns: 10,
40 max_rows: 10,
41 max_matrices: 10,
42 }
43 }
44}
45
46struct FormatTensor<'a, S: Storage, L: Layout> {
49 tensor: &'a TensorBase<S, L>,
50 opts: FormatOptions,
51}
52
53impl<'a, S: Storage, L: Layout> FormatTensor<'a, S, L> {
54 fn new(tensor: &'a TensorBase<S, L>, opts: FormatOptions) -> Self {
55 Self { tensor, opts }
56 }
57
58 fn write_vector<T: Debug>(
60 &self,
61 f: &mut Formatter<'_>,
62 row: impl ExactSizeIterator<Item = T> + Clone,
63 ) -> Result<(), Error> {
64 let len = row.len();
65
66 let head = row.clone().take(self.opts.max_columns / 2);
67 let tail = row
68 .clone()
69 .skip(self.opts.max_columns / 2)
70 .skip(len.saturating_sub(self.opts.max_columns));
71
72 let mut data_fmt = f.debug_list();
73 data_fmt.entries(head.map(Entry::Value));
74 if len > self.opts.max_columns {
75 data_fmt.entry(&Entry::<T>::Ellipsis);
76 }
77 data_fmt.entries(tail.map(Entry::Value));
78 data_fmt.finish()
79 }
80
81 fn write_matrix<T: Debug>(
87 &self,
88 f: &mut Formatter<'_>,
89 mat: NdTensorView<T, 2>,
90 extra_indent: usize,
91 ) -> Result<(), Error> {
92 write!(f, "[")?;
93 for row in 0..mat.rows().min(self.opts.max_rows) {
94 self.write_vector(f, mat.slice(row).iter())?;
95
96 if row < mat.rows().min(self.opts.max_rows) - 1 {
97 write!(f, ",\n{:>width$}", ' ', width = extra_indent + 1)?;
98 } else if mat.rows() > self.opts.max_rows {
99 write!(f, ",\n{}...", " ".repeat(extra_indent + 1))?;
100 }
101 }
102 write!(f, "]")?;
103 Ok(())
104 }
105}
106
107impl<S: Storage, L: Layout + Clone> Debug for FormatTensor<'_, S, L>
108where
109 S::Elem: Debug,
110{
111 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
112 let tensor = self.tensor;
113
114 match tensor.ndim() {
115 0 => write!(f, "({:?})", tensor.item().unwrap())?,
116 1 => self.write_vector(f, tensor.iter())?,
117 n => {
118 let outer_dims = n - 2;
120 write!(f, "{}", "[".repeat(outer_dims))?;
121
122 let n_matrices: usize = tensor.shape().as_ref().iter().take(outer_dims).product();
123
124 for (i, mat) in tensor
125 .inner_iter::<2>()
126 .enumerate()
127 .take(self.opts.max_matrices)
128 {
129 if i > 0 {
130 write!(f, "{}", " ".repeat(outer_dims))?;
131 }
132
133 self.write_matrix(f, mat, outer_dims)?;
134
135 if i < n_matrices.min(self.opts.max_matrices) - 1 {
136 write!(f, ",\n\n")?;
137 } else if n_matrices > self.opts.max_matrices {
138 write!(f, "\n\n{}...\n\n", " ".repeat(outer_dims))?;
139 }
140 }
141
142 write!(f, "{}", "]".repeat(outer_dims))?;
143 }
144 }
145
146 write!(
147 f,
148 ", shape={:?}, strides={:?}",
149 tensor.shape(),
150 tensor.strides()
151 )
152 }
153}
154
155impl<S: Storage, L: Layout + Clone> Debug for TensorBase<S, L>
156where
157 S::Elem: Debug,
158{
159 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
160 write!(f, "{:?}", FormatTensor::new(self, FormatOptions::default()))
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use rten_testing::TestCases;
167
168 use super::{FormatOptions, FormatTensor};
169 use crate::Tensor;
170
171 #[test]
172 fn test_debug() {
173 #[derive(Clone, Debug)]
174 struct Case<'a> {
175 tensor: Tensor,
176 opts: FormatOptions,
177 expected: &'a str,
178 }
179
180 let cases = [
181 Case {
183 tensor: Tensor::from(2.),
184 opts: FormatOptions::default(),
185 expected: "(2.0), shape=[], strides=[]",
186 },
187 Case {
189 tensor: Tensor::from([0.; 0]),
190 opts: FormatOptions::default(),
191 expected: "[], shape=[0], strides=[1]",
192 },
193 Case {
195 tensor: Tensor::from([1., 2., 3., 4.]),
196 opts: FormatOptions::default(),
197 expected: "[1.0, 2.0, 3.0, 4.0], shape=[4], strides=[1]",
198 },
199 Case {
201 tensor: Tensor::from([1e-8, 1e-7]),
202 opts: FormatOptions::default(),
203 expected: "[1e-8, 1e-7], shape=[2], strides=[1]",
204 },
205 Case {
207 tensor: Tensor::arange(1., 22., None),
208 opts: FormatOptions {
209 max_columns: 10,
210 ..Default::default()
211 },
212 expected: "[1.0, 2.0, 3.0, 4.0, 5.0, ..., 17.0, 18.0, 19.0, 20.0, 21.0], shape=[21], strides=[1]",
213 },
214 Case {
216 tensor: Tensor::from([[1., 2.], [3., 4.]]),
217 opts: FormatOptions::default(),
218 expected: "
219[[1.0, 2.0],
220 [3.0, 4.0]], shape=[2, 2], strides=[2, 1]"
221 .trim(),
222 },
223 Case {
225 tensor: Tensor::from([[1., 2.], [3., 4.], [5., 6.]]),
226 opts: FormatOptions {
227 max_rows: 2,
228 ..Default::default()
229 },
230 expected: "
231[[1.0, 2.0],
232 [3.0, 4.0],
233 ...], shape=[3, 2], strides=[2, 1]"
234 .trim(),
235 },
236 Case {
238 tensor: Tensor::from([[[1., 2.], [3., 4.]]]),
239 opts: FormatOptions::default(),
240 expected: "
241[[[1.0, 2.0],
242 [3.0, 4.0]]], shape=[1, 2, 2], strides=[4, 2, 1]"
243 .trim(),
244 },
245 Case {
247 tensor: Tensor::from([
248 [[1., 2.], [3., 4.]],
249 [[1., 2.], [3., 4.]],
250 [[1., 2.], [3., 4.]],
251 ]),
252 opts: FormatOptions {
253 max_matrices: 2,
254 ..Default::default()
255 },
256 expected: "
257[[[1.0, 2.0],
258 [3.0, 4.0]],
259
260 [[1.0, 2.0],
261 [3.0, 4.0]]
262
263 ...
264
265], shape=[3, 2, 2], strides=[4, 2, 1]"
266 .trim(),
267 },
268 Case {
270 tensor: Tensor::from([[[1., 2.], [3., 4.]]]).into_shape([1, 1, 2, 2].as_slice()),
271 opts: FormatOptions::default(),
272 expected: "
273[[[[1.0, 2.0],
274 [3.0, 4.0]]]], shape=[1, 1, 2, 2], strides=[4, 4, 2, 1]"
275 .trim(),
276 },
277 ];
278
279 cases.test_each_clone(|case| {
280 let debug_str = format!("{:?}", FormatTensor::new(&case.tensor, case.opts));
281 assert_eq!(debug_str, case.expected);
282 })
283 }
284}