rten_tensor/
impl_debug.rs

1use std::fmt::{Debug, Error, Formatter};
2
3use crate::layout::{Layout, MatrixLayout};
4use crate::{AsView, NdTensorView, Storage, TensorBase};
5
6/// Entry in the formatted representation of a tensor's data.
7enum Entry<T: Debug> {
8    Value(T),
9
10    /// "..." used to elide long dimensions.
11    Ellipsis,
12}
13
14impl<T: Debug> Debug for Entry<T> {
15    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
16        match self {
17            Entry::Value(val) => write!(f, "{:?}", val),
18            Entry::Ellipsis => write!(f, "..."),
19        }
20    }
21}
22
23/// Configuration for debug formatting of a tensor.
24#[derive(Clone, Debug)]
25struct FormatOptions {
26    /// Maximum number of columns to print before eliding.
27    pub max_columns: usize,
28
29    /// Maximum number of rows to print before eliding.
30    pub max_rows: usize,
31
32    /// Maximum number of sub-matrices to print before eliding.
33    pub max_matrices: usize,
34}
35
36impl Default for FormatOptions {
37    fn default() -> Self {
38        FormatOptions {
39            max_columns: 10,
40            max_rows: 10,
41            max_matrices: 10,
42        }
43    }
44}
45
46/// A [`Debug`]-implementing wrapper around a tensor reference with custom
47/// formatting options.
48struct FormatTensor<'a, S: Storage, L: Layout> {
49    tensor: &'a TensorBase<S, L>,
50    opts: FormatOptions,
51}
52
53impl<'a, S: Storage, L: Layout> FormatTensor<'a, S, L> {
54    fn new(tensor: &'a TensorBase<S, L>, opts: FormatOptions) -> Self {
55        Self { tensor, opts }
56    }
57
58    /// Format a single vector of a tensor as a list (`[0, 1, 2, ... n]`).
59    fn write_vector<T: Debug>(
60        &self,
61        f: &mut Formatter<'_>,
62        row: impl ExactSizeIterator<Item = T> + Clone,
63    ) -> Result<(), Error> {
64        let len = row.len();
65
66        let head = row.clone().take(self.opts.max_columns / 2);
67        let tail = row
68            .clone()
69            .skip(self.opts.max_columns / 2)
70            .skip(len.saturating_sub(self.opts.max_columns));
71
72        let mut data_fmt = f.debug_list();
73        data_fmt.entries(head.map(Entry::Value));
74        if len > self.opts.max_columns {
75            data_fmt.entry(&Entry::<T>::Ellipsis);
76        }
77        data_fmt.entries(tail.map(Entry::Value));
78        data_fmt.finish()
79    }
80
81    /// Format a single sub-matrix from a tensor.
82    ///
83    /// `extra_indent` specifies the amount of additional indentation to
84    /// apply to rows after the first one. The first row is assumed not to
85    /// require any indentation.
86    fn write_matrix<T: Debug>(
87        &self,
88        f: &mut Formatter<'_>,
89        mat: NdTensorView<T, 2>,
90        extra_indent: usize,
91    ) -> Result<(), Error> {
92        write!(f, "[")?;
93        for row in 0..mat.rows().min(self.opts.max_rows) {
94            self.write_vector(f, mat.slice(row).iter())?;
95
96            if row < mat.rows().min(self.opts.max_rows) - 1 {
97                write!(f, ",\n{:>width$}", ' ', width = extra_indent + 1)?;
98            } else if mat.rows() > self.opts.max_rows {
99                write!(f, ",\n{}...", " ".repeat(extra_indent + 1))?;
100            }
101        }
102        write!(f, "]")?;
103        Ok(())
104    }
105}
106
107impl<S: Storage, L: Layout + Clone> Debug for FormatTensor<'_, S, L>
108where
109    S::Elem: Debug,
110{
111    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
112        let tensor = self.tensor;
113
114        match tensor.ndim() {
115            0 => write!(f, "({:?})", tensor.item().unwrap())?,
116            1 => self.write_vector(f, tensor.iter())?,
117            n => {
118                // Format tensors with >= 2 dims as a sequence of matrices.
119                let outer_dims = n - 2;
120                write!(f, "{}", "[".repeat(outer_dims))?;
121
122                let n_matrices: usize = tensor.shape().as_ref().iter().take(outer_dims).product();
123
124                for (i, mat) in tensor
125                    .inner_iter::<2>()
126                    .enumerate()
127                    .take(self.opts.max_matrices)
128                {
129                    if i > 0 {
130                        write!(f, "{}", " ".repeat(outer_dims))?;
131                    }
132
133                    self.write_matrix(f, mat, outer_dims)?;
134
135                    if i < n_matrices.min(self.opts.max_matrices) - 1 {
136                        write!(f, ",\n\n")?;
137                    } else if n_matrices > self.opts.max_matrices {
138                        write!(f, "\n\n{}...\n\n", " ".repeat(outer_dims))?;
139                    }
140                }
141
142                write!(f, "{}", "]".repeat(outer_dims))?;
143            }
144        }
145
146        write!(
147            f,
148            ", shape={:?}, strides={:?}",
149            tensor.shape(),
150            tensor.strides()
151        )
152    }
153}
154
155impl<S: Storage, L: Layout + Clone> Debug for TensorBase<S, L>
156where
157    S::Elem: Debug,
158{
159    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
160        write!(f, "{:?}", FormatTensor::new(self, FormatOptions::default()))
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use rten_testing::TestCases;
167
168    use super::{FormatOptions, FormatTensor};
169    use crate::Tensor;
170
171    #[test]
172    fn test_debug() {
173        #[derive(Clone, Debug)]
174        struct Case<'a> {
175            tensor: Tensor,
176            opts: FormatOptions,
177            expected: &'a str,
178        }
179
180        let cases = [
181            // Scalar
182            Case {
183                tensor: Tensor::from(2.),
184                opts: FormatOptions::default(),
185                expected: "(2.0), shape=[], strides=[]",
186            },
187            // Empty vector
188            Case {
189                tensor: Tensor::from([0.; 0]),
190                opts: FormatOptions::default(),
191                expected: "[], shape=[0], strides=[1]",
192            },
193            // Short vector
194            Case {
195                tensor: Tensor::from([1., 2., 3., 4.]),
196                opts: FormatOptions::default(),
197                expected: "[1.0, 2.0, 3.0, 4.0], shape=[4], strides=[1]",
198            },
199            // Small and large values
200            Case {
201                tensor: Tensor::from([1e-8, 1e-7]),
202                opts: FormatOptions::default(),
203                expected: "[1e-8, 1e-7], shape=[2], strides=[1]",
204            },
205            // Long vector
206            Case {
207                tensor: Tensor::arange(1., 22., None),
208                opts: FormatOptions {
209                    max_columns: 10,
210                    ..Default::default()
211                },
212                expected: "[1.0, 2.0, 3.0, 4.0, 5.0, ..., 17.0, 18.0, 19.0, 20.0, 21.0], shape=[21], strides=[1]",
213            },
214            // Matrix
215            Case {
216                tensor: Tensor::from([[1., 2.], [3., 4.]]),
217                opts: FormatOptions::default(),
218                expected: "
219[[1.0, 2.0],
220 [3.0, 4.0]], shape=[2, 2], strides=[2, 1]"
221                    .trim(),
222            },
223            // Matrix with elided rows
224            Case {
225                tensor: Tensor::from([[1., 2.], [3., 4.], [5., 6.]]),
226                opts: FormatOptions {
227                    max_rows: 2,
228                    ..Default::default()
229                },
230                expected: "
231[[1.0, 2.0],
232 [3.0, 4.0],
233 ...], shape=[3, 2], strides=[2, 1]"
234                    .trim(),
235            },
236            // 3D
237            Case {
238                tensor: Tensor::from([[[1., 2.], [3., 4.]]]),
239                opts: FormatOptions::default(),
240                expected: "
241[[[1.0, 2.0],
242  [3.0, 4.0]]], shape=[1, 2, 2], strides=[4, 2, 1]"
243                    .trim(),
244            },
245            // 3D
246            Case {
247                tensor: Tensor::from([
248                    [[1., 2.], [3., 4.]],
249                    [[1., 2.], [3., 4.]],
250                    [[1., 2.], [3., 4.]],
251                ]),
252                opts: FormatOptions {
253                    max_matrices: 2,
254                    ..Default::default()
255                },
256                expected: "
257[[[1.0, 2.0],
258  [3.0, 4.0]],
259
260 [[1.0, 2.0],
261  [3.0, 4.0]]
262
263 ...
264
265], shape=[3, 2, 2], strides=[4, 2, 1]"
266                    .trim(),
267            },
268            // 4D
269            Case {
270                tensor: Tensor::from([[[1., 2.], [3., 4.]]]).into_shape([1, 1, 2, 2].as_slice()),
271                opts: FormatOptions::default(),
272                expected: "
273[[[[1.0, 2.0],
274   [3.0, 4.0]]]], shape=[1, 1, 2, 2], strides=[4, 4, 2, 1]"
275                    .trim(),
276            },
277        ];
278
279        cases.test_each_clone(|case| {
280            let debug_str = format!("{:?}", FormatTensor::new(&case.tensor, case.opts));
281            assert_eq!(debug_str, case.expected);
282        })
283    }
284}