1use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9 fn fmt_dt<T: WithDType + std::fmt::Display>(
10 &self,
11 f: &mut std::fmt::Formatter,
12 ) -> std::fmt::Result {
13 let device_str = match self.device().location() {
14 crate::DeviceLocation::Cpu => "".to_owned(),
15 crate::DeviceLocation::Cuda { gpu_id } => {
16 format!(", cuda:{gpu_id}")
17 }
18 crate::DeviceLocation::Metal { gpu_id } => {
19 format!(", metal:{gpu_id}")
20 }
21 };
22
23 write!(f, "Tensor[")?;
24 match self.dims() {
25 [] => {
26 if let Ok(v) = self.to_scalar::<T>() {
27 write!(f, "{v}")?
28 }
29 }
30 [s] if *s < 10 => {
31 if let Ok(vs) = self.to_vec1::<T>() {
32 for (i, v) in vs.iter().enumerate() {
33 if i > 0 {
34 write!(f, ", ")?;
35 }
36 write!(f, "{v}")?;
37 }
38 }
39 }
40 dims => {
41 write!(f, "dims ")?;
42 for (i, d) in dims.iter().enumerate() {
43 if i > 0 {
44 write!(f, ", ")?;
45 }
46 write!(f, "{d}")?;
47 }
48 }
49 }
50 write!(f, "; {}{}]", self.dtype().as_str(), device_str)
51 }
52}
53
54impl std::fmt::Debug for Tensor {
55 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56 match self.dtype() {
57 DType::U8 => self.fmt_dt::<u8>(f),
58 DType::U32 => self.fmt_dt::<u32>(f),
59 DType::I16 => self.fmt_dt::<i16>(f),
60 DType::I32 => self.fmt_dt::<i32>(f),
61 DType::I64 => self.fmt_dt::<i64>(f),
62 DType::BF16 => self.fmt_dt::<bf16>(f),
63 DType::F16 => self.fmt_dt::<f16>(f),
64 DType::F32 => self.fmt_dt::<f32>(f),
65 DType::F64 => self.fmt_dt::<f64>(f),
66 DType::F8E4M3 => self.fmt_dt::<float8::F8E4M3>(f),
67 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
68 write!(
69 f,
70 "Tensor[{:?}; dtype={}, unsupported dummy type]",
71 self.shape(),
72 self.dtype().as_str()
73 )
74 }
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct PrinterOptions {
82 pub precision: usize,
83 pub threshold: usize,
84 pub edge_items: usize,
85 pub line_width: usize,
86 pub sci_mode: Option<bool>,
87}
88
89static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
90 std::sync::Mutex::new(PrinterOptions::const_default());
91
92impl PrinterOptions {
93 const fn const_default() -> Self {
95 Self {
96 precision: 4,
97 threshold: 1000,
98 edge_items: 3,
99 line_width: 80,
100 sci_mode: None,
101 }
102 }
103}
104
105pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
106 &PRINT_OPTS
107}
108
109pub fn set_print_options(options: PrinterOptions) {
110 *PRINT_OPTS.lock().unwrap() = options
111}
112
113pub fn set_print_options_default() {
114 *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
115}
116
117pub fn set_print_options_short() {
118 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
119 precision: 2,
120 threshold: 1000,
121 edge_items: 2,
122 line_width: 80,
123 sci_mode: None,
124 }
125}
126
127pub fn set_print_options_full() {
128 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
129 precision: 4,
130 threshold: usize::MAX,
131 edge_items: 3,
132 line_width: 80,
133 sci_mode: None,
134 }
135}
136
137pub fn set_line_width(line_width: usize) {
138 PRINT_OPTS.lock().unwrap().line_width = line_width
139}
140
141pub fn set_precision(precision: usize) {
142 PRINT_OPTS.lock().unwrap().precision = precision
143}
144
145pub fn set_edge_items(edge_items: usize) {
146 PRINT_OPTS.lock().unwrap().edge_items = edge_items
147}
148
149pub fn set_threshold(threshold: usize) {
150 PRINT_OPTS.lock().unwrap().threshold = threshold
151}
152
153pub fn set_sci_mode(sci_mode: Option<bool>) {
154 PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
155}
156
157struct FmtSize {
158 current_size: usize,
159}
160
161impl FmtSize {
162 fn new() -> Self {
163 Self { current_size: 0 }
164 }
165
166 fn final_size(self) -> usize {
167 self.current_size
168 }
169}
170
171impl std::fmt::Write for FmtSize {
172 fn write_str(&mut self, s: &str) -> std::fmt::Result {
173 self.current_size += s.len();
174 Ok(())
175 }
176}
177
178trait TensorFormatter {
179 type Elem: WithDType;
180
181 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
182
183 fn max_width(&self, to_display: &Tensor) -> usize {
184 let mut max_width = 1;
185 if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
186 for &v in vs.iter() {
187 let mut fmt_size = FmtSize::new();
188 let _res = self.fmt(v, 1, &mut fmt_size);
189 max_width = usize::max(max_width, fmt_size.final_size())
190 }
191 }
192 max_width
193 }
194
195 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
196 writeln!(f)?;
197 for _ in 0..i {
198 write!(f, " ")?
199 }
200 Ok(())
201 }
202
203 fn fmt_tensor(
204 &self,
205 t: &Tensor,
206 indent: usize,
207 max_w: usize,
208 summarize: bool,
209 po: &PrinterOptions,
210 f: &mut std::fmt::Formatter,
211 ) -> std::fmt::Result {
212 let dims = t.dims();
213 let edge_items = po.edge_items;
214 write!(f, "[")?;
215 match dims {
216 [] => {
217 if let Ok(v) = t.to_scalar::<Self::Elem>() {
218 self.fmt(v, max_w, f)?
219 }
220 }
221 [v] if summarize && *v > 2 * edge_items => {
222 if let Ok(vs) = t
223 .narrow(0, 0, edge_items)
224 .and_then(|t| t.to_vec1::<Self::Elem>())
225 {
226 for v in vs.into_iter() {
227 self.fmt(v, max_w, f)?;
228 write!(f, ", ")?;
229 }
230 }
231 write!(f, "...")?;
232 if let Ok(vs) = t
233 .narrow(0, v - edge_items, edge_items)
234 .and_then(|t| t.to_vec1::<Self::Elem>())
235 {
236 for v in vs.into_iter() {
237 write!(f, ", ")?;
238 self.fmt(v, max_w, f)?;
239 }
240 }
241 }
242 [_] => {
243 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
244 if let Ok(vs) = t.to_vec1::<Self::Elem>() {
245 for (i, v) in vs.into_iter().enumerate() {
246 if i > 0 {
247 if i % elements_per_line == 0 {
248 write!(f, ",")?;
249 Self::write_newline_indent(indent, f)?
250 } else {
251 write!(f, ", ")?;
252 }
253 }
254 self.fmt(v, max_w, f)?
255 }
256 }
257 }
258 _ => {
259 if summarize && dims[0] > 2 * edge_items {
260 for i in 0..edge_items {
261 match t.get(i) {
262 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
263 Err(e) => write!(f, "{e:?}")?,
264 }
265 write!(f, ",")?;
266 Self::write_newline_indent(indent, f)?
267 }
268 write!(f, "...")?;
269 Self::write_newline_indent(indent, f)?;
270 for i in dims[0] - edge_items..dims[0] {
271 match t.get(i) {
272 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
273 Err(e) => write!(f, "{e:?}")?,
274 }
275 if i + 1 != dims[0] {
276 write!(f, ",")?;
277 Self::write_newline_indent(indent, f)?
278 }
279 }
280 } else {
281 for i in 0..dims[0] {
282 match t.get(i) {
283 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
284 Err(e) => write!(f, "{e:?}")?,
285 }
286 if i + 1 != dims[0] {
287 write!(f, ",")?;
288 Self::write_newline_indent(indent, f)?
289 }
290 }
291 }
292 }
293 }
294 write!(f, "]")?;
295 Ok(())
296 }
297}
298
299struct FloatFormatter<S: WithDType> {
300 int_mode: bool,
301 sci_mode: bool,
302 precision: usize,
303 _phantom: std::marker::PhantomData<S>,
304}
305
306impl<S> FloatFormatter<S>
307where
308 S: WithDType + num_traits::Float + std::fmt::Display,
309{
310 fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
311 let mut int_mode = true;
312 let mut sci_mode = false;
313
314 let values = t
317 .flatten_all()?
318 .to_vec1()?
319 .into_iter()
320 .filter(|v: &S| v.is_finite() && !v.is_zero())
321 .collect::<Vec<_>>();
322 if !values.is_empty() {
323 let mut nonzero_finite_min = S::max_value();
324 let mut nonzero_finite_max = S::min_value();
325 for &v in values.iter() {
326 let v = v.abs();
327 if v < nonzero_finite_min {
328 nonzero_finite_min = v
329 }
330 if v > nonzero_finite_max {
331 nonzero_finite_max = v
332 }
333 }
334
335 for &value in values.iter() {
336 if value.ceil() != value {
337 int_mode = false;
338 break;
339 }
340 }
341 if let Some(v1) = S::from(1000.) {
342 if let Some(v2) = S::from(1e8) {
343 if let Some(v3) = S::from(1e-4) {
344 sci_mode = nonzero_finite_max / nonzero_finite_min > v1
345 || nonzero_finite_max > v2
346 || nonzero_finite_min < v3
347 }
348 }
349 }
350 }
351
352 match po.sci_mode {
353 None => {}
354 Some(v) => sci_mode = v,
355 }
356 Ok(Self {
357 int_mode,
358 sci_mode,
359 precision: po.precision,
360 _phantom: std::marker::PhantomData,
361 })
362 }
363}
364
365impl<S> TensorFormatter for FloatFormatter<S>
366where
367 S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
368{
369 type Elem = S;
370
371 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
372 if self.sci_mode {
373 write!(
374 f,
375 "{v:width$.prec$e}",
376 v = v,
377 width = max_w,
378 prec = self.precision
379 )
380 } else if self.int_mode {
381 if v.is_finite() {
382 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
383 } else {
384 write!(f, "{v:max_w$.0}")
385 }
386 } else {
387 write!(
388 f,
389 "{v:width$.prec$}",
390 v = v,
391 width = max_w,
392 prec = self.precision
393 )
394 }
395 }
396}
397
398struct IntFormatter<S: WithDType> {
399 _phantom: std::marker::PhantomData<S>,
400}
401
402impl<S: WithDType> IntFormatter<S> {
403 fn new() -> Self {
404 Self {
405 _phantom: std::marker::PhantomData,
406 }
407 }
408}
409
410impl<S> TensorFormatter for IntFormatter<S>
411where
412 S: WithDType + std::fmt::Display,
413{
414 type Elem = S;
415
416 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
417 write!(f, "{v:max_w$}")
418 }
419}
420
421fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
422 let dims = t.dims();
423 if dims.is_empty() {
424 Ok(t.clone())
425 } else if dims.len() == 1 {
426 if dims[0] > 2 * edge_items {
427 Tensor::cat(
428 &[
429 t.narrow(0, 0, edge_items)?,
430 t.narrow(0, dims[0] - edge_items, edge_items)?,
431 ],
432 0,
433 )
434 } else {
435 Ok(t.clone())
436 }
437 } else if dims[0] > 2 * edge_items {
438 let mut vs: Vec<_> = (0..edge_items)
439 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
440 .collect::<Result<Vec<_>>>()?;
441 for i in (dims[0] - edge_items)..dims[0] {
442 vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
443 }
444 Tensor::cat(&vs, 0)
445 } else {
446 let vs: Vec<_> = (0..dims[0])
447 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
448 .collect::<Result<Vec<_>>>()?;
449 Tensor::cat(&vs, 0)
450 }
451}
452
453impl std::fmt::Display for Tensor {
454 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
455 let po = PRINT_OPTS.lock().unwrap();
456 let summarize = self.elem_count() > po.threshold;
457 let to_display = if summarize {
458 match get_summarized_data(self, po.edge_items) {
459 Ok(v) => v,
460 Err(err) => return write!(f, "{err:?}"),
461 }
462 } else {
463 self.clone()
464 };
465 match self.dtype() {
466 DType::U8 => {
467 let tf: IntFormatter<u8> = IntFormatter::new();
468 let max_w = tf.max_width(&to_display);
469 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
470 writeln!(f)?;
471 }
472 DType::U32 => {
473 let tf: IntFormatter<u32> = IntFormatter::new();
474 let max_w = tf.max_width(&to_display);
475 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
476 writeln!(f)?;
477 }
478 DType::I16 => {
479 let tf: IntFormatter<i16> = IntFormatter::new();
480 let max_w = tf.max_width(&to_display);
481 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
482 writeln!(f)?;
483 }
484 DType::I32 => {
485 let tf: IntFormatter<i32> = IntFormatter::new();
486 let max_w = tf.max_width(&to_display);
487 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
488 writeln!(f)?;
489 }
490 DType::I64 => {
491 let tf: IntFormatter<i64> = IntFormatter::new();
492 let max_w = tf.max_width(&to_display);
493 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
494 writeln!(f)?;
495 }
496 DType::BF16 => {
497 if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
498 let max_w = tf.max_width(&to_display);
499 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
500 writeln!(f)?;
501 }
502 }
503 DType::F16 => {
504 if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
505 let max_w = tf.max_width(&to_display);
506 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
507 writeln!(f)?;
508 }
509 }
510 DType::F64 => {
511 if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
512 let max_w = tf.max_width(&to_display);
513 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
514 writeln!(f)?;
515 }
516 }
517 DType::F32 => {
518 if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
519 let max_w = tf.max_width(&to_display);
520 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
521 writeln!(f)?;
522 }
523 }
524 DType::F8E4M3 => {
525 if let Ok(tf) = FloatFormatter::<float8::F8E4M3>::new(&to_display, &po) {
526 let max_w = tf.max_width(&to_display);
527 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
528 writeln!(f)?;
529 }
530 }
531 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
532 writeln!(
533 f,
534 "Dummy type {} (not supported for display)",
535 self.dtype().as_str()
536 )?;
537 }
538 };
539
540 let device_str = match self.device().location() {
541 crate::DeviceLocation::Cpu => "".to_owned(),
542 crate::DeviceLocation::Cuda { gpu_id } => {
543 format!(", cuda:{gpu_id}")
544 }
545 crate::DeviceLocation::Metal { gpu_id } => {
546 format!(", metal:{gpu_id}")
547 }
548 };
549
550 write!(
551 f,
552 "Tensor[{:?}, {}{}]",
553 self.dims(),
554 self.dtype().as_str(),
555 device_str
556 )
557 }
558}