1use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9 fn fmt_dt<T: WithDType + std::fmt::Display>(
10 &self,
11 f: &mut std::fmt::Formatter,
12 ) -> std::fmt::Result {
13 let device_str = match self.device().location() {
14 crate::DeviceLocation::Cpu => "".to_owned(),
15 crate::DeviceLocation::Cuda { gpu_id } => {
16 format!(", cuda:{gpu_id}")
17 }
18 crate::DeviceLocation::Metal { gpu_id } => {
19 format!(", metal:{gpu_id}")
20 }
21 #[cfg(feature = "rocm")]
22 crate::DeviceLocation::Rocm { gpu_id } => {
23 format!(", rocm:{gpu_id}")
24 }
25 #[cfg(feature = "vulkan")]
26 crate::DeviceLocation::Vulkan { gpu_id } => {
27 format!(", vulkan:{gpu_id}")
28 }
29 };
30
31 write!(f, "Tensor[")?;
32 match self.dims() {
33 [] => {
34 if let Ok(v) = self.to_scalar::<T>() {
35 write!(f, "{v}")?
36 }
37 }
38 [s] if *s < 10 => {
39 if let Ok(vs) = self.to_vec1::<T>() {
40 for (i, v) in vs.iter().enumerate() {
41 if i > 0 {
42 write!(f, ", ")?;
43 }
44 write!(f, "{v}")?;
45 }
46 }
47 }
48 dims => {
49 write!(f, "dims ")?;
50 for (i, d) in dims.iter().enumerate() {
51 if i > 0 {
52 write!(f, ", ")?;
53 }
54 write!(f, "{d}")?;
55 }
56 }
57 }
58 write!(f, "; {}{}]", self.dtype().as_str(), device_str)
59 }
60}
61
62impl std::fmt::Debug for Tensor {
63 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64 match self.dtype() {
65 DType::U8 => self.fmt_dt::<u8>(f),
66 DType::U32 => self.fmt_dt::<u32>(f),
67 DType::I16 => self.fmt_dt::<i16>(f),
68 DType::I32 => self.fmt_dt::<i32>(f),
69 DType::I64 => self.fmt_dt::<i64>(f),
70 DType::BF16 => self.fmt_dt::<bf16>(f),
71 DType::F16 => self.fmt_dt::<f16>(f),
72 DType::F32 => self.fmt_dt::<f32>(f),
73 DType::F64 => self.fmt_dt::<f64>(f),
74 DType::F8E4M3 => self.fmt_dt::<float8::F8E4M3>(f),
75 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
76 write!(
77 f,
78 "Tensor[{:?}; dtype={}, unsupported dummy type]",
79 self.shape(),
80 self.dtype().as_str()
81 )
82 }
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct PrinterOptions {
90 pub precision: usize,
91 pub threshold: usize,
92 pub edge_items: usize,
93 pub line_width: usize,
94 pub sci_mode: Option<bool>,
95}
96
97static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
98 std::sync::Mutex::new(PrinterOptions::const_default());
99
100impl PrinterOptions {
101 const fn const_default() -> Self {
103 Self {
104 precision: 4,
105 threshold: 1000,
106 edge_items: 3,
107 line_width: 80,
108 sci_mode: None,
109 }
110 }
111}
112
113pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
114 &PRINT_OPTS
115}
116
117pub fn set_print_options(options: PrinterOptions) {
118 *PRINT_OPTS.lock().unwrap() = options
119}
120
121pub fn set_print_options_default() {
122 *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
123}
124
125pub fn set_print_options_short() {
126 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
127 precision: 2,
128 threshold: 1000,
129 edge_items: 2,
130 line_width: 80,
131 sci_mode: None,
132 }
133}
134
135pub fn set_print_options_full() {
136 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
137 precision: 4,
138 threshold: usize::MAX,
139 edge_items: 3,
140 line_width: 80,
141 sci_mode: None,
142 }
143}
144
145pub fn set_line_width(line_width: usize) {
146 PRINT_OPTS.lock().unwrap().line_width = line_width
147}
148
149pub fn set_precision(precision: usize) {
150 PRINT_OPTS.lock().unwrap().precision = precision
151}
152
153pub fn set_edge_items(edge_items: usize) {
154 PRINT_OPTS.lock().unwrap().edge_items = edge_items
155}
156
157pub fn set_threshold(threshold: usize) {
158 PRINT_OPTS.lock().unwrap().threshold = threshold
159}
160
161pub fn set_sci_mode(sci_mode: Option<bool>) {
162 PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
163}
164
165struct FmtSize {
166 current_size: usize,
167}
168
169impl FmtSize {
170 fn new() -> Self {
171 Self { current_size: 0 }
172 }
173
174 fn final_size(self) -> usize {
175 self.current_size
176 }
177}
178
179impl std::fmt::Write for FmtSize {
180 fn write_str(&mut self, s: &str) -> std::fmt::Result {
181 self.current_size += s.len();
182 Ok(())
183 }
184}
185
186trait TensorFormatter {
187 type Elem: WithDType;
188
189 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
190
191 fn max_width(&self, to_display: &Tensor) -> usize {
192 let mut max_width = 1;
193 if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
194 for &v in vs.iter() {
195 let mut fmt_size = FmtSize::new();
196 let _res = self.fmt(v, 1, &mut fmt_size);
197 max_width = usize::max(max_width, fmt_size.final_size())
198 }
199 }
200 max_width
201 }
202
203 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
204 writeln!(f)?;
205 for _ in 0..i {
206 write!(f, " ")?
207 }
208 Ok(())
209 }
210
211 fn fmt_tensor(
212 &self,
213 t: &Tensor,
214 indent: usize,
215 max_w: usize,
216 summarize: bool,
217 po: &PrinterOptions,
218 f: &mut std::fmt::Formatter,
219 ) -> std::fmt::Result {
220 let dims = t.dims();
221 let edge_items = po.edge_items;
222 write!(f, "[")?;
223 match dims {
224 [] => {
225 if let Ok(v) = t.to_scalar::<Self::Elem>() {
226 self.fmt(v, max_w, f)?
227 }
228 }
229 [v] if summarize && *v > 2 * edge_items => {
230 if let Ok(vs) = t
231 .narrow(0, 0, edge_items)
232 .and_then(|t| t.to_vec1::<Self::Elem>())
233 {
234 for v in vs.into_iter() {
235 self.fmt(v, max_w, f)?;
236 write!(f, ", ")?;
237 }
238 }
239 write!(f, "...")?;
240 if let Ok(vs) = t
241 .narrow(0, v - edge_items, edge_items)
242 .and_then(|t| t.to_vec1::<Self::Elem>())
243 {
244 for v in vs.into_iter() {
245 write!(f, ", ")?;
246 self.fmt(v, max_w, f)?;
247 }
248 }
249 }
250 [_] => {
251 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
252 if let Ok(vs) = t.to_vec1::<Self::Elem>() {
253 for (i, v) in vs.into_iter().enumerate() {
254 if i > 0 {
255 if i % elements_per_line == 0 {
256 write!(f, ",")?;
257 Self::write_newline_indent(indent, f)?
258 } else {
259 write!(f, ", ")?;
260 }
261 }
262 self.fmt(v, max_w, f)?
263 }
264 }
265 }
266 _ => {
267 if summarize && dims[0] > 2 * edge_items {
268 for i in 0..edge_items {
269 match t.get(i) {
270 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
271 Err(e) => write!(f, "{e:?}")?,
272 }
273 write!(f, ",")?;
274 Self::write_newline_indent(indent, f)?
275 }
276 write!(f, "...")?;
277 Self::write_newline_indent(indent, f)?;
278 for i in dims[0] - edge_items..dims[0] {
279 match t.get(i) {
280 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
281 Err(e) => write!(f, "{e:?}")?,
282 }
283 if i + 1 != dims[0] {
284 write!(f, ",")?;
285 Self::write_newline_indent(indent, f)?
286 }
287 }
288 } else {
289 for i in 0..dims[0] {
290 match t.get(i) {
291 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
292 Err(e) => write!(f, "{e:?}")?,
293 }
294 if i + 1 != dims[0] {
295 write!(f, ",")?;
296 Self::write_newline_indent(indent, f)?
297 }
298 }
299 }
300 }
301 }
302 write!(f, "]")?;
303 Ok(())
304 }
305}
306
307struct FloatFormatter<S: WithDType> {
308 int_mode: bool,
309 sci_mode: bool,
310 precision: usize,
311 _phantom: std::marker::PhantomData<S>,
312}
313
314impl<S> FloatFormatter<S>
315where
316 S: WithDType + num_traits::Float + std::fmt::Display,
317{
318 fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
319 let mut int_mode = true;
320 let mut sci_mode = false;
321
322 let values = t
325 .flatten_all()?
326 .to_vec1()?
327 .into_iter()
328 .filter(|v: &S| v.is_finite() && !v.is_zero())
329 .collect::<Vec<_>>();
330 if !values.is_empty() {
331 let mut nonzero_finite_min = S::max_value();
332 let mut nonzero_finite_max = S::min_value();
333 for &v in values.iter() {
334 let v = v.abs();
335 if v < nonzero_finite_min {
336 nonzero_finite_min = v
337 }
338 if v > nonzero_finite_max {
339 nonzero_finite_max = v
340 }
341 }
342
343 for &value in values.iter() {
344 if value.ceil() != value {
345 int_mode = false;
346 break;
347 }
348 }
349 if let Some(v1) = S::from(1000.) {
350 if let Some(v2) = S::from(1e8) {
351 if let Some(v3) = S::from(1e-4) {
352 sci_mode = nonzero_finite_max / nonzero_finite_min > v1
353 || nonzero_finite_max > v2
354 || nonzero_finite_min < v3
355 }
356 }
357 }
358 }
359
360 match po.sci_mode {
361 None => {}
362 Some(v) => sci_mode = v,
363 }
364 Ok(Self {
365 int_mode,
366 sci_mode,
367 precision: po.precision,
368 _phantom: std::marker::PhantomData,
369 })
370 }
371}
372
373impl<S> TensorFormatter for FloatFormatter<S>
374where
375 S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
376{
377 type Elem = S;
378
379 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
380 if self.sci_mode {
381 write!(
382 f,
383 "{v:width$.prec$e}",
384 v = v,
385 width = max_w,
386 prec = self.precision
387 )
388 } else if self.int_mode {
389 if v.is_finite() {
390 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
391 } else {
392 write!(f, "{v:max_w$.0}")
393 }
394 } else {
395 write!(
396 f,
397 "{v:width$.prec$}",
398 v = v,
399 width = max_w,
400 prec = self.precision
401 )
402 }
403 }
404}
405
406struct IntFormatter<S: WithDType> {
407 _phantom: std::marker::PhantomData<S>,
408}
409
410impl<S: WithDType> IntFormatter<S> {
411 fn new() -> Self {
412 Self {
413 _phantom: std::marker::PhantomData,
414 }
415 }
416}
417
418impl<S> TensorFormatter for IntFormatter<S>
419where
420 S: WithDType + std::fmt::Display,
421{
422 type Elem = S;
423
424 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
425 write!(f, "{v:max_w$}")
426 }
427}
428
429fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
430 let dims = t.dims();
431 if dims.is_empty() {
432 Ok(t.clone())
433 } else if dims.len() == 1 {
434 if dims[0] > 2 * edge_items {
435 Tensor::cat(
436 &[
437 t.narrow(0, 0, edge_items)?,
438 t.narrow(0, dims[0] - edge_items, edge_items)?,
439 ],
440 0,
441 )
442 } else {
443 Ok(t.clone())
444 }
445 } else if dims[0] > 2 * edge_items {
446 let mut vs: Vec<_> = (0..edge_items)
447 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
448 .collect::<Result<Vec<_>>>()?;
449 for i in (dims[0] - edge_items)..dims[0] {
450 vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
451 }
452 Tensor::cat(&vs, 0)
453 } else {
454 let vs: Vec<_> = (0..dims[0])
455 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
456 .collect::<Result<Vec<_>>>()?;
457 Tensor::cat(&vs, 0)
458 }
459}
460
461impl std::fmt::Display for Tensor {
462 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
463 let po = PRINT_OPTS.lock().unwrap();
464 let summarize = self.elem_count() > po.threshold;
465 let to_display = if summarize {
466 match get_summarized_data(self, po.edge_items) {
467 Ok(v) => v,
468 Err(err) => return write!(f, "{err:?}"),
469 }
470 } else {
471 self.clone()
472 };
473 match self.dtype() {
474 DType::U8 => {
475 let tf: IntFormatter<u8> = IntFormatter::new();
476 let max_w = tf.max_width(&to_display);
477 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
478 writeln!(f)?;
479 }
480 DType::U32 => {
481 let tf: IntFormatter<u32> = IntFormatter::new();
482 let max_w = tf.max_width(&to_display);
483 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
484 writeln!(f)?;
485 }
486 DType::I16 => {
487 let tf: IntFormatter<i16> = IntFormatter::new();
488 let max_w = tf.max_width(&to_display);
489 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
490 writeln!(f)?;
491 }
492 DType::I32 => {
493 let tf: IntFormatter<i32> = IntFormatter::new();
494 let max_w = tf.max_width(&to_display);
495 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
496 writeln!(f)?;
497 }
498 DType::I64 => {
499 let tf: IntFormatter<i64> = IntFormatter::new();
500 let max_w = tf.max_width(&to_display);
501 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
502 writeln!(f)?;
503 }
504 DType::BF16 => {
505 if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
506 let max_w = tf.max_width(&to_display);
507 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
508 writeln!(f)?;
509 }
510 }
511 DType::F16 => {
512 if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
513 let max_w = tf.max_width(&to_display);
514 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
515 writeln!(f)?;
516 }
517 }
518 DType::F64 => {
519 if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
520 let max_w = tf.max_width(&to_display);
521 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
522 writeln!(f)?;
523 }
524 }
525 DType::F32 => {
526 if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
527 let max_w = tf.max_width(&to_display);
528 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
529 writeln!(f)?;
530 }
531 }
532 DType::F8E4M3 => {
533 if let Ok(tf) = FloatFormatter::<float8::F8E4M3>::new(&to_display, &po) {
534 let max_w = tf.max_width(&to_display);
535 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
536 writeln!(f)?;
537 }
538 }
539 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
540 writeln!(
541 f,
542 "Dummy type {} (not supported for display)",
543 self.dtype().as_str()
544 )?;
545 }
546 };
547
548 let device_str = match self.device().location() {
549 crate::DeviceLocation::Cpu => "".to_owned(),
550 crate::DeviceLocation::Cuda { gpu_id } => {
551 format!(", cuda:{gpu_id}")
552 }
553 crate::DeviceLocation::Metal { gpu_id } => {
554 format!(", metal:{gpu_id}")
555 }
556 #[cfg(feature = "rocm")]
557 crate::DeviceLocation::Rocm { gpu_id } => {
558 format!(", rocm:{gpu_id}")
559 }
560 #[cfg(feature = "vulkan")]
561 crate::DeviceLocation::Vulkan { gpu_id } => {
562 format!(", vulkan:{gpu_id}")
563 }
564 };
565
566 write!(
567 f,
568 "Tensor[{:?}, {}{}]",
569 self.dims(),
570 self.dtype().as_str(),
571 device_str
572 )
573 }
574}