candle_core_temp/
display.rs

1/// Pretty printing of tensors
2/// This implementation should be in line with the PyTorch version.
3/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
4use crate::{DType, Result, Tensor, WithDType};
5use half::{bf16, f16};
6
7impl Tensor {
8    fn fmt_dt<T: WithDType + std::fmt::Display>(
9        &self,
10        f: &mut std::fmt::Formatter,
11    ) -> std::fmt::Result {
12        let device_str = match self.device().location() {
13            crate::DeviceLocation::Cpu => "".to_owned(),
14            crate::DeviceLocation::Cuda { gpu_id } => {
15                format!(", cuda:{}", gpu_id)
16            }
17        };
18
19        write!(f, "Tensor[")?;
20        match self.dims() {
21            [] => {
22                if let Ok(v) = self.to_scalar::<T>() {
23                    write!(f, "{v}")?
24                }
25            }
26            [s] if *s < 10 => {
27                if let Ok(vs) = self.to_vec1::<T>() {
28                    for (i, v) in vs.iter().enumerate() {
29                        if i > 0 {
30                            write!(f, ", ")?;
31                        }
32                        write!(f, "{v}")?;
33                    }
34                }
35            }
36            dims => {
37                write!(f, "dims ")?;
38                for (i, d) in dims.iter().enumerate() {
39                    if i > 0 {
40                        write!(f, ", ")?;
41                    }
42                    write!(f, "{d}")?;
43                }
44            }
45        }
46        write!(f, "; {}{}]", self.dtype().as_str(), device_str)
47    }
48}
49
50impl std::fmt::Debug for Tensor {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        match self.dtype() {
53            DType::U8 => self.fmt_dt::<u8>(f),
54            DType::U32 => self.fmt_dt::<u32>(f),
55            DType::I64 => self.fmt_dt::<i64>(f),
56            DType::BF16 => self.fmt_dt::<bf16>(f),
57            DType::F16 => self.fmt_dt::<f16>(f),
58            DType::F32 => self.fmt_dt::<f32>(f),
59            DType::F64 => self.fmt_dt::<f64>(f),
60        }
61    }
62}
63
64/// Options for Tensor pretty printing
65pub struct PrinterOptions {
66    precision: usize,
67    threshold: usize,
68    edge_items: usize,
69    line_width: usize,
70    sci_mode: Option<bool>,
71}
72
73static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
74    std::sync::Mutex::new(PrinterOptions::const_default());
75
76impl PrinterOptions {
77    // We cannot use the default trait as it's not const.
78    const fn const_default() -> Self {
79        Self {
80            precision: 4,
81            threshold: 1000,
82            edge_items: 3,
83            line_width: 80,
84            sci_mode: None,
85        }
86    }
87}
88
89pub fn set_print_options(options: PrinterOptions) {
90    *PRINT_OPTS.lock().unwrap() = options
91}
92
93pub fn set_print_options_default() {
94    *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
95}
96
97pub fn set_print_options_short() {
98    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
99        precision: 2,
100        threshold: 1000,
101        edge_items: 2,
102        line_width: 80,
103        sci_mode: None,
104    }
105}
106
107pub fn set_print_options_full() {
108    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
109        precision: 4,
110        threshold: usize::MAX,
111        edge_items: 3,
112        line_width: 80,
113        sci_mode: None,
114    }
115}
116
117struct FmtSize {
118    current_size: usize,
119}
120
121impl FmtSize {
122    fn new() -> Self {
123        Self { current_size: 0 }
124    }
125
126    fn final_size(self) -> usize {
127        self.current_size
128    }
129}
130
131impl std::fmt::Write for FmtSize {
132    fn write_str(&mut self, s: &str) -> std::fmt::Result {
133        self.current_size += s.len();
134        Ok(())
135    }
136}
137
138trait TensorFormatter {
139    type Elem: WithDType;
140
141    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
142
143    fn max_width(&self, to_display: &Tensor) -> usize {
144        let mut max_width = 1;
145        if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
146            for &v in vs.iter() {
147                let mut fmt_size = FmtSize::new();
148                let _res = self.fmt(v, 1, &mut fmt_size);
149                max_width = usize::max(max_width, fmt_size.final_size())
150            }
151        }
152        max_width
153    }
154
155    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
156        writeln!(f)?;
157        for _ in 0..i {
158            write!(f, " ")?
159        }
160        Ok(())
161    }
162
163    fn fmt_tensor(
164        &self,
165        t: &Tensor,
166        indent: usize,
167        max_w: usize,
168        summarize: bool,
169        po: &PrinterOptions,
170        f: &mut std::fmt::Formatter,
171    ) -> std::fmt::Result {
172        let dims = t.dims();
173        let edge_items = po.edge_items;
174        write!(f, "[")?;
175        match dims {
176            [] => {
177                if let Ok(v) = t.to_scalar::<Self::Elem>() {
178                    self.fmt(v, max_w, f)?
179                }
180            }
181            [v] if summarize && *v > 2 * edge_items => {
182                if let Ok(vs) = t
183                    .narrow(0, 0, edge_items)
184                    .and_then(|t| t.to_vec1::<Self::Elem>())
185                {
186                    for v in vs.into_iter() {
187                        self.fmt(v, max_w, f)?;
188                        write!(f, ", ")?;
189                    }
190                }
191                write!(f, "...")?;
192                if let Ok(vs) = t
193                    .narrow(0, v - edge_items, edge_items)
194                    .and_then(|t| t.to_vec1::<Self::Elem>())
195                {
196                    for v in vs.into_iter() {
197                        write!(f, ", ")?;
198                        self.fmt(v, max_w, f)?;
199                    }
200                }
201            }
202            [_] => {
203                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
204                if let Ok(vs) = t.to_vec1::<Self::Elem>() {
205                    for (i, v) in vs.into_iter().enumerate() {
206                        if i > 0 {
207                            if i % elements_per_line == 0 {
208                                write!(f, ",")?;
209                                Self::write_newline_indent(indent, f)?
210                            } else {
211                                write!(f, ", ")?;
212                            }
213                        }
214                        self.fmt(v, max_w, f)?
215                    }
216                }
217            }
218            _ => {
219                if summarize && dims[0] > 2 * edge_items {
220                    for i in 0..edge_items {
221                        match t.get(i) {
222                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
223                            Err(e) => write!(f, "{e:?}")?,
224                        }
225                        write!(f, ",")?;
226                        Self::write_newline_indent(indent, f)?
227                    }
228                    write!(f, "...")?;
229                    Self::write_newline_indent(indent, f)?;
230                    for i in dims[0] - edge_items..dims[0] {
231                        match t.get(i) {
232                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
233                            Err(e) => write!(f, "{e:?}")?,
234                        }
235                        if i + 1 != dims[0] {
236                            write!(f, ",")?;
237                            Self::write_newline_indent(indent, f)?
238                        }
239                    }
240                } else {
241                    for i in 0..dims[0] {
242                        match t.get(i) {
243                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
244                            Err(e) => write!(f, "{e:?}")?,
245                        }
246                        if i + 1 != dims[0] {
247                            write!(f, ",")?;
248                            Self::write_newline_indent(indent, f)?
249                        }
250                    }
251                }
252            }
253        }
254        write!(f, "]")?;
255        Ok(())
256    }
257}
258
259struct FloatFormatter<S: WithDType> {
260    int_mode: bool,
261    sci_mode: bool,
262    precision: usize,
263    _phantom: std::marker::PhantomData<S>,
264}
265
266impl<S> FloatFormatter<S>
267where
268    S: WithDType + num_traits::Float + std::fmt::Display,
269{
270    fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
271        let mut int_mode = true;
272        let mut sci_mode = false;
273
274        // Rather than containing all values, this should only include
275        // values that end up being displayed according to [threshold].
276        let values = t
277            .flatten_all()?
278            .to_vec1()?
279            .into_iter()
280            .filter(|v: &S| v.is_finite() && !v.is_zero())
281            .collect::<Vec<_>>();
282        if !values.is_empty() {
283            let mut nonzero_finite_min = S::max_value();
284            let mut nonzero_finite_max = S::min_value();
285            for &v in values.iter() {
286                let v = v.abs();
287                if v < nonzero_finite_min {
288                    nonzero_finite_min = v
289                }
290                if v > nonzero_finite_max {
291                    nonzero_finite_max = v
292                }
293            }
294
295            for &value in values.iter() {
296                if value.ceil() != value {
297                    int_mode = false;
298                    break;
299                }
300            }
301            if let Some(v1) = S::from(1000.) {
302                if let Some(v2) = S::from(1e8) {
303                    if let Some(v3) = S::from(1e-4) {
304                        sci_mode = nonzero_finite_max / nonzero_finite_min > v1
305                            || nonzero_finite_max > v2
306                            || nonzero_finite_min < v3
307                    }
308                }
309            }
310        }
311
312        match po.sci_mode {
313            None => {}
314            Some(v) => sci_mode = v,
315        }
316        Ok(Self {
317            int_mode,
318            sci_mode,
319            precision: po.precision,
320            _phantom: std::marker::PhantomData,
321        })
322    }
323}
324
325impl<S> TensorFormatter for FloatFormatter<S>
326where
327    S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
328{
329    type Elem = S;
330
331    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
332        if self.sci_mode {
333            write!(
334                f,
335                "{v:width$.prec$e}",
336                v = v,
337                width = max_w,
338                prec = self.precision
339            )
340        } else if self.int_mode {
341            if v.is_finite() {
342                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
343            } else {
344                write!(f, "{v:max_w$.0}")
345            }
346        } else {
347            write!(
348                f,
349                "{v:width$.prec$}",
350                v = v,
351                width = max_w,
352                prec = self.precision
353            )
354        }
355    }
356}
357
358struct IntFormatter<S: WithDType> {
359    _phantom: std::marker::PhantomData<S>,
360}
361
362impl<S: WithDType> IntFormatter<S> {
363    fn new() -> Self {
364        Self {
365            _phantom: std::marker::PhantomData,
366        }
367    }
368}
369
370impl<S> TensorFormatter for IntFormatter<S>
371where
372    S: WithDType + std::fmt::Display,
373{
374    type Elem = S;
375
376    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
377        write!(f, "{v:max_w$}")
378    }
379}
380
381fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
382    let dims = t.dims();
383    if dims.is_empty() {
384        Ok(t.clone())
385    } else if dims.len() == 1 {
386        if dims[0] > 2 * edge_items {
387            Tensor::cat(
388                &[
389                    t.narrow(0, 0, edge_items)?,
390                    t.narrow(0, dims[0] - edge_items, edge_items)?,
391                ],
392                0,
393            )
394        } else {
395            Ok(t.clone())
396        }
397    } else if dims[0] > 2 * edge_items {
398        let mut vs: Vec<_> = (0..edge_items)
399            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
400            .collect::<Result<Vec<_>>>()?;
401        for i in (dims[0] - edge_items)..dims[0] {
402            vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
403        }
404        Tensor::cat(&vs, 0)
405    } else {
406        let vs: Vec<_> = (0..dims[0])
407            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
408            .collect::<Result<Vec<_>>>()?;
409        Tensor::cat(&vs, 0)
410    }
411}
412
413impl std::fmt::Display for Tensor {
414    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
415        let po = PRINT_OPTS.lock().unwrap();
416        let summarize = self.elem_count() > po.threshold;
417        let to_display = if summarize {
418            match get_summarized_data(self, po.edge_items) {
419                Ok(v) => v,
420                Err(err) => return write!(f, "{err:?}"),
421            }
422        } else {
423            self.clone()
424        };
425        match self.dtype() {
426            DType::U8 => {
427                let tf: IntFormatter<u8> = IntFormatter::new();
428                let max_w = tf.max_width(&to_display);
429                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
430                writeln!(f)?;
431            }
432            DType::U32 => {
433                let tf: IntFormatter<u32> = IntFormatter::new();
434                let max_w = tf.max_width(&to_display);
435                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
436                writeln!(f)?;
437            }
438            DType::I64 => {
439                let tf: IntFormatter<i64> = IntFormatter::new();
440                let max_w = tf.max_width(&to_display);
441                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
442                writeln!(f)?;
443            }
444            DType::BF16 => {
445                if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
446                    let max_w = tf.max_width(&to_display);
447                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
448                    writeln!(f)?;
449                }
450            }
451            DType::F16 => {
452                if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
453                    let max_w = tf.max_width(&to_display);
454                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
455                    writeln!(f)?;
456                }
457            }
458            DType::F64 => {
459                if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
460                    let max_w = tf.max_width(&to_display);
461                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
462                    writeln!(f)?;
463                }
464            }
465            DType::F32 => {
466                if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
467                    let max_w = tf.max_width(&to_display);
468                    tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
469                    writeln!(f)?;
470                }
471            }
472        };
473
474        let device_str = match self.device().location() {
475            crate::DeviceLocation::Cpu => "".to_owned(),
476            crate::DeviceLocation::Cuda { gpu_id } => {
477                format!(", cuda:{}", gpu_id)
478            }
479        };
480
481        write!(
482            f,
483            "Tensor[{:?}, {}{}]",
484            self.dims(),
485            self.dtype().as_str(),
486            device_str
487        )
488    }
489}