1use crate::{DType, Result, Tensor, WithDType};
6use half::{bf16, f16};
7
8impl Tensor {
9 fn fmt_dt<T: WithDType + std::fmt::Display>(
10 &self,
11 f: &mut std::fmt::Formatter,
12 ) -> std::fmt::Result {
13 let device_str = match self.device().location() {
14 crate::DeviceLocation::Cpu => "".to_owned(),
15 crate::DeviceLocation::Cuda { gpu_id } => {
16 format!(", cuda:{}", gpu_id)
17 }
18 crate::DeviceLocation::Metal { gpu_id } => {
19 format!(", metal:{}", gpu_id)
20 }
21 };
22
23 write!(f, "Tensor[")?;
24 match self.dims() {
25 [] => {
26 if let Ok(v) = self.to_scalar::<T>() {
27 write!(f, "{v}")?
28 }
29 }
30 [s] if *s < 10 => {
31 if let Ok(vs) = self.to_vec1::<T>() {
32 for (i, v) in vs.iter().enumerate() {
33 if i > 0 {
34 write!(f, ", ")?;
35 }
36 write!(f, "{v}")?;
37 }
38 }
39 }
40 dims => {
41 write!(f, "dims ")?;
42 for (i, d) in dims.iter().enumerate() {
43 if i > 0 {
44 write!(f, ", ")?;
45 }
46 write!(f, "{d}")?;
47 }
48 }
49 }
50 write!(f, "; {}{}]", self.dtype().as_str(), device_str)
51 }
52}
53
54impl std::fmt::Debug for Tensor {
55 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
56 match self.dtype() {
57 DType::U8 => self.fmt_dt::<u8>(f),
58 DType::U32 => self.fmt_dt::<u32>(f),
59 DType::I64 => self.fmt_dt::<i64>(f),
60 DType::BF16 => self.fmt_dt::<bf16>(f),
61 DType::F16 => self.fmt_dt::<f16>(f),
62 DType::F32 => self.fmt_dt::<f32>(f),
63 DType::F64 => self.fmt_dt::<f64>(f),
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct PrinterOptions {
71 pub precision: usize,
72 pub threshold: usize,
73 pub edge_items: usize,
74 pub line_width: usize,
75 pub sci_mode: Option<bool>,
76}
77
78static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
79 std::sync::Mutex::new(PrinterOptions::const_default());
80
81impl PrinterOptions {
82 const fn const_default() -> Self {
84 Self {
85 precision: 4,
86 threshold: 1000,
87 edge_items: 3,
88 line_width: 80,
89 sci_mode: None,
90 }
91 }
92}
93
94pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
95 &PRINT_OPTS
96}
97
98pub fn set_print_options(options: PrinterOptions) {
99 *PRINT_OPTS.lock().unwrap() = options
100}
101
102pub fn set_print_options_default() {
103 *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
104}
105
106pub fn set_print_options_short() {
107 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
108 precision: 2,
109 threshold: 1000,
110 edge_items: 2,
111 line_width: 80,
112 sci_mode: None,
113 }
114}
115
116pub fn set_print_options_full() {
117 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
118 precision: 4,
119 threshold: usize::MAX,
120 edge_items: 3,
121 line_width: 80,
122 sci_mode: None,
123 }
124}
125
126pub fn set_line_width(line_width: usize) {
127 PRINT_OPTS.lock().unwrap().line_width = line_width
128}
129
130pub fn set_precision(precision: usize) {
131 PRINT_OPTS.lock().unwrap().precision = precision
132}
133
134pub fn set_edge_items(edge_items: usize) {
135 PRINT_OPTS.lock().unwrap().edge_items = edge_items
136}
137
138pub fn set_threshold(threshold: usize) {
139 PRINT_OPTS.lock().unwrap().threshold = threshold
140}
141
142pub fn set_sci_mode(sci_mode: Option<bool>) {
143 PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
144}
145
146struct FmtSize {
147 current_size: usize,
148}
149
150impl FmtSize {
151 fn new() -> Self {
152 Self { current_size: 0 }
153 }
154
155 fn final_size(self) -> usize {
156 self.current_size
157 }
158}
159
160impl std::fmt::Write for FmtSize {
161 fn write_str(&mut self, s: &str) -> std::fmt::Result {
162 self.current_size += s.len();
163 Ok(())
164 }
165}
166
167trait TensorFormatter {
168 type Elem: WithDType;
169
170 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
171
172 fn max_width(&self, to_display: &Tensor) -> usize {
173 let mut max_width = 1;
174 if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
175 for &v in vs.iter() {
176 let mut fmt_size = FmtSize::new();
177 let _res = self.fmt(v, 1, &mut fmt_size);
178 max_width = usize::max(max_width, fmt_size.final_size())
179 }
180 }
181 max_width
182 }
183
184 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
185 writeln!(f)?;
186 for _ in 0..i {
187 write!(f, " ")?
188 }
189 Ok(())
190 }
191
192 fn fmt_tensor(
193 &self,
194 t: &Tensor,
195 indent: usize,
196 max_w: usize,
197 summarize: bool,
198 po: &PrinterOptions,
199 f: &mut std::fmt::Formatter,
200 ) -> std::fmt::Result {
201 let dims = t.dims();
202 let edge_items = po.edge_items;
203 write!(f, "[")?;
204 match dims {
205 [] => {
206 if let Ok(v) = t.to_scalar::<Self::Elem>() {
207 self.fmt(v, max_w, f)?
208 }
209 }
210 [v] if summarize && *v > 2 * edge_items => {
211 if let Ok(vs) = t
212 .narrow(0, 0, edge_items)
213 .and_then(|t| t.to_vec1::<Self::Elem>())
214 {
215 for v in vs.into_iter() {
216 self.fmt(v, max_w, f)?;
217 write!(f, ", ")?;
218 }
219 }
220 write!(f, "...")?;
221 if let Ok(vs) = t
222 .narrow(0, v - edge_items, edge_items)
223 .and_then(|t| t.to_vec1::<Self::Elem>())
224 {
225 for v in vs.into_iter() {
226 write!(f, ", ")?;
227 self.fmt(v, max_w, f)?;
228 }
229 }
230 }
231 [_] => {
232 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
233 if let Ok(vs) = t.to_vec1::<Self::Elem>() {
234 for (i, v) in vs.into_iter().enumerate() {
235 if i > 0 {
236 if i % elements_per_line == 0 {
237 write!(f, ",")?;
238 Self::write_newline_indent(indent, f)?
239 } else {
240 write!(f, ", ")?;
241 }
242 }
243 self.fmt(v, max_w, f)?
244 }
245 }
246 }
247 _ => {
248 if summarize && dims[0] > 2 * edge_items {
249 for i in 0..edge_items {
250 match t.get(i) {
251 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
252 Err(e) => write!(f, "{e:?}")?,
253 }
254 write!(f, ",")?;
255 Self::write_newline_indent(indent, f)?
256 }
257 write!(f, "...")?;
258 Self::write_newline_indent(indent, f)?;
259 for i in dims[0] - edge_items..dims[0] {
260 match t.get(i) {
261 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
262 Err(e) => write!(f, "{e:?}")?,
263 }
264 if i + 1 != dims[0] {
265 write!(f, ",")?;
266 Self::write_newline_indent(indent, f)?
267 }
268 }
269 } else {
270 for i in 0..dims[0] {
271 match t.get(i) {
272 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
273 Err(e) => write!(f, "{e:?}")?,
274 }
275 if i + 1 != dims[0] {
276 write!(f, ",")?;
277 Self::write_newline_indent(indent, f)?
278 }
279 }
280 }
281 }
282 }
283 write!(f, "]")?;
284 Ok(())
285 }
286}
287
288struct FloatFormatter<S: WithDType> {
289 int_mode: bool,
290 sci_mode: bool,
291 precision: usize,
292 _phantom: std::marker::PhantomData<S>,
293}
294
295impl<S> FloatFormatter<S>
296where
297 S: WithDType + num_traits::Float + std::fmt::Display,
298{
299 fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
300 let mut int_mode = true;
301 let mut sci_mode = false;
302
303 let values = t
306 .flatten_all()?
307 .to_vec1()?
308 .into_iter()
309 .filter(|v: &S| v.is_finite() && !v.is_zero())
310 .collect::<Vec<_>>();
311 if !values.is_empty() {
312 let mut nonzero_finite_min = S::max_value();
313 let mut nonzero_finite_max = S::min_value();
314 for &v in values.iter() {
315 let v = v.abs();
316 if v < nonzero_finite_min {
317 nonzero_finite_min = v
318 }
319 if v > nonzero_finite_max {
320 nonzero_finite_max = v
321 }
322 }
323
324 for &value in values.iter() {
325 if value.ceil() != value {
326 int_mode = false;
327 break;
328 }
329 }
330 if let Some(v1) = S::from(1000.) {
331 if let Some(v2) = S::from(1e8) {
332 if let Some(v3) = S::from(1e-4) {
333 sci_mode = nonzero_finite_max / nonzero_finite_min > v1
334 || nonzero_finite_max > v2
335 || nonzero_finite_min < v3
336 }
337 }
338 }
339 }
340
341 match po.sci_mode {
342 None => {}
343 Some(v) => sci_mode = v,
344 }
345 Ok(Self {
346 int_mode,
347 sci_mode,
348 precision: po.precision,
349 _phantom: std::marker::PhantomData,
350 })
351 }
352}
353
354impl<S> TensorFormatter for FloatFormatter<S>
355where
356 S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
357{
358 type Elem = S;
359
360 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
361 if self.sci_mode {
362 write!(
363 f,
364 "{v:width$.prec$e}",
365 v = v,
366 width = max_w,
367 prec = self.precision
368 )
369 } else if self.int_mode {
370 if v.is_finite() {
371 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
372 } else {
373 write!(f, "{v:max_w$.0}")
374 }
375 } else {
376 write!(
377 f,
378 "{v:width$.prec$}",
379 v = v,
380 width = max_w,
381 prec = self.precision
382 )
383 }
384 }
385}
386
387struct IntFormatter<S: WithDType> {
388 _phantom: std::marker::PhantomData<S>,
389}
390
391impl<S: WithDType> IntFormatter<S> {
392 fn new() -> Self {
393 Self {
394 _phantom: std::marker::PhantomData,
395 }
396 }
397}
398
399impl<S> TensorFormatter for IntFormatter<S>
400where
401 S: WithDType + std::fmt::Display,
402{
403 type Elem = S;
404
405 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
406 write!(f, "{v:max_w$}")
407 }
408}
409
410fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
411 let dims = t.dims();
412 if dims.is_empty() {
413 Ok(t.clone())
414 } else if dims.len() == 1 {
415 if dims[0] > 2 * edge_items {
416 Tensor::cat(
417 &[
418 t.narrow(0, 0, edge_items)?,
419 t.narrow(0, dims[0] - edge_items, edge_items)?,
420 ],
421 0,
422 )
423 } else {
424 Ok(t.clone())
425 }
426 } else if dims[0] > 2 * edge_items {
427 let mut vs: Vec<_> = (0..edge_items)
428 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
429 .collect::<Result<Vec<_>>>()?;
430 for i in (dims[0] - edge_items)..dims[0] {
431 vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
432 }
433 Tensor::cat(&vs, 0)
434 } else {
435 let vs: Vec<_> = (0..dims[0])
436 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
437 .collect::<Result<Vec<_>>>()?;
438 Tensor::cat(&vs, 0)
439 }
440}
441
442impl std::fmt::Display for Tensor {
443 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
444 let po = PRINT_OPTS.lock().unwrap();
445 let summarize = self.elem_count() > po.threshold;
446 let to_display = if summarize {
447 match get_summarized_data(self, po.edge_items) {
448 Ok(v) => v,
449 Err(err) => return write!(f, "{err:?}"),
450 }
451 } else {
452 self.clone()
453 };
454 match self.dtype() {
455 DType::U8 => {
456 let tf: IntFormatter<u8> = IntFormatter::new();
457 let max_w = tf.max_width(&to_display);
458 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
459 writeln!(f)?;
460 }
461 DType::U32 => {
462 let tf: IntFormatter<u32> = IntFormatter::new();
463 let max_w = tf.max_width(&to_display);
464 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
465 writeln!(f)?;
466 }
467 DType::I64 => {
468 let tf: IntFormatter<i64> = IntFormatter::new();
469 let max_w = tf.max_width(&to_display);
470 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
471 writeln!(f)?;
472 }
473 DType::BF16 => {
474 if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
475 let max_w = tf.max_width(&to_display);
476 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
477 writeln!(f)?;
478 }
479 }
480 DType::F16 => {
481 if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
482 let max_w = tf.max_width(&to_display);
483 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
484 writeln!(f)?;
485 }
486 }
487 DType::F64 => {
488 if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
489 let max_w = tf.max_width(&to_display);
490 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
491 writeln!(f)?;
492 }
493 }
494 DType::F32 => {
495 if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
496 let max_w = tf.max_width(&to_display);
497 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
498 writeln!(f)?;
499 }
500 }
501 };
502
503 let device_str = match self.device().location() {
504 crate::DeviceLocation::Cpu => "".to_owned(),
505 crate::DeviceLocation::Cuda { gpu_id } => {
506 format!(", cuda:{}", gpu_id)
507 }
508 crate::DeviceLocation::Metal { gpu_id } => {
509 format!(", metal:{}", gpu_id)
510 }
511 };
512
513 write!(
514 f,
515 "Tensor[{:?}, {}{}]",
516 self.dims(),
517 self.dtype().as_str(),
518 device_str
519 )
520 }
521}