Skip to main content

lumen_core/tensor/
display.rs

1//! Pretty printing of tensors
2//!
3//! adapted from [candle](https://github.com/huggingface/candle)
4//!
5use crate::{FloatCategory, FloatDType, IntCategory, IntDType, NumCategory, NumDType, Result, Tensor, WithDType};
6
7// =================================================================================== //
8//                      Debug 
9// =================================================================================== //
10
11impl<T: WithDType> std::fmt::Debug for Tensor<T> {
12    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
13        write!(f, "Tensor[")?;
14        if self.is_meta() {
15            return write!(f, "Meta; {:?}, {}]", self.dims(), self.dtype())
16        }
17        match self.dims() {
18            [] => {
19                if let Ok(v) = self.to_scalar() {
20                    write!(f, "{v}")?
21                }
22            }
23            [s] if *s < 10 => {
24                for (i, v) in self.to_vec().expect("Meta Tensor").iter().enumerate() {
25                    if i > 0 {
26                        write!(f, ", ")?;
27                    }
28                    write!(f, "{v}")?;
29                }
30            }
31            dims => {
32                write!(f, "dims ")?;
33                for (i, d) in dims.iter().enumerate() {
34                    if i > 0 {
35                        write!(f, ", ")?;
36                    }
37                    write!(f, "{d}")?;
38                }
39            }
40        }
41        write!(f, "; {}]", self.dtype())
42    }
43}
44
45// =================================================================================== //
46//                  Options for Tensor pretty printing
47// =================================================================================== //
48
49#[derive(Debug, Clone)]
50pub struct PrinterOptions {
51    pub precision: usize,
52    pub threshold: usize,
53    pub edge_items: usize,
54    pub line_width: usize,
55    pub sci_mode: Option<bool>,
56}
57
58static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
59    std::sync::Mutex::new(PrinterOptions::const_default());
60
61impl PrinterOptions {
62    // We cannot use the default trait as it's not const.
63    const fn const_default() -> Self {
64        Self {
65            precision: 4,
66            threshold: 1000,
67            edge_items: 3,
68            line_width: 80,
69            sci_mode: None,
70        }
71    }
72}
73
74pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
75    &PRINT_OPTS
76}
77
78pub fn set_print_options(options: PrinterOptions) {
79    *PRINT_OPTS.lock().unwrap() = options
80}
81
82pub fn set_print_options_default() {
83    *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
84}
85
86pub fn set_print_options_short() {
87    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
88        precision: 2,
89        threshold: 1000,
90        edge_items: 2,
91        line_width: 80,
92        sci_mode: None,
93    }
94}
95
96pub fn set_print_options_full() {
97    *PRINT_OPTS.lock().unwrap() = PrinterOptions {
98        precision: 4,
99        threshold: usize::MAX,
100        edge_items: 3,
101        line_width: 80,
102        sci_mode: None,
103    }
104}
105
106pub fn set_line_width(line_width: usize) {
107    PRINT_OPTS.lock().unwrap().line_width = line_width
108}
109
110pub fn set_precision(precision: usize) {
111    PRINT_OPTS.lock().unwrap().precision = precision
112}
113
114pub fn set_edge_items(edge_items: usize) {
115    PRINT_OPTS.lock().unwrap().edge_items = edge_items
116}
117
118pub fn set_threshold(threshold: usize) {
119    PRINT_OPTS.lock().unwrap().threshold = threshold
120}
121
122pub fn set_sci_mode(sci_mode: Option<bool>) {
123    PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
124}
125
126struct FmtSize {
127    current_size: usize,
128}
129
130impl FmtSize {
131    fn new() -> Self {
132        Self { current_size: 0 }
133    }
134
135    fn final_size(self) -> usize {
136        self.current_size
137    }
138}
139
140impl std::fmt::Write for FmtSize {
141    fn write_str(&mut self, s: &str) -> std::fmt::Result {
142        self.current_size += s.len();
143        Ok(())
144    }
145}
146
147trait TensorFormatter {
148    type Elem: WithDType;
149
150    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
151
152    fn max_width(&self, to_display: &Tensor<Self::Elem>) -> usize {
153        let mut max_width = 1;
154        if let Ok(t) = to_display.flatten_all() {
155            let vs = t.to_vec().expect("Meta Tensor");
156            for &v in vs.iter() {
157                let mut fmt_size = FmtSize::new();
158                let _res = self.fmt(v, 1, &mut fmt_size);
159                max_width = usize::max(max_width, fmt_size.final_size())
160            }
161        }
162        max_width
163    }
164
165    fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
166        writeln!(f)?;
167        for _ in 0..i {
168            write!(f, " ")?
169        }
170        Ok(())
171    }
172
173    fn fmt_tensor(
174        &self,
175        t: &Tensor<Self::Elem>,
176        indent: usize,
177        max_w: usize,
178        summarize: bool,
179        po: &PrinterOptions,
180        f: &mut std::fmt::Formatter,
181    ) -> std::fmt::Result {
182        let dims = t.dims();
183        let edge_items = po.edge_items;
184        write!(f, "[")?;
185        match dims {
186            [] => {
187                if let Ok(v) = t.to_scalar() {
188                    self.fmt(v, max_w, f)?
189                }
190            }
191            [v] if summarize && *v > 2 * edge_items => {
192                if let Ok(t) = t
193                    .narrow(0, 0, edge_items)
194                {
195                    for v in t.to_vec().expect("Meta Tensor").into_iter() {
196                        self.fmt(v, max_w, f)?;
197                        write!(f, ", ")?;
198                    }
199                }
200                write!(f, "...")?;
201                if let Ok(t) = t
202                    .narrow(0, v - edge_items, edge_items)
203                {
204                    for v in t.to_vec().expect("Meta Tensor").into_iter() {
205                        write!(f, ", ")?;
206                        self.fmt(v, max_w, f)?;
207                    }
208                }
209            }
210            [_] => {
211                let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
212                let vs = t.to_vec().expect("Meta Tensor");
213                for (i, v) in vs.into_iter().enumerate() {
214                    if i > 0 {
215                        if i % elements_per_line == 0 {
216                            write!(f, ",")?;
217                            Self::write_newline_indent(indent, f)?
218                        } else {
219                            write!(f, ", ")?;
220                        }
221                    }
222                    self.fmt(v, max_w, f)?
223                }
224            }
225            _ => {
226                if summarize && dims[0] > 2 * edge_items {
227                    for i in 0..edge_items {
228                        match t.get(i) {
229                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
230                            Err(e) => write!(f, "{e:?}")?,
231                        }
232                        write!(f, ",")?;
233                        Self::write_newline_indent(indent, f)?
234                    }
235                    write!(f, "...")?;
236                    Self::write_newline_indent(indent, f)?;
237                    for i in dims[0] - edge_items..dims[0] {
238                        match t.get(i) {
239                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
240                            Err(e) => write!(f, "{e:?}")?,
241                        }
242                        if i + 1 != dims[0] {
243                            write!(f, ",")?;
244                            Self::write_newline_indent(indent, f)?
245                        }
246                    }
247                } else {
248                    for i in 0..dims[0] {
249                        match t.get(i) {
250                            Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
251                            Err(e) => write!(f, "{e:?}")?,
252                        }
253                        if i + 1 != dims[0] {
254                            write!(f, ",")?;
255                            Self::write_newline_indent(indent, f)?
256                        }
257                    }
258                }
259            }
260        }
261        write!(f, "]")?;
262        Ok(())
263    }
264}
265
266struct FloatFormatter<S: FloatDType> {
267    int_mode: bool,
268    sci_mode: bool,
269    precision: usize,
270    _phantom: std::marker::PhantomData<S>,
271}
272
273impl<S: FloatDType> FloatFormatter<S> {
274    fn new(t: &Tensor<S>, po: &PrinterOptions) -> Result<Self> {
275        let mut int_mode = true;
276        let mut sci_mode = false;
277
278        // Rather than containing all values, this should only include
279        // values that end up being displayed according to [threshold].
280        let values = t
281            .flatten_all()?
282            .to_vec().expect("Meta Tensor")
283            .into_iter()
284            .filter(|v: &S| v.is_finite() && !v.is_zero())
285            .collect::<Vec<_>>();
286        if !values.is_empty() {
287            let mut nonzero_finite_min = S::MAX_VALUE;
288            let mut nonzero_finite_max = S::MIN_VALUE;
289            for &v in values.iter() {
290                let v = v.abs();
291                if v < nonzero_finite_min {
292                    nonzero_finite_min = v
293                }
294                if v > nonzero_finite_max {
295                    nonzero_finite_max = v
296                }
297            }
298
299            for &value in values.iter() {
300                if value.ceil() != value {
301                    int_mode = false;
302                    break;
303                }
304            }
305            if let Some(v1) = S::from(1000.) {
306                if let Some(v2) = S::from(1e8) {
307                    if let Some(v3) = S::from(1e-4) {
308                        sci_mode = nonzero_finite_max / nonzero_finite_min > v1
309                            || nonzero_finite_max > v2
310                            || nonzero_finite_min < v3
311                    }
312                }
313            }
314        }
315
316        match po.sci_mode {
317            None => {}
318            Some(v) => sci_mode = v,
319        }
320        Ok(Self {
321            int_mode,
322            sci_mode,
323            precision: po.precision,
324            _phantom: std::marker::PhantomData,
325        })
326    }
327}
328
329impl<S: FloatDType> TensorFormatter for FloatFormatter<S> {
330    type Elem = S;
331
332    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
333        if self.sci_mode {
334            write!(
335                f,
336                "{v:width$.prec$e}",
337                v = v,
338                width = max_w,
339                prec = self.precision
340            )
341        } else if self.int_mode {
342            if v.is_finite() {
343                write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
344            } else {
345                write!(f, "{v:max_w$.0}")
346            }
347        } else {
348            write!(
349                f,
350                "{v:width$.prec$}",
351                v = v,
352                width = max_w,
353                prec = self.precision
354            )
355        }
356    }
357}
358
359struct IntFormatter<S: WithDType> {
360    _phantom: std::marker::PhantomData<S>,
361}
362
363impl<S: WithDType> IntFormatter<S> {
364    fn new() -> Self {
365        Self {
366            _phantom: std::marker::PhantomData,
367        }
368    }
369}
370
371impl<S: WithDType> TensorFormatter for IntFormatter<S> {
372    type Elem = S;
373
374    fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
375        write!(f, "{v:max_w$}")
376    }
377}
378
379fn get_summarized_data<T: WithDType>(t: &Tensor<T>, edge_items: usize) -> Result<Tensor<T>> {
380    let dims = t.dims();
381    if dims.is_empty() {
382        Ok(t.clone())
383    } else if dims.len() == 1 {
384        if dims[0] > 2 * edge_items {
385            Tensor::cat(
386                &[
387                    t.narrow(0, 0, edge_items)?,
388                    t.narrow(0, dims[0] - edge_items, edge_items)?,
389                ],
390                0,
391            )
392        } else {
393            Ok(t.clone())
394        }
395    } else if dims[0] > 2 * edge_items {
396        let mut vs: Vec<_> = (0..edge_items)
397            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
398            .collect::<Result<Vec<_>>>()?;
399        for i in (dims[0] - edge_items)..dims[0] {
400            vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
401        }
402        Tensor::cat(&vs, 0)
403    } else {
404        let vs: Vec<_> = (0..dims[0])
405            .map(|i| get_summarized_data(&t.get(i)?, edge_items))
406            .collect::<Result<Vec<_>>>()?;
407        Tensor::cat(&vs, 0)
408    }
409}
410
411trait Display<T: WithDType> {
412    fn fmt_display(tensor: &Tensor<T>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        if tensor.is_meta() {
414            return write!(
415                f,
416                "Tensor[Meta; {:?}, {}]",
417                tensor.dims(),
418                tensor.dtype(),
419            )
420        }
421
422        let po: std::sync::MutexGuard<'_, PrinterOptions> = PRINT_OPTS.lock().unwrap();
423        let summarize = tensor.element_count() > po.threshold;
424        let to_display = if summarize {
425            match get_summarized_data(tensor, po.edge_items) {
426                Ok(v) => v,
427                Err(err) => return write!(f, "{err:?}"),
428            }
429        } else {
430            tensor.clone()
431        };
432
433        Self::do_fmt_display(tensor, to_display, f, summarize, po)?;
434
435        write!(
436            f,
437            "Tensor[{:?}, {}]",
438            tensor.dims(),
439            tensor.dtype(),
440        )
441    }
442
443    fn do_fmt_display(
444        tensor: &Tensor<T>, 
445        to_display: Tensor<T>, 
446        f: &mut std::fmt::Formatter,
447        summarize: bool,
448        po: std::sync::MutexGuard<'_, PrinterOptions>
449    ) -> std::fmt::Result {
450        let tf = IntFormatter::new();
451        let max_w = tf.max_width(&to_display);
452        tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
453        writeln!(f)
454    }
455}
456
457struct DisplayOp;
458
459impl<T: NumDType> Display<T> for DisplayOp 
460where 
461    Self: DisplayCategory<T>,
462{
463    fn do_fmt_display(
464        tensor: &Tensor<T>, 
465        to_display: Tensor<T>, 
466        f: &mut std::fmt::Formatter,
467        summarize: bool,
468        po: std::sync::MutexGuard<'_, PrinterOptions>
469    ) -> std::fmt::Result {
470        DisplayOp::do_fmt_display_category(tensor, to_display, f, summarize, po)
471    }
472}
473
474trait DisplayCategory<T: NumDType, C: NumCategory = <T as NumDType>::Category> {
475    fn do_fmt_display_category(
476        tensor: &Tensor<T>, 
477        to_display: Tensor<T>, 
478        f: &mut std::fmt::Formatter,
479        summarize: bool,
480        po: std::sync::MutexGuard<'_, PrinterOptions>
481    ) -> std::fmt::Result;
482}
483
484impl<T: IntDType> DisplayCategory<T, IntCategory> for DisplayOp {
485    fn do_fmt_display_category(
486        tensor: &Tensor<T>, 
487        to_display: Tensor<T>, 
488        f: &mut std::fmt::Formatter,
489        summarize: bool,
490        po: std::sync::MutexGuard<'_, PrinterOptions>
491    ) -> std::fmt::Result {
492        let tf: IntFormatter<T> = IntFormatter::new();
493        let max_w = tf.max_width(&to_display);
494        tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
495        writeln!(f)
496    }
497} 
498
499impl<T: FloatDType> DisplayCategory<T, FloatCategory> for DisplayOp {
500    fn do_fmt_display_category(
501        tensor: &Tensor<T>, 
502        to_display: Tensor<T>, 
503        f: &mut std::fmt::Formatter,
504        summarize: bool,
505        po: std::sync::MutexGuard<'_, PrinterOptions>
506    ) -> std::fmt::Result {
507        if let Ok(tf) = FloatFormatter::new(&to_display, &po) {
508            let max_w = tf.max_width(&to_display);
509            tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
510            writeln!(f)?;
511        }
512        Ok(())
513    }
514} 
515
516impl Display<bool> for DisplayOp {
517    
518}
519
520impl<T: WithDType> std::fmt::Display for Tensor<T> 
521where 
522    DisplayOp: Display<T>,
523{
524    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
525        DisplayOp::fmt_display(self, f)
526    }
527}