Skip to main content

hanzo_ml/
display.rs

1//! Pretty printing of tensors
2//!
3//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
4//!
5use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9    fn fmt_dt<T: WithDType + std::fmt::Display>(
10        &self,
11        f: &mut std::fmt::Formatter,
12    ) -> std::fmt::Result {
13        let device_str = match self.device().location() {
14            crate::DeviceLocation::Cpu => "".to_owned(),
15            crate::DeviceLocation::Cuda { gpu_id } => {
16                format!(", cuda:{gpu_id}")
17            }
18            crate::DeviceLocation::Metal { gpu_id } => {
19                format!(", metal:{gpu_id}")
20            }
21            #[cfg(feature = "rocm")]
22            crate::DeviceLocation::Rocm { gpu_id } => {
23                format!(", rocm:{gpu_id}")
24            }
25            #[cfg(feature = "vulkan")]
26            crate::DeviceLocation::Vulkan { gpu_id } => {
27                format!(", vulkan:{gpu_id}")
28            }
29        };
30
31        write!(f, "Tensor[")?;
32        match self.dims() {
33            [] => {
34                if let Ok(v) = self.to_scalar::<T>() {
35                    write!(f, "{v}")?
36                }
37            }
38            [s] if *s < 10 => {
39                if let Ok(vs) = self.to_vec1::<T>() {
40                    for (i, v) in vs.iter().enumerate() {
41                        if i > 0 {
42                            write!(f, ", ")?;
43                        }
44                        write!(f, "{v}")?;
45                    }
46                }
47            }
48            dims => {
49                write!(f, "dims ")?;
50                for (i, d) in dims.iter().enumerate() {
51                    if i > 0 {
52                        write!(f, ", ")?;
53                    }
54                    write!(f, "{d}")?;
55                }
56            }
57        }
58        write!(f, "; {}{}]", self.dtype().as_str(), device_str)
59    }
60}
61
62impl std::fmt::Debug for Tensor {
63    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64        match self.dtype() {
65            DType::U8 => self.fmt_dt::<u8>(f),
66            DType::U32 => self.fmt_dt::<u32>(f),
67            DType::I16 => self.fmt_dt::<i16>(f),
68            DType::I32 => self.fmt_dt::<i32>(f),
69            DType::I64 => self.fmt_dt::<i64>(f),
70            DType::BF16 => self.fmt_dt::<bf16>(f),
71            DType::F16 => self.fmt_dt::<f16>(f),
72            DType::F32 => self.fmt_dt::<f32>(f),
73            DType::F64 => self.fmt_dt::<f64>(f),
74            DType::F8E4M3 => self.fmt_dt::<float8::F8E4M3>(f),
75            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
76                write!(
77                    f,
78                    "Tensor[{:?}; dtype={}, unsupported dummy type]",
79                    self.shape(),
80                    self.dtype().as_str()
81                )
82            }
83        }
84    }
85}
86
87/// Options for Tensor pretty printing
88#[derive(Debug, Clone)]
89pub struct PrinterOptions {
90    pub precision: usize,
91    pub threshold: usize,
92    pub edge_items: usize,
93    pub line_width: usize,
94    pub sci_mode: Option<bool>,
95}
96
97static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
98    std::sync::Mutex::new(PrinterOptions::const_default());
99
100impl PrinterOptions {
101    // We cannot use the default trait as it's not const.
102    const fn const_default() -> Self {
103        Self {
104            precision: 4,
105            threshold: 1000,
106            edge_items: 3,
107            line_width: 80,
108            sci_mode: None,
109        }
110    }
111}
112
113pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
114    &PRINT_OPTS
115}
116
117pub fn set_print_options(options: PrinterOptions) {
118    *PRINT_OPTS.lock().unwrap() = options
119}
120
121pub fn set_print_options_default() {
122    *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
123}
124
125pub fn set_print_options_short() {
126    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
127        precision: 2,
128        threshold: 1000,
129        edge_items: 2,
130        line_width: 80,
131        sci_mode: None,
132    }
133}
134
135pub fn set_print_options_full() {
136    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
137        precision: 4,
138        threshold: usize::MAX,
139        edge_items: 3,
140        line_width: 80,
141        sci_mode: None,
142    }
143}
144
145pub fn set_line_width(line_width: usize) {
146    PRINT_OPTS.lock().unwrap().line_width = line_width
147}
148
149pub fn set_precision(precision: usize) {
150    PRINT_OPTS.lock().unwrap().precision = precision
151}
152
153pub fn set_edge_items(edge_items: usize) {
154    PRINT_OPTS.lock().unwrap().edge_items = edge_items
155}
156
157pub fn set_threshold(threshold: usize) {
158    PRINT_OPTS.lock().unwrap().threshold = threshold
159}
160
161pub fn set_sci_mode(sci_mode: Option<bool>) {
162    PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
163}
164
165struct FmtSize {
166    current_size: usize,
167}
168
169impl FmtSize {
170    fn new() -> Self {
171        Self { current_size: 0 }
172    }
173
174    fn final_size(self) -> usize {
175        self.current_size
176    }
177}
178
179impl std::fmt::Write for FmtSize {
180    fn write_str(&mut self, s: &str) -> std::fmt::Result {
181        self.current_size += s.len();
182        Ok(())
183    }
184}
185
186trait TensorFormatter {
187    type Elem: WithDType;
188
189    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
190
191    fn max_width(&self, to_display: &Tensor) -> usize {
192        let mut max_width = 1;
193        if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
194            for &v in vs.iter() {
195                let mut fmt_size = FmtSize::new();
196                let _res = self.fmt(v, 1, &mut fmt_size);
197                max_width = usize::max(max_width, fmt_size.final_size())
198            }
199        }
200        max_width
201    }
202
203    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
204        writeln!(f)?;
205        for _ in 0..i {
206            write!(f, " ")?
207        }
208        Ok(())
209    }
210
211    fn fmt_tensor(
212        &self,
213        t: &Tensor,
214        indent: usize,
215        max_w: usize,
216        summarize: bool,
217        po: &PrinterOptions,
218        f: &mut std::fmt::Formatter,
219    ) -> std::fmt::Result {
220        let dims = t.dims();
221        let edge_items = po.edge_items;
222        write!(f, "[")?;
223        match dims {
224            [] => {
225                if let Ok(v) = t.to_scalar::<Self::Elem>() {
226                    self.fmt(v, max_w, f)?
227                }
228            }
229            [v] if summarize && *v > 2 * edge_items => {
230                if let Ok(vs) = t
231                    .narrow(0, 0, edge_items)
232                    .and_then(|t| t.to_vec1::<Self::Elem>())
233                {
234                    for v in vs.into_iter() {
235                        self.fmt(v, max_w, f)?;
236                        write!(f, ", ")?;
237                    }
238                }
239                write!(f, "...")?;
240                if let Ok(vs) = t
241                    .narrow(0, v - edge_items, edge_items)
242                    .and_then(|t| t.to_vec1::<Self::Elem>())
243                {
244                    for v in vs.into_iter() {
245                        write!(f, ", ")?;
246                        self.fmt(v, max_w, f)?;
247                    }
248                }
249            }
250            [_] => {
251                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
252                if let Ok(vs) = t.to_vec1::<Self::Elem>() {
253                    for (i, v) in vs.into_iter().enumerate() {
254                        if i > 0 {
255                            if i % elements_per_line == 0 {
256                                write!(f, ",")?;
257                                Self::write_newline_indent(indent, f)?
258                            } else {
259                                write!(f, ", ")?;
260                            }
261                        }
262                        self.fmt(v, max_w, f)?
263                    }
264                }
265            }
266            _ => {
267                if summarize && dims[0] > 2 * edge_items {
268                    for i in 0..edge_items {
269                        match t.get(i) {
270                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
271                            Err(e) => write!(f, "{e:?}")?,
272                        }
273                        write!(f, ",")?;
274                        Self::write_newline_indent(indent, f)?
275                    }
276                    write!(f, "...")?;
277                    Self::write_newline_indent(indent, f)?;
278                    for i in dims[0] - edge_items..dims[0] {
279                        match t.get(i) {
280                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
281                            Err(e) => write!(f, "{e:?}")?,
282                        }
283                        if i + 1 != dims[0] {
284                            write!(f, ",")?;
285                            Self::write_newline_indent(indent, f)?
286                        }
287                    }
288                } else {
289                    for i in 0..dims[0] {
290                        match t.get(i) {
291                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
292                            Err(e) => write!(f, "{e:?}")?,
293                        }
294                        if i + 1 != dims[0] {
295                            write!(f, ",")?;
296                            Self::write_newline_indent(indent, f)?
297                        }
298                    }
299                }
300            }
301        }
302        write!(f, "]")?;
303        Ok(())
304    }
305}
306
307struct FloatFormatter<S: WithDType> {
308    int_mode: bool,
309    sci_mode: bool,
310    precision: usize,
311    _phantom: std::marker::PhantomData<S>,
312}
313
314impl<S> FloatFormatter<S>
315where
316    S: WithDType + num_traits::Float + std::fmt::Display,
317{
318    fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
319        let mut int_mode = true;
320        let mut sci_mode = false;
321
322        // Rather than containing all values, this should only include
323        // values that end up being displayed according to [threshold].
324        let values = t
325            .flatten_all()?
326            .to_vec1()?
327            .into_iter()
328            .filter(|v: &S| v.is_finite() && !v.is_zero())
329            .collect::<Vec<_>>();
330        if !values.is_empty() {
331            let mut nonzero_finite_min = S::max_value();
332            let mut nonzero_finite_max = S::min_value();
333            for &v in values.iter() {
334                let v = v.abs();
335                if v < nonzero_finite_min {
336                    nonzero_finite_min = v
337                }
338                if v > nonzero_finite_max {
339                    nonzero_finite_max = v
340                }
341            }
342
343            for &value in values.iter() {
344                if value.ceil() != value {
345                    int_mode = false;
346                    break;
347                }
348            }
349            if let Some(v1) = S::from(1000.) {
350                if let Some(v2) = S::from(1e8) {
351                    if let Some(v3) = S::from(1e-4) {
352                        sci_mode = nonzero_finite_max / nonzero_finite_min > v1
353                            || nonzero_finite_max > v2
354                            || nonzero_finite_min < v3
355                    }
356                }
357            }
358        }
359
360        match po.sci_mode {
361            None => {}
362            Some(v) => sci_mode = v,
363        }
364        Ok(Self {
365            int_mode,
366            sci_mode,
367            precision: po.precision,
368            _phantom: std::marker::PhantomData,
369        })
370    }
371}
372
373impl<S> TensorFormatter for FloatFormatter<S>
374where
375    S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
376{
377    type Elem = S;
378
379    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
380        if self.sci_mode {
381            write!(
382                f,
383                "{v:width$.prec$e}",
384                v = v,
385                width = max_w,
386                prec = self.precision
387            )
388        } else if self.int_mode {
389            if v.is_finite() {
390                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
391            } else {
392                write!(f, "{v:max_w$.0}")
393            }
394        } else {
395            write!(
396                f,
397                "{v:width$.prec$}",
398                v = v,
399                width = max_w,
400                prec = self.precision
401            )
402        }
403    }
404}
405
406struct IntFormatter<S: WithDType> {
407    _phantom: std::marker::PhantomData<S>,
408}
409
410impl<S: WithDType> IntFormatter<S> {
411    fn new() -> Self {
412        Self {
413            _phantom: std::marker::PhantomData,
414        }
415    }
416}
417
418impl<S> TensorFormatter for IntFormatter<S>
419where
420    S: WithDType + std::fmt::Display,
421{
422    type Elem = S;
423
424    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
425        write!(f, "{v:max_w$}")
426    }
427}
428
429fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
430    let dims = t.dims();
431    if dims.is_empty() {
432        Ok(t.clone())
433    } else if dims.len() == 1 {
434        if dims[0] > 2 * edge_items {
435            Tensor::cat(
436                &[
437                    t.narrow(0, 0, edge_items)?,
438                    t.narrow(0, dims[0] - edge_items, edge_items)?,
439                ],
440                0,
441            )
442        } else {
443            Ok(t.clone())
444        }
445    } else if dims[0] > 2 * edge_items {
446        let mut vs: Vec<_> = (0..edge_items)
447            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
448            .collect::<Result<Vec<_>>>()?;
449        for i in (dims[0] - edge_items)..dims[0] {
450            vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
451        }
452        Tensor::cat(&vs, 0)
453    } else {
454        let vs: Vec<_> = (0..dims[0])
455            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
456            .collect::<Result<Vec<_>>>()?;
457        Tensor::cat(&vs, 0)
458    }
459}
460
461impl std::fmt::Display for Tensor {
462    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
463        let po = PRINT_OPTS.lock().unwrap();
464        let summarize = self.elem_count() > po.threshold;
465        let to_display = if summarize {
466            match get_summarized_data(self, po.edge_items) {
467                Ok(v) => v,
468                Err(err) => return write!(f, "{err:?}"),
469            }
470        } else {
471            self.clone()
472        };
473        match self.dtype() {
474            DType::U8 => {
475                let tf: IntFormatter<u8> = IntFormatter::new();
476                let max_w = tf.max_width(&to_display);
477                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
478                writeln!(f)?;
479            }
480            DType::U32 => {
481                let tf: IntFormatter<u32> = IntFormatter::new();
482                let max_w = tf.max_width(&to_display);
483                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
484                writeln!(f)?;
485            }
486            DType::I16 => {
487                let tf: IntFormatter<i16> = IntFormatter::new();
488                let max_w = tf.max_width(&to_display);
489                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
490                writeln!(f)?;
491            }
492            DType::I32 => {
493                let tf: IntFormatter<i32> = IntFormatter::new();
494                let max_w = tf.max_width(&to_display);
495                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
496                writeln!(f)?;
497            }
498            DType::I64 => {
499                let tf: IntFormatter<i64> = IntFormatter::new();
500                let max_w = tf.max_width(&to_display);
501                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
502                writeln!(f)?;
503            }
504            DType::BF16 => {
505                if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
506                    let max_w = tf.max_width(&to_display);
507                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
508                    writeln!(f)?;
509                }
510            }
511            DType::F16 => {
512                if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
513                    let max_w = tf.max_width(&to_display);
514                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
515                    writeln!(f)?;
516                }
517            }
518            DType::F64 => {
519                if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
520                    let max_w = tf.max_width(&to_display);
521                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
522                    writeln!(f)?;
523                }
524            }
525            DType::F32 => {
526                if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
527                    let max_w = tf.max_width(&to_display);
528                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
529                    writeln!(f)?;
530                }
531            }
532            DType::F8E4M3 => {
533                if let Ok(tf) = FloatFormatter::<float8::F8E4M3>::new(&to_display, &po) {
534                    let max_w = tf.max_width(&to_display);
535                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
536                    writeln!(f)?;
537                }
538            }
539            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
540                writeln!(
541                    f,
542                    "Dummy type {} (not supported for display)",
543                    self.dtype().as_str()
544                )?;
545            }
546        };
547
548        let device_str = match self.device().location() {
549            crate::DeviceLocation::Cpu => "".to_owned(),
550            crate::DeviceLocation::Cuda { gpu_id } => {
551                format!(", cuda:{gpu_id}")
552            }
553            crate::DeviceLocation::Metal { gpu_id } => {
554                format!(", metal:{gpu_id}")
555            }
556            #[cfg(feature = "rocm")]
557            crate::DeviceLocation::Rocm { gpu_id } => {
558                format!(", rocm:{gpu_id}")
559            }
560            #[cfg(feature = "vulkan")]
561            crate::DeviceLocation::Vulkan { gpu_id } => {
562                format!(", vulkan:{gpu_id}")
563            }
564        };
565
566        write!(
567            f,
568            "Tensor[{:?}, {}{}]",
569            self.dims(),
570            self.dtype().as_str(),
571            device_str
572        )
573    }
574}