1use crate::{DType, Result, Tensor, WithDType};
5use half::{bf16, f16};
6
7impl Tensor {
8 fn fmt_dt<T: WithDType + std::fmt::Display>(
9 &self,
10 f: &mut std::fmt::Formatter,
11 ) -> std::fmt::Result {
12 let device_str = match self.device().location() {
13 crate::DeviceLocation::Cpu => "".to_owned(),
14 crate::DeviceLocation::Cuda { gpu_id } => {
15 format!(", cuda:{}", gpu_id)
16 }
17 };
18
19 write!(f, "Tensor[")?;
20 match self.dims() {
21 [] => {
22 if let Ok(v) = self.to_scalar::<T>() {
23 write!(f, "{v}")?
24 }
25 }
26 [s] if *s < 10 => {
27 if let Ok(vs) = self.to_vec1::<T>() {
28 for (i, v) in vs.iter().enumerate() {
29 if i > 0 {
30 write!(f, ", ")?;
31 }
32 write!(f, "{v}")?;
33 }
34 }
35 }
36 dims => {
37 write!(f, "dims ")?;
38 for (i, d) in dims.iter().enumerate() {
39 if i > 0 {
40 write!(f, ", ")?;
41 }
42 write!(f, "{d}")?;
43 }
44 }
45 }
46 write!(f, "; {}{}]", self.dtype().as_str(), device_str)
47 }
48}
49
50impl std::fmt::Debug for Tensor {
51 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52 match self.dtype() {
53 DType::U8 => self.fmt_dt::<u8>(f),
54 DType::U32 => self.fmt_dt::<u32>(f),
55 DType::I64 => self.fmt_dt::<i64>(f),
56 DType::BF16 => self.fmt_dt::<bf16>(f),
57 DType::F16 => self.fmt_dt::<f16>(f),
58 DType::F32 => self.fmt_dt::<f32>(f),
59 DType::F64 => self.fmt_dt::<f64>(f),
60 }
61 }
62}
63
64pub struct PrinterOptions {
66 precision: usize,
67 threshold: usize,
68 edge_items: usize,
69 line_width: usize,
70 sci_mode: Option<bool>,
71}
72
73static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
74 std::sync::Mutex::new(PrinterOptions::const_default());
75
76impl PrinterOptions {
77 const fn const_default() -> Self {
79 Self {
80 precision: 4,
81 threshold: 1000,
82 edge_items: 3,
83 line_width: 80,
84 sci_mode: None,
85 }
86 }
87}
88
89pub fn set_print_options(options: PrinterOptions) {
90 *PRINT_OPTS.lock().unwrap() = options
91}
92
93pub fn set_print_options_default() {
94 *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
95}
96
97pub fn set_print_options_short() {
98 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
99 precision: 2,
100 threshold: 1000,
101 edge_items: 2,
102 line_width: 80,
103 sci_mode: None,
104 }
105}
106
107pub fn set_print_options_full() {
108 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
109 precision: 4,
110 threshold: usize::MAX,
111 edge_items: 3,
112 line_width: 80,
113 sci_mode: None,
114 }
115}
116
117struct FmtSize {
118 current_size: usize,
119}
120
121impl FmtSize {
122 fn new() -> Self {
123 Self { current_size: 0 }
124 }
125
126 fn final_size(self) -> usize {
127 self.current_size
128 }
129}
130
131impl std::fmt::Write for FmtSize {
132 fn write_str(&mut self, s: &str) -> std::fmt::Result {
133 self.current_size += s.len();
134 Ok(())
135 }
136}
137
138trait TensorFormatter {
139 type Elem: WithDType;
140
141 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
142
143 fn max_width(&self, to_display: &Tensor) -> usize {
144 let mut max_width = 1;
145 if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
146 for &v in vs.iter() {
147 let mut fmt_size = FmtSize::new();
148 let _res = self.fmt(v, 1, &mut fmt_size);
149 max_width = usize::max(max_width, fmt_size.final_size())
150 }
151 }
152 max_width
153 }
154
155 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
156 writeln!(f)?;
157 for _ in 0..i {
158 write!(f, " ")?
159 }
160 Ok(())
161 }
162
163 fn fmt_tensor(
164 &self,
165 t: &Tensor,
166 indent: usize,
167 max_w: usize,
168 summarize: bool,
169 po: &PrinterOptions,
170 f: &mut std::fmt::Formatter,
171 ) -> std::fmt::Result {
172 let dims = t.dims();
173 let edge_items = po.edge_items;
174 write!(f, "[")?;
175 match dims {
176 [] => {
177 if let Ok(v) = t.to_scalar::<Self::Elem>() {
178 self.fmt(v, max_w, f)?
179 }
180 }
181 [v] if summarize && *v > 2 * edge_items => {
182 if let Ok(vs) = t
183 .narrow(0, 0, edge_items)
184 .and_then(|t| t.to_vec1::<Self::Elem>())
185 {
186 for v in vs.into_iter() {
187 self.fmt(v, max_w, f)?;
188 write!(f, ", ")?;
189 }
190 }
191 write!(f, "...")?;
192 if let Ok(vs) = t
193 .narrow(0, v - edge_items, edge_items)
194 .and_then(|t| t.to_vec1::<Self::Elem>())
195 {
196 for v in vs.into_iter() {
197 write!(f, ", ")?;
198 self.fmt(v, max_w, f)?;
199 }
200 }
201 }
202 [_] => {
203 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
204 if let Ok(vs) = t.to_vec1::<Self::Elem>() {
205 for (i, v) in vs.into_iter().enumerate() {
206 if i > 0 {
207 if i % elements_per_line == 0 {
208 write!(f, ",")?;
209 Self::write_newline_indent(indent, f)?
210 } else {
211 write!(f, ", ")?;
212 }
213 }
214 self.fmt(v, max_w, f)?
215 }
216 }
217 }
218 _ => {
219 if summarize && dims[0] > 2 * edge_items {
220 for i in 0..edge_items {
221 match t.get(i) {
222 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
223 Err(e) => write!(f, "{e:?}")?,
224 }
225 write!(f, ",")?;
226 Self::write_newline_indent(indent, f)?
227 }
228 write!(f, "...")?;
229 Self::write_newline_indent(indent, f)?;
230 for i in dims[0] - edge_items..dims[0] {
231 match t.get(i) {
232 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
233 Err(e) => write!(f, "{e:?}")?,
234 }
235 if i + 1 != dims[0] {
236 write!(f, ",")?;
237 Self::write_newline_indent(indent, f)?
238 }
239 }
240 } else {
241 for i in 0..dims[0] {
242 match t.get(i) {
243 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
244 Err(e) => write!(f, "{e:?}")?,
245 }
246 if i + 1 != dims[0] {
247 write!(f, ",")?;
248 Self::write_newline_indent(indent, f)?
249 }
250 }
251 }
252 }
253 }
254 write!(f, "]")?;
255 Ok(())
256 }
257}
258
259struct FloatFormatter<S: WithDType> {
260 int_mode: bool,
261 sci_mode: bool,
262 precision: usize,
263 _phantom: std::marker::PhantomData<S>,
264}
265
266impl<S> FloatFormatter<S>
267where
268 S: WithDType + num_traits::Float + std::fmt::Display,
269{
270 fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
271 let mut int_mode = true;
272 let mut sci_mode = false;
273
274 let values = t
277 .flatten_all()?
278 .to_vec1()?
279 .into_iter()
280 .filter(|v: &S| v.is_finite() && !v.is_zero())
281 .collect::<Vec<_>>();
282 if !values.is_empty() {
283 let mut nonzero_finite_min = S::max_value();
284 let mut nonzero_finite_max = S::min_value();
285 for &v in values.iter() {
286 let v = v.abs();
287 if v < nonzero_finite_min {
288 nonzero_finite_min = v
289 }
290 if v > nonzero_finite_max {
291 nonzero_finite_max = v
292 }
293 }
294
295 for &value in values.iter() {
296 if value.ceil() != value {
297 int_mode = false;
298 break;
299 }
300 }
301 if let Some(v1) = S::from(1000.) {
302 if let Some(v2) = S::from(1e8) {
303 if let Some(v3) = S::from(1e-4) {
304 sci_mode = nonzero_finite_max / nonzero_finite_min > v1
305 || nonzero_finite_max > v2
306 || nonzero_finite_min < v3
307 }
308 }
309 }
310 }
311
312 match po.sci_mode {
313 None => {}
314 Some(v) => sci_mode = v,
315 }
316 Ok(Self {
317 int_mode,
318 sci_mode,
319 precision: po.precision,
320 _phantom: std::marker::PhantomData,
321 })
322 }
323}
324
325impl<S> TensorFormatter for FloatFormatter<S>
326where
327 S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
328{
329 type Elem = S;
330
331 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
332 if self.sci_mode {
333 write!(
334 f,
335 "{v:width$.prec$e}",
336 v = v,
337 width = max_w,
338 prec = self.precision
339 )
340 } else if self.int_mode {
341 if v.is_finite() {
342 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
343 } else {
344 write!(f, "{v:max_w$.0}")
345 }
346 } else {
347 write!(
348 f,
349 "{v:width$.prec$}",
350 v = v,
351 width = max_w,
352 prec = self.precision
353 )
354 }
355 }
356}
357
358struct IntFormatter<S: WithDType> {
359 _phantom: std::marker::PhantomData<S>,
360}
361
362impl<S: WithDType> IntFormatter<S> {
363 fn new() -> Self {
364 Self {
365 _phantom: std::marker::PhantomData,
366 }
367 }
368}
369
370impl<S> TensorFormatter for IntFormatter<S>
371where
372 S: WithDType + std::fmt::Display,
373{
374 type Elem = S;
375
376 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
377 write!(f, "{v:max_w$}")
378 }
379}
380
381fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
382 let dims = t.dims();
383 if dims.is_empty() {
384 Ok(t.clone())
385 } else if dims.len() == 1 {
386 if dims[0] > 2 * edge_items {
387 Tensor::cat(
388 &[
389 t.narrow(0, 0, edge_items)?,
390 t.narrow(0, dims[0] - edge_items, edge_items)?,
391 ],
392 0,
393 )
394 } else {
395 Ok(t.clone())
396 }
397 } else if dims[0] > 2 * edge_items {
398 let mut vs: Vec<_> = (0..edge_items)
399 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
400 .collect::<Result<Vec<_>>>()?;
401 for i in (dims[0] - edge_items)..dims[0] {
402 vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
403 }
404 Tensor::cat(&vs, 0)
405 } else {
406 let vs: Vec<_> = (0..dims[0])
407 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
408 .collect::<Result<Vec<_>>>()?;
409 Tensor::cat(&vs, 0)
410 }
411}
412
413impl std::fmt::Display for Tensor {
414 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
415 let po = PRINT_OPTS.lock().unwrap();
416 let summarize = self.elem_count() > po.threshold;
417 let to_display = if summarize {
418 match get_summarized_data(self, po.edge_items) {
419 Ok(v) => v,
420 Err(err) => return write!(f, "{err:?}"),
421 }
422 } else {
423 self.clone()
424 };
425 match self.dtype() {
426 DType::U8 => {
427 let tf: IntFormatter<u8> = IntFormatter::new();
428 let max_w = tf.max_width(&to_display);
429 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
430 writeln!(f)?;
431 }
432 DType::U32 => {
433 let tf: IntFormatter<u32> = IntFormatter::new();
434 let max_w = tf.max_width(&to_display);
435 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
436 writeln!(f)?;
437 }
438 DType::I64 => {
439 let tf: IntFormatter<i64> = IntFormatter::new();
440 let max_w = tf.max_width(&to_display);
441 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
442 writeln!(f)?;
443 }
444 DType::BF16 => {
445 if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
446 let max_w = tf.max_width(&to_display);
447 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
448 writeln!(f)?;
449 }
450 }
451 DType::F16 => {
452 if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
453 let max_w = tf.max_width(&to_display);
454 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
455 writeln!(f)?;
456 }
457 }
458 DType::F64 => {
459 if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
460 let max_w = tf.max_width(&to_display);
461 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
462 writeln!(f)?;
463 }
464 }
465 DType::F32 => {
466 if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
467 let max_w = tf.max_width(&to_display);
468 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
469 writeln!(f)?;
470 }
471 }
472 };
473
474 let device_str = match self.device().location() {
475 crate::DeviceLocation::Cpu => "".to_owned(),
476 crate::DeviceLocation::Cuda { gpu_id } => {
477 format!(", cuda:{}", gpu_id)
478 }
479 };
480
481 write!(
482 f,
483 "Tensor[{:?}, {}{}]",
484 self.dims(),
485 self.dtype().as_str(),
486 device_str
487 )
488 }
489}