candle_core/
display.rs

1//! Pretty printing of tensors
2//!
3//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
4//!
5use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9    fn fmt_dt<T: WithDType + std::fmt::Display>(
10        &self,
11        f: &mut std::fmt::Formatter,
12    ) -> std::fmt::Result {
13        let device_str = match self.device().location() {
14            crate::DeviceLocation::Cpu => "".to_owned(),
15            crate::DeviceLocation::Cuda { gpu_id } => {
16                format!(", cuda:{}", gpu_id)
17            }
18            crate::DeviceLocation::Metal { gpu_id } => {
19                format!(", metal:{}", gpu_id)
20            }
21        };
22
23        write!(f, "Tensor[")?;
24        match self.dims() {
25            [] => {
26                if let Ok(v) = self.to_scalar::<T>() {
27                    write!(f, "{v}")?
28                }
29            }
30            [s] if *s < 10 => {
31                if let Ok(vs) = self.to_vec1::<T>() {
32                    for (i, v) in vs.iter().enumerate() {
33                        if i > 0 {
34                            write!(f, ", ")?;
35                        }
36                        write!(f, "{v}")?;
37                    }
38                }
39            }
40            dims => {
41                write!(f, "dims ")?;
42                for (i, d) in dims.iter().enumerate() {
43                    if i > 0 {
44                        write!(f, ", ")?;
45                    }
46                    write!(f, "{d}")?;
47                }
48            }
49        }
50        write!(f, "; {}{}]", self.dtype().as_str(), device_str)
51    }
52}
53
54impl std::fmt::Debug for Tensor {
55    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56        match self.dtype() {
57            DType::U8 => self.fmt_dt::<u8>(f),
58            DType::U32 => self.fmt_dt::<u32>(f),
59            DType::I64 => self.fmt_dt::<i64>(f),
60            DType::BF16 => self.fmt_dt::<bf16>(f),
61            DType::F16 => self.fmt_dt::<f16>(f),
62            DType::F32 => self.fmt_dt::<f32>(f),
63            DType::F64 => self.fmt_dt::<f64>(f),
64        }
65    }
66}
67
68/// Options for Tensor pretty printing
69#[derive(Debug, Clone)]
70pub struct PrinterOptions {
71    pub precision: usize,
72    pub threshold: usize,
73    pub edge_items: usize,
74    pub line_width: usize,
75    pub sci_mode: Option<bool>,
76}
77
78static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
79    std::sync::Mutex::new(PrinterOptions::const_default());
80
81impl PrinterOptions {
82    // We cannot use the default trait as it's not const.
83    const fn const_default() -> Self {
84        Self {
85            precision: 4,
86            threshold: 1000,
87            edge_items: 3,
88            line_width: 80,
89            sci_mode: None,
90        }
91    }
92}
93
94pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
95    &PRINT_OPTS
96}
97
98pub fn set_print_options(options: PrinterOptions) {
99    *PRINT_OPTS.lock().unwrap() = options
100}
101
102pub fn set_print_options_default() {
103    *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
104}
105
106pub fn set_print_options_short() {
107    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
108        precision: 2,
109        threshold: 1000,
110        edge_items: 2,
111        line_width: 80,
112        sci_mode: None,
113    }
114}
115
116pub fn set_print_options_full() {
117    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
118        precision: 4,
119        threshold: usize::MAX,
120        edge_items: 3,
121        line_width: 80,
122        sci_mode: None,
123    }
124}
125
126pub fn set_line_width(line_width: usize) {
127    PRINT_OPTS.lock().unwrap().line_width = line_width
128}
129
130pub fn set_precision(precision: usize) {
131    PRINT_OPTS.lock().unwrap().precision = precision
132}
133
134pub fn set_edge_items(edge_items: usize) {
135    PRINT_OPTS.lock().unwrap().edge_items = edge_items
136}
137
138pub fn set_threshold(threshold: usize) {
139    PRINT_OPTS.lock().unwrap().threshold = threshold
140}
141
142pub fn set_sci_mode(sci_mode: Option<bool>) {
143    PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
144}
145
146struct FmtSize {
147    current_size: usize,
148}
149
150impl FmtSize {
151    fn new() -> Self {
152        Self { current_size: 0 }
153    }
154
155    fn final_size(self) -> usize {
156        self.current_size
157    }
158}
159
160impl std::fmt::Write for FmtSize {
161    fn write_str(&mut self, s: &str) -> std::fmt::Result {
162        self.current_size += s.len();
163        Ok(())
164    }
165}
166
167trait TensorFormatter {
168    type Elem: WithDType;
169
170    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
171
172    fn max_width(&self, to_display: &Tensor) -> usize {
173        let mut max_width = 1;
174        if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
175            for &v in vs.iter() {
176                let mut fmt_size = FmtSize::new();
177                let _res = self.fmt(v, 1, &mut fmt_size);
178                max_width = usize::max(max_width, fmt_size.final_size())
179            }
180        }
181        max_width
182    }
183
184    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
185        writeln!(f)?;
186        for _ in 0..i {
187            write!(f, " ")?
188        }
189        Ok(())
190    }
191
192    fn fmt_tensor(
193        &self,
194        t: &Tensor,
195        indent: usize,
196        max_w: usize,
197        summarize: bool,
198        po: &PrinterOptions,
199        f: &mut std::fmt::Formatter,
200    ) -> std::fmt::Result {
201        let dims = t.dims();
202        let edge_items = po.edge_items;
203        write!(f, "[")?;
204        match dims {
205            [] => {
206                if let Ok(v) = t.to_scalar::<Self::Elem>() {
207                    self.fmt(v, max_w, f)?
208                }
209            }
210            [v] if summarize && *v > 2 * edge_items => {
211                if let Ok(vs) = t
212                    .narrow(0, 0, edge_items)
213                    .and_then(|t| t.to_vec1::<Self::Elem>())
214                {
215                    for v in vs.into_iter() {
216                        self.fmt(v, max_w, f)?;
217                        write!(f, ", ")?;
218                    }
219                }
220                write!(f, "...")?;
221                if let Ok(vs) = t
222                    .narrow(0, v - edge_items, edge_items)
223                    .and_then(|t| t.to_vec1::<Self::Elem>())
224                {
225                    for v in vs.into_iter() {
226                        write!(f, ", ")?;
227                        self.fmt(v, max_w, f)?;
228                    }
229                }
230            }
231            [_] => {
232                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
233                if let Ok(vs) = t.to_vec1::<Self::Elem>() {
234                    for (i, v) in vs.into_iter().enumerate() {
235                        if i > 0 {
236                            if i % elements_per_line == 0 {
237                                write!(f, ",")?;
238                                Self::write_newline_indent(indent, f)?
239                            } else {
240                                write!(f, ", ")?;
241                            }
242                        }
243                        self.fmt(v, max_w, f)?
244                    }
245                }
246            }
247            _ => {
248                if summarize && dims[0] > 2 * edge_items {
249                    for i in 0..edge_items {
250                        match t.get(i) {
251                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
252                            Err(e) => write!(f, "{e:?}")?,
253                        }
254                        write!(f, ",")?;
255                        Self::write_newline_indent(indent, f)?
256                    }
257                    write!(f, "...")?;
258                    Self::write_newline_indent(indent, f)?;
259                    for i in dims[0] - edge_items..dims[0] {
260                        match t.get(i) {
261                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
262                            Err(e) => write!(f, "{e:?}")?,
263                        }
264                        if i + 1 != dims[0] {
265                            write!(f, ",")?;
266                            Self::write_newline_indent(indent, f)?
267                        }
268                    }
269                } else {
270                    for i in 0..dims[0] {
271                        match t.get(i) {
272                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
273                            Err(e) => write!(f, "{e:?}")?,
274                        }
275                        if i + 1 != dims[0] {
276                            write!(f, ",")?;
277                            Self::write_newline_indent(indent, f)?
278                        }
279                    }
280                }
281            }
282        }
283        write!(f, "]")?;
284        Ok(())
285    }
286}
287
288struct FloatFormatter<S: WithDType> {
289    int_mode: bool,
290    sci_mode: bool,
291    precision: usize,
292    _phantom: std::marker::PhantomData<S>,
293}
294
295impl<S> FloatFormatter<S>
296where
297    S: WithDType + num_traits::Float + std::fmt::Display,
298{
299    fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
300        let mut int_mode = true;
301        let mut sci_mode = false;
302
303        // Rather than containing all values, this should only include
304        // values that end up being displayed according to [threshold].
305        let values = t
306            .flatten_all()?
307            .to_vec1()?
308            .into_iter()
309            .filter(|v: &S| v.is_finite() && !v.is_zero())
310            .collect::<Vec<_>>();
311        if !values.is_empty() {
312            let mut nonzero_finite_min = S::max_value();
313            let mut nonzero_finite_max = S::min_value();
314            for &v in values.iter() {
315                let v = v.abs();
316                if v < nonzero_finite_min {
317                    nonzero_finite_min = v
318                }
319                if v > nonzero_finite_max {
320                    nonzero_finite_max = v
321                }
322            }
323
324            for &value in values.iter() {
325                if value.ceil() != value {
326                    int_mode = false;
327                    break;
328                }
329            }
330            if let Some(v1) = S::from(1000.) {
331                if let Some(v2) = S::from(1e8) {
332                    if let Some(v3) = S::from(1e-4) {
333                        sci_mode = nonzero_finite_max / nonzero_finite_min > v1
334                            || nonzero_finite_max > v2
335                            || nonzero_finite_min < v3
336                    }
337                }
338            }
339        }
340
341        match po.sci_mode {
342            None => {}
343            Some(v) => sci_mode = v,
344        }
345        Ok(Self {
346            int_mode,
347            sci_mode,
348            precision: po.precision,
349            _phantom: std::marker::PhantomData,
350        })
351    }
352}
353
354impl<S> TensorFormatter for FloatFormatter<S>
355where
356    S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
357{
358    type Elem = S;
359
360    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
361        if self.sci_mode {
362            write!(
363                f,
364                "{v:width$.prec$e}",
365                v = v,
366                width = max_w,
367                prec = self.precision
368            )
369        } else if self.int_mode {
370            if v.is_finite() {
371                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
372            } else {
373                write!(f, "{v:max_w$.0}")
374            }
375        } else {
376            write!(
377                f,
378                "{v:width$.prec$}",
379                v = v,
380                width = max_w,
381                prec = self.precision
382            )
383        }
384    }
385}
386
387struct IntFormatter<S: WithDType> {
388    _phantom: std::marker::PhantomData<S>,
389}
390
391impl<S: WithDType> IntFormatter<S> {
392    fn new() -> Self {
393        Self {
394            _phantom: std::marker::PhantomData,
395        }
396    }
397}
398
399impl<S> TensorFormatter for IntFormatter<S>
400where
401    S: WithDType + std::fmt::Display,
402{
403    type Elem = S;
404
405    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
406        write!(f, "{v:max_w$}")
407    }
408}
409
410fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
411    let dims = t.dims();
412    if dims.is_empty() {
413        Ok(t.clone())
414    } else if dims.len() == 1 {
415        if dims[0] > 2 * edge_items {
416            Tensor::cat(
417                &[
418                    t.narrow(0, 0, edge_items)?,
419                    t.narrow(0, dims[0] - edge_items, edge_items)?,
420                ],
421                0,
422            )
423        } else {
424            Ok(t.clone())
425        }
426    } else if dims[0] > 2 * edge_items {
427        let mut vs: Vec<_> = (0..edge_items)
428            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
429            .collect::<Result<Vec<_>>>()?;
430        for i in (dims[0] - edge_items)..dims[0] {
431            vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
432        }
433        Tensor::cat(&vs, 0)
434    } else {
435        let vs: Vec<_> = (0..dims[0])
436            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
437            .collect::<Result<Vec<_>>>()?;
438        Tensor::cat(&vs, 0)
439    }
440}
441
442impl std::fmt::Display for Tensor {
443    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
444        let po = PRINT_OPTS.lock().unwrap();
445        let summarize = self.elem_count() > po.threshold;
446        let to_display = if summarize {
447            match get_summarized_data(self, po.edge_items) {
448                Ok(v) => v,
449                Err(err) => return write!(f, "{err:?}"),
450            }
451        } else {
452            self.clone()
453        };
454        match self.dtype() {
455            DType::U8 => {
456                let tf: IntFormatter<u8> = IntFormatter::new();
457                let max_w = tf.max_width(&to_display);
458                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
459                writeln!(f)?;
460            }
461            DType::U32 => {
462                let tf: IntFormatter<u32> = IntFormatter::new();
463                let max_w = tf.max_width(&to_display);
464                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
465                writeln!(f)?;
466            }
467            DType::I64 => {
468                let tf: IntFormatter<i64> = IntFormatter::new();
469                let max_w = tf.max_width(&to_display);
470                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
471                writeln!(f)?;
472            }
473            DType::BF16 => {
474                if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
475                    let max_w = tf.max_width(&to_display);
476                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
477                    writeln!(f)?;
478                }
479            }
480            DType::F16 => {
481                if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
482                    let max_w = tf.max_width(&to_display);
483                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
484                    writeln!(f)?;
485                }
486            }
487            DType::F64 => {
488                if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
489                    let max_w = tf.max_width(&to_display);
490                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
491                    writeln!(f)?;
492                }
493            }
494            DType::F32 => {
495                if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
496                    let max_w = tf.max_width(&to_display);
497                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
498                    writeln!(f)?;
499                }
500            }
501        };
502
503        let device_str = match self.device().location() {
504            crate::DeviceLocation::Cpu => "".to_owned(),
505            crate::DeviceLocation::Cuda { gpu_id } => {
506                format!(", cuda:{}", gpu_id)
507            }
508            crate::DeviceLocation::Metal { gpu_id } => {
509                format!(", metal:{}", gpu_id)
510            }
511        };
512
513        write!(
514            f,
515            "Tensor[{:?}, {}{}]",
516            self.dims(),
517            self.dtype().as_str(),
518            device_str
519        )
520    }
521}