candle_core/
display.rs

1//! Pretty printing of tensors
2//!
3//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
4//!
5use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9    fn fmt_dt<T: WithDType + std::fmt::Display>(
10        &self,
11        f: &mut std::fmt::Formatter,
12    ) -> std::fmt::Result {
13        let device_str = match self.device().location() {
14            crate::DeviceLocation::Cpu => "".to_owned(),
15            crate::DeviceLocation::Cuda { gpu_id } => {
16                format!(", cuda:{gpu_id}")
17            }
18            crate::DeviceLocation::Metal { gpu_id } => {
19                format!(", metal:{gpu_id}")
20            }
21        };
22
23        write!(f, "Tensor[")?;
24        match self.dims() {
25            [] => {
26                if let Ok(v) = self.to_scalar::<T>() {
27                    write!(f, "{v}")?
28                }
29            }
30            [s] if *s < 10 => {
31                if let Ok(vs) = self.to_vec1::<T>() {
32                    for (i, v) in vs.iter().enumerate() {
33                        if i > 0 {
34                            write!(f, ", ")?;
35                        }
36                        write!(f, "{v}")?;
37                    }
38                }
39            }
40            dims => {
41                write!(f, "dims ")?;
42                for (i, d) in dims.iter().enumerate() {
43                    if i > 0 {
44                        write!(f, ", ")?;
45                    }
46                    write!(f, "{d}")?;
47                }
48            }
49        }
50        write!(f, "; {}{}]", self.dtype().as_str(), device_str)
51    }
52}
53
54impl std::fmt::Debug for Tensor {
55    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56        match self.dtype() {
57            DType::U8 => self.fmt_dt::<u8>(f),
58            DType::U32 => self.fmt_dt::<u32>(f),
59            DType::I16 => self.fmt_dt::<i16>(f),
60            DType::I32 => self.fmt_dt::<i32>(f),
61            DType::I64 => self.fmt_dt::<i64>(f),
62            DType::BF16 => self.fmt_dt::<bf16>(f),
63            DType::F16 => self.fmt_dt::<f16>(f),
64            DType::F32 => self.fmt_dt::<f32>(f),
65            DType::F64 => self.fmt_dt::<f64>(f),
66            DType::F8E4M3 => self.fmt_dt::<float8::F8E4M3>(f),
67            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
68                write!(
69                    f,
70                    "Tensor[{:?}; dtype={}, unsupported dummy type]",
71                    self.shape(),
72                    self.dtype().as_str()
73                )
74            }
75        }
76    }
77}
78
79/// Options for Tensor pretty printing
80#[derive(Debug, Clone)]
81pub struct PrinterOptions {
82    pub precision: usize,
83    pub threshold: usize,
84    pub edge_items: usize,
85    pub line_width: usize,
86    pub sci_mode: Option<bool>,
87}
88
89static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
90    std::sync::Mutex::new(PrinterOptions::const_default());
91
92impl PrinterOptions {
93    // We cannot use the default trait as it's not const.
94    const fn const_default() -> Self {
95        Self {
96            precision: 4,
97            threshold: 1000,
98            edge_items: 3,
99            line_width: 80,
100            sci_mode: None,
101        }
102    }
103}
104
105pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
106    &PRINT_OPTS
107}
108
109pub fn set_print_options(options: PrinterOptions) {
110    *PRINT_OPTS.lock().unwrap() = options
111}
112
113pub fn set_print_options_default() {
114    *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
115}
116
117pub fn set_print_options_short() {
118    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
119        precision: 2,
120        threshold: 1000,
121        edge_items: 2,
122        line_width: 80,
123        sci_mode: None,
124    }
125}
126
127pub fn set_print_options_full() {
128    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
129        precision: 4,
130        threshold: usize::MAX,
131        edge_items: 3,
132        line_width: 80,
133        sci_mode: None,
134    }
135}
136
137pub fn set_line_width(line_width: usize) {
138    PRINT_OPTS.lock().unwrap().line_width = line_width
139}
140
141pub fn set_precision(precision: usize) {
142    PRINT_OPTS.lock().unwrap().precision = precision
143}
144
145pub fn set_edge_items(edge_items: usize) {
146    PRINT_OPTS.lock().unwrap().edge_items = edge_items
147}
148
149pub fn set_threshold(threshold: usize) {
150    PRINT_OPTS.lock().unwrap().threshold = threshold
151}
152
153pub fn set_sci_mode(sci_mode: Option<bool>) {
154    PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
155}
156
157struct FmtSize {
158    current_size: usize,
159}
160
161impl FmtSize {
162    fn new() -> Self {
163        Self { current_size: 0 }
164    }
165
166    fn final_size(self) -> usize {
167        self.current_size
168    }
169}
170
171impl std::fmt::Write for FmtSize {
172    fn write_str(&mut self, s: &str) -> std::fmt::Result {
173        self.current_size += s.len();
174        Ok(())
175    }
176}
177
178trait TensorFormatter {
179    type Elem: WithDType;
180
181    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
182
183    fn max_width(&self, to_display: &Tensor) -> usize {
184        let mut max_width = 1;
185        if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
186            for &v in vs.iter() {
187                let mut fmt_size = FmtSize::new();
188                let _res = self.fmt(v, 1, &mut fmt_size);
189                max_width = usize::max(max_width, fmt_size.final_size())
190            }
191        }
192        max_width
193    }
194
195    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
196        writeln!(f)?;
197        for _ in 0..i {
198            write!(f, " ")?
199        }
200        Ok(())
201    }
202
203    fn fmt_tensor(
204        &self,
205        t: &Tensor,
206        indent: usize,
207        max_w: usize,
208        summarize: bool,
209        po: &PrinterOptions,
210        f: &mut std::fmt::Formatter,
211    ) -> std::fmt::Result {
212        let dims = t.dims();
213        let edge_items = po.edge_items;
214        write!(f, "[")?;
215        match dims {
216            [] => {
217                if let Ok(v) = t.to_scalar::<Self::Elem>() {
218                    self.fmt(v, max_w, f)?
219                }
220            }
221            [v] if summarize && *v > 2 * edge_items => {
222                if let Ok(vs) = t
223                    .narrow(0, 0, edge_items)
224                    .and_then(|t| t.to_vec1::<Self::Elem>())
225                {
226                    for v in vs.into_iter() {
227                        self.fmt(v, max_w, f)?;
228                        write!(f, ", ")?;
229                    }
230                }
231                write!(f, "...")?;
232                if let Ok(vs) = t
233                    .narrow(0, v - edge_items, edge_items)
234                    .and_then(|t| t.to_vec1::<Self::Elem>())
235                {
236                    for v in vs.into_iter() {
237                        write!(f, ", ")?;
238                        self.fmt(v, max_w, f)?;
239                    }
240                }
241            }
242            [_] => {
243                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
244                if let Ok(vs) = t.to_vec1::<Self::Elem>() {
245                    for (i, v) in vs.into_iter().enumerate() {
246                        if i > 0 {
247                            if i % elements_per_line == 0 {
248                                write!(f, ",")?;
249                                Self::write_newline_indent(indent, f)?
250                            } else {
251                                write!(f, ", ")?;
252                            }
253                        }
254                        self.fmt(v, max_w, f)?
255                    }
256                }
257            }
258            _ => {
259                if summarize && dims[0] > 2 * edge_items {
260                    for i in 0..edge_items {
261                        match t.get(i) {
262                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
263                            Err(e) => write!(f, "{e:?}")?,
264                        }
265                        write!(f, ",")?;
266                        Self::write_newline_indent(indent, f)?
267                    }
268                    write!(f, "...")?;
269                    Self::write_newline_indent(indent, f)?;
270                    for i in dims[0] - edge_items..dims[0] {
271                        match t.get(i) {
272                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
273                            Err(e) => write!(f, "{e:?}")?,
274                        }
275                        if i + 1 != dims[0] {
276                            write!(f, ",")?;
277                            Self::write_newline_indent(indent, f)?
278                        }
279                    }
280                } else {
281                    for i in 0..dims[0] {
282                        match t.get(i) {
283                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
284                            Err(e) => write!(f, "{e:?}")?,
285                        }
286                        if i + 1 != dims[0] {
287                            write!(f, ",")?;
288                            Self::write_newline_indent(indent, f)?
289                        }
290                    }
291                }
292            }
293        }
294        write!(f, "]")?;
295        Ok(())
296    }
297}
298
299struct FloatFormatter<S: WithDType> {
300    int_mode: bool,
301    sci_mode: bool,
302    precision: usize,
303    _phantom: std::marker::PhantomData<S>,
304}
305
306impl<S> FloatFormatter<S>
307where
308    S: WithDType + num_traits::Float + std::fmt::Display,
309{
310    fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
311        let mut int_mode = true;
312        let mut sci_mode = false;
313
314        // Rather than containing all values, this should only include
315        // values that end up being displayed according to [threshold].
316        let values = t
317            .flatten_all()?
318            .to_vec1()?
319            .into_iter()
320            .filter(|v: &S| v.is_finite() && !v.is_zero())
321            .collect::<Vec<_>>();
322        if !values.is_empty() {
323            let mut nonzero_finite_min = S::max_value();
324            let mut nonzero_finite_max = S::min_value();
325            for &v in values.iter() {
326                let v = v.abs();
327                if v < nonzero_finite_min {
328                    nonzero_finite_min = v
329                }
330                if v > nonzero_finite_max {
331                    nonzero_finite_max = v
332                }
333            }
334
335            for &value in values.iter() {
336                if value.ceil() != value {
337                    int_mode = false;
338                    break;
339                }
340            }
341            if let Some(v1) = S::from(1000.) {
342                if let Some(v2) = S::from(1e8) {
343                    if let Some(v3) = S::from(1e-4) {
344                        sci_mode = nonzero_finite_max / nonzero_finite_min > v1
345                            || nonzero_finite_max > v2
346                            || nonzero_finite_min < v3
347                    }
348                }
349            }
350        }
351
352        match po.sci_mode {
353            None => {}
354            Some(v) => sci_mode = v,
355        }
356        Ok(Self {
357            int_mode,
358            sci_mode,
359            precision: po.precision,
360            _phantom: std::marker::PhantomData,
361        })
362    }
363}
364
365impl<S> TensorFormatter for FloatFormatter<S>
366where
367    S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
368{
369    type Elem = S;
370
371    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
372        if self.sci_mode {
373            write!(
374                f,
375                "{v:width$.prec$e}",
376                v = v,
377                width = max_w,
378                prec = self.precision
379            )
380        } else if self.int_mode {
381            if v.is_finite() {
382                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
383            } else {
384                write!(f, "{v:max_w$.0}")
385            }
386        } else {
387            write!(
388                f,
389                "{v:width$.prec$}",
390                v = v,
391                width = max_w,
392                prec = self.precision
393            )
394        }
395    }
396}
397
398struct IntFormatter<S: WithDType> {
399    _phantom: std::marker::PhantomData<S>,
400}
401
402impl<S: WithDType> IntFormatter<S> {
403    fn new() -> Self {
404        Self {
405            _phantom: std::marker::PhantomData,
406        }
407    }
408}
409
410impl<S> TensorFormatter for IntFormatter<S>
411where
412    S: WithDType + std::fmt::Display,
413{
414    type Elem = S;
415
416    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
417        write!(f, "{v:max_w$}")
418    }
419}
420
421fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
422    let dims = t.dims();
423    if dims.is_empty() {
424        Ok(t.clone())
425    } else if dims.len() == 1 {
426        if dims[0] > 2 * edge_items {
427            Tensor::cat(
428                &[
429                    t.narrow(0, 0, edge_items)?,
430                    t.narrow(0, dims[0] - edge_items, edge_items)?,
431                ],
432                0,
433            )
434        } else {
435            Ok(t.clone())
436        }
437    } else if dims[0] > 2 * edge_items {
438        let mut vs: Vec<_> = (0..edge_items)
439            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
440            .collect::<Result<Vec<_>>>()?;
441        for i in (dims[0] - edge_items)..dims[0] {
442            vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
443        }
444        Tensor::cat(&vs, 0)
445    } else {
446        let vs: Vec<_> = (0..dims[0])
447            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
448            .collect::<Result<Vec<_>>>()?;
449        Tensor::cat(&vs, 0)
450    }
451}
452
453impl std::fmt::Display for Tensor {
454    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
455        let po = PRINT_OPTS.lock().unwrap();
456        let summarize = self.elem_count() > po.threshold;
457        let to_display = if summarize {
458            match get_summarized_data(self, po.edge_items) {
459                Ok(v) => v,
460                Err(err) => return write!(f, "{err:?}"),
461            }
462        } else {
463            self.clone()
464        };
465        match self.dtype() {
466            DType::U8 => {
467                let tf: IntFormatter<u8> = IntFormatter::new();
468                let max_w = tf.max_width(&to_display);
469                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
470                writeln!(f)?;
471            }
472            DType::U32 => {
473                let tf: IntFormatter<u32> = IntFormatter::new();
474                let max_w = tf.max_width(&to_display);
475                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
476                writeln!(f)?;
477            }
478            DType::I16 => {
479                let tf: IntFormatter<i16> = IntFormatter::new();
480                let max_w = tf.max_width(&to_display);
481                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
482                writeln!(f)?;
483            }
484            DType::I32 => {
485                let tf: IntFormatter<i32> = IntFormatter::new();
486                let max_w = tf.max_width(&to_display);
487                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
488                writeln!(f)?;
489            }
490            DType::I64 => {
491                let tf: IntFormatter<i64> = IntFormatter::new();
492                let max_w = tf.max_width(&to_display);
493                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
494                writeln!(f)?;
495            }
496            DType::BF16 => {
497                if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
498                    let max_w = tf.max_width(&to_display);
499                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
500                    writeln!(f)?;
501                }
502            }
503            DType::F16 => {
504                if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
505                    let max_w = tf.max_width(&to_display);
506                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
507                    writeln!(f)?;
508                }
509            }
510            DType::F64 => {
511                if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
512                    let max_w = tf.max_width(&to_display);
513                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
514                    writeln!(f)?;
515                }
516            }
517            DType::F32 => {
518                if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
519                    let max_w = tf.max_width(&to_display);
520                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
521                    writeln!(f)?;
522                }
523            }
524            DType::F8E4M3 => {
525                if let Ok(tf) = FloatFormatter::<float8::F8E4M3>::new(&to_display, &po) {
526                    let max_w = tf.max_width(&to_display);
527                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
528                    writeln!(f)?;
529                }
530            }
531            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
532                writeln!(
533                    f,
534                    "Dummy type {} (not supported for display)",
535                    self.dtype().as_str()
536                )?;
537            }
538        };
539
540        let device_str = match self.device().location() {
541            crate::DeviceLocation::Cpu => "".to_owned(),
542            crate::DeviceLocation::Cuda { gpu_id } => {
543                format!(", cuda:{gpu_id}")
544            }
545            crate::DeviceLocation::Metal { gpu_id } => {
546                format!(", metal:{gpu_id}")
547            }
548        };
549
550        write!(
551            f,
552            "Tensor[{:?}, {}{}]",
553            self.dims(),
554            self.dtype().as_str(),
555            device_str
556        )
557    }
558}