rten_tensor/
impl_debug.rs

1use std::fmt::{Debug, Error, Formatter};
2
3use crate::{AsView, Layout, MatrixLayout, MutLayout, NdTensorView, Storage, TensorBase};
4
5/// Entry in the formatted representation of a tensor's data.
6enum Entry<T: Debug> {
7    Value(T),
8
9    /// "..." used to elide long dimensions.
10    Ellipsis,
11}
12
13impl<T: Debug> Debug for Entry<T> {
14    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
15        match self {
16            Entry::Value(val) => write!(f, "{:?}", val),
17            Entry::Ellipsis => write!(f, "..."),
18        }
19    }
20}
21
22/// Configuration for debug formatting of a tensor.
23#[derive(Clone, Debug)]
24struct FormatOptions {
25    /// Maximum number of columns to print before eliding.
26    pub max_columns: usize,
27
28    /// Maximum number of rows to print before eliding.
29    pub max_rows: usize,
30
31    /// Maximum number of sub-matrices to print before eliding.
32    pub max_matrices: usize,
33}
34
35impl Default for FormatOptions {
36    fn default() -> Self {
37        FormatOptions {
38            max_columns: 10,
39            max_rows: 10,
40            max_matrices: 10,
41        }
42    }
43}
44
45/// A [`Debug`]-implementing wrapper around a tensor reference with custom
46/// formatting options.
47struct FormatTensor<'a, S: Storage, L: MutLayout> {
48    tensor: &'a TensorBase<S, L>,
49    opts: FormatOptions,
50}
51
52impl<'a, S: Storage, L: MutLayout> FormatTensor<'a, S, L> {
53    fn new(tensor: &'a TensorBase<S, L>, opts: FormatOptions) -> Self {
54        Self { tensor, opts }
55    }
56
57    /// Format a single vector of a tensor as a list (`[0, 1, 2, ... n]`).
58    fn write_vector<T: Debug>(
59        &self,
60        f: &mut Formatter<'_>,
61        row: impl ExactSizeIterator<Item = T> + Clone,
62    ) -> Result<(), Error> {
63        let len = row.len();
64
65        let head = row.clone().take(self.opts.max_columns / 2);
66        let tail = row
67            .clone()
68            .skip(self.opts.max_columns / 2)
69            .skip(len.saturating_sub(self.opts.max_columns));
70
71        let mut data_fmt = f.debug_list();
72        data_fmt.entries(head.map(Entry::Value));
73        if len > self.opts.max_columns {
74            data_fmt.entry(&Entry::<T>::Ellipsis);
75        }
76        data_fmt.entries(tail.map(Entry::Value));
77        data_fmt.finish()
78    }
79
80    /// Format a single sub-matrix from a tensor.
81    ///
82    /// `extra_indent` specifies the amount of additional indentation to
83    /// apply to rows after the first one. The first row is assumed not to
84    /// require any indentation.
85    fn write_matrix<T: Debug>(
86        &self,
87        f: &mut Formatter<'_>,
88        mat: NdTensorView<T, 2>,
89        extra_indent: usize,
90    ) -> Result<(), Error> {
91        write!(f, "[")?;
92        for row in 0..mat.rows().min(self.opts.max_rows) {
93            self.write_vector(f, mat.slice(row).iter())?;
94
95            if row < mat.rows().min(self.opts.max_rows) - 1 {
96                write!(f, ",\n{:>width$}", ' ', width = extra_indent + 1)?;
97            } else if mat.rows() > self.opts.max_rows {
98                write!(f, ",\n{}...", " ".repeat(extra_indent + 1))?;
99            }
100        }
101        write!(f, "]")?;
102        Ok(())
103    }
104}
105
106impl<S: Storage, L: MutLayout> Debug for FormatTensor<'_, S, L>
107where
108    S::Elem: Debug,
109{
110    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
111        let tensor = self.tensor;
112
113        match tensor.ndim() {
114            0 => write!(f, "({:?})", tensor.item().unwrap())?,
115            1 => self.write_vector(f, tensor.iter())?,
116            n => {
117                // Format tensors with >= 2 dims as a sequence of matrices.
118                let outer_dims = n - 2;
119                write!(f, "{}", "[".repeat(outer_dims))?;
120
121                let n_matrices: usize = tensor.shape().as_ref().iter().take(outer_dims).product();
122
123                for (i, mat) in tensor
124                    .inner_iter::<2>()
125                    .enumerate()
126                    .take(self.opts.max_matrices)
127                {
128                    if i > 0 {
129                        write!(f, "{}", " ".repeat(outer_dims))?;
130                    }
131
132                    self.write_matrix(f, mat, outer_dims)?;
133
134                    if i < n_matrices.min(self.opts.max_matrices) - 1 {
135                        write!(f, ",\n\n")?;
136                    } else if n_matrices > self.opts.max_matrices {
137                        write!(f, "\n\n{}...\n\n", " ".repeat(outer_dims))?;
138                    }
139                }
140
141                write!(f, "{}", "]".repeat(outer_dims))?;
142            }
143        }
144
145        write!(
146            f,
147            ", shape={:?}, strides={:?}",
148            tensor.shape(),
149            tensor.strides()
150        )
151    }
152}
153
154impl<S: Storage, L: MutLayout> Debug for TensorBase<S, L>
155where
156    S::Elem: Debug,
157{
158    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
159        write!(f, "{:?}", FormatTensor::new(self, FormatOptions::default()))
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use rten_testing::TestCases;
166
167    use super::{FormatOptions, FormatTensor};
168    use crate::Tensor;
169
170    #[test]
171    fn test_debug() {
172        #[derive(Clone, Debug)]
173        struct Case<'a> {
174            tensor: Tensor,
175            opts: FormatOptions,
176            expected: &'a str,
177        }
178
179        let cases = [
180            // Scalar
181            Case {
182                tensor: Tensor::from(2.),
183                opts: FormatOptions::default(),
184                expected: "(2.0), shape=[], strides=[]",
185            },
186            // Empty vector
187            Case {
188                tensor: Tensor::from([0.; 0]),
189                opts: FormatOptions::default(),
190                expected: "[], shape=[0], strides=[1]",
191            },
192            // Short vector
193            Case {
194                tensor: Tensor::from([1., 2., 3., 4.]),
195                opts: FormatOptions::default(),
196                expected: "[1.0, 2.0, 3.0, 4.0], shape=[4], strides=[1]",
197            },
198            // Small and large values
199            Case {
200                tensor: Tensor::from([1e-8, 1e-7]),
201                opts: FormatOptions::default(),
202                expected: "[1e-8, 1e-7], shape=[2], strides=[1]",
203            },
204            // Long vector
205            Case {
206                tensor: Tensor::arange(1., 22., None),
207                opts: FormatOptions {
208                    max_columns: 10,
209                    ..Default::default()
210                },
211                expected: "[1.0, 2.0, 3.0, 4.0, 5.0, ..., 17.0, 18.0, 19.0, 20.0, 21.0], shape=[21], strides=[1]",
212            },
213            // Matrix
214            Case {
215                tensor: Tensor::from([[1., 2.], [3., 4.]]),
216                opts: FormatOptions::default(),
217                expected: "
218[[1.0, 2.0],
219 [3.0, 4.0]], shape=[2, 2], strides=[2, 1]"
220                    .trim(),
221            },
222            // Matrix with elided rows
223            Case {
224                tensor: Tensor::from([[1., 2.], [3., 4.], [5., 6.]]),
225                opts: FormatOptions {
226                    max_rows: 2,
227                    ..Default::default()
228                },
229                expected: "
230[[1.0, 2.0],
231 [3.0, 4.0],
232 ...], shape=[3, 2], strides=[2, 1]"
233                    .trim(),
234            },
235            // 3D
236            Case {
237                tensor: Tensor::from([[[1., 2.], [3., 4.]]]),
238                opts: FormatOptions::default(),
239                expected: "
240[[[1.0, 2.0],
241  [3.0, 4.0]]], shape=[1, 2, 2], strides=[4, 2, 1]"
242                    .trim(),
243            },
244            // 3D
245            Case {
246                tensor: Tensor::from([
247                    [[1., 2.], [3., 4.]],
248                    [[1., 2.], [3., 4.]],
249                    [[1., 2.], [3., 4.]],
250                ]),
251                opts: FormatOptions {
252                    max_matrices: 2,
253                    ..Default::default()
254                },
255                expected: "
256[[[1.0, 2.0],
257  [3.0, 4.0]],
258
259 [[1.0, 2.0],
260  [3.0, 4.0]]
261
262 ...
263
264], shape=[3, 2, 2], strides=[4, 2, 1]"
265                    .trim(),
266            },
267            // 4D
268            Case {
269                tensor: Tensor::from([[[1., 2.], [3., 4.]]]).into_shape([1, 1, 2, 2].as_slice()),
270                opts: FormatOptions::default(),
271                expected: "
272[[[[1.0, 2.0],
273   [3.0, 4.0]]]], shape=[1, 1, 2, 2], strides=[4, 4, 2, 1]"
274                    .trim(),
275            },
276        ];
277
278        cases.test_each_clone(|case| {
279            let debug_str = format!("{:?}", FormatTensor::new(&case.tensor, case.opts));
280            assert_eq!(debug_str, case.expected);
281        })
282    }
283}