1use crate::{FloatCategory, FloatDType, IntCategory, IntDType, NumCategory, NumDType, Result, Tensor, WithDType};
6
7impl<T: WithDType> std::fmt::Debug for Tensor<T> {
12 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
13 write!(f, "Tensor[")?;
14 if self.is_meta() {
15 return write!(f, "Meta; {:?}, {}]", self.dims(), self.dtype())
16 }
17 match self.dims() {
18 [] => {
19 if let Ok(v) = self.to_scalar() {
20 write!(f, "{v}")?
21 }
22 }
23 [s] if *s < 10 => {
24 for (i, v) in self.to_vec().expect("Meta Tensor").iter().enumerate() {
25 if i > 0 {
26 write!(f, ", ")?;
27 }
28 write!(f, "{v}")?;
29 }
30 }
31 dims => {
32 write!(f, "dims ")?;
33 for (i, d) in dims.iter().enumerate() {
34 if i > 0 {
35 write!(f, ", ")?;
36 }
37 write!(f, "{d}")?;
38 }
39 }
40 }
41 write!(f, "; {}]", self.dtype())
42 }
43}
44
45#[derive(Debug, Clone)]
50pub struct PrinterOptions {
51 pub precision: usize,
52 pub threshold: usize,
53 pub edge_items: usize,
54 pub line_width: usize,
55 pub sci_mode: Option<bool>,
56}
57
58static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
59 std::sync::Mutex::new(PrinterOptions::const_default());
60
61impl PrinterOptions {
62 const fn const_default() -> Self {
64 Self {
65 precision: 4,
66 threshold: 1000,
67 edge_items: 3,
68 line_width: 80,
69 sci_mode: None,
70 }
71 }
72}
73
74pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
75 &PRINT_OPTS
76}
77
78pub fn set_print_options(options: PrinterOptions) {
79 *PRINT_OPTS.lock().unwrap() = options
80}
81
82pub fn set_print_options_default() {
83 *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
84}
85
86pub fn set_print_options_short() {
87 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
88 precision: 2,
89 threshold: 1000,
90 edge_items: 2,
91 line_width: 80,
92 sci_mode: None,
93 }
94}
95
96pub fn set_print_options_full() {
97 *PRINT_OPTS.lock().unwrap() = PrinterOptions {
98 precision: 4,
99 threshold: usize::MAX,
100 edge_items: 3,
101 line_width: 80,
102 sci_mode: None,
103 }
104}
105
106pub fn set_line_width(line_width: usize) {
107 PRINT_OPTS.lock().unwrap().line_width = line_width
108}
109
110pub fn set_precision(precision: usize) {
111 PRINT_OPTS.lock().unwrap().precision = precision
112}
113
114pub fn set_edge_items(edge_items: usize) {
115 PRINT_OPTS.lock().unwrap().edge_items = edge_items
116}
117
118pub fn set_threshold(threshold: usize) {
119 PRINT_OPTS.lock().unwrap().threshold = threshold
120}
121
122pub fn set_sci_mode(sci_mode: Option<bool>) {
123 PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
124}
125
126struct FmtSize {
127 current_size: usize,
128}
129
130impl FmtSize {
131 fn new() -> Self {
132 Self { current_size: 0 }
133 }
134
135 fn final_size(self) -> usize {
136 self.current_size
137 }
138}
139
140impl std::fmt::Write for FmtSize {
141 fn write_str(&mut self, s: &str) -> std::fmt::Result {
142 self.current_size += s.len();
143 Ok(())
144 }
145}
146
147trait TensorFormatter {
148 type Elem: WithDType;
149
150 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
151
152 fn max_width(&self, to_display: &Tensor<Self::Elem>) -> usize {
153 let mut max_width = 1;
154 if let Ok(t) = to_display.flatten_all() {
155 let vs = t.to_vec().expect("Meta Tensor");
156 for &v in vs.iter() {
157 let mut fmt_size = FmtSize::new();
158 let _res = self.fmt(v, 1, &mut fmt_size);
159 max_width = usize::max(max_width, fmt_size.final_size())
160 }
161 }
162 max_width
163 }
164
165 fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
166 writeln!(f)?;
167 for _ in 0..i {
168 write!(f, " ")?
169 }
170 Ok(())
171 }
172
173 fn fmt_tensor(
174 &self,
175 t: &Tensor<Self::Elem>,
176 indent: usize,
177 max_w: usize,
178 summarize: bool,
179 po: &PrinterOptions,
180 f: &mut std::fmt::Formatter,
181 ) -> std::fmt::Result {
182 let dims = t.dims();
183 let edge_items = po.edge_items;
184 write!(f, "[")?;
185 match dims {
186 [] => {
187 if let Ok(v) = t.to_scalar() {
188 self.fmt(v, max_w, f)?
189 }
190 }
191 [v] if summarize && *v > 2 * edge_items => {
192 if let Ok(t) = t
193 .narrow(0, 0, edge_items)
194 {
195 for v in t.to_vec().expect("Meta Tensor").into_iter() {
196 self.fmt(v, max_w, f)?;
197 write!(f, ", ")?;
198 }
199 }
200 write!(f, "...")?;
201 if let Ok(t) = t
202 .narrow(0, v - edge_items, edge_items)
203 {
204 for v in t.to_vec().expect("Meta Tensor").into_iter() {
205 write!(f, ", ")?;
206 self.fmt(v, max_w, f)?;
207 }
208 }
209 }
210 [_] => {
211 let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
212 let vs = t.to_vec().expect("Meta Tensor");
213 for (i, v) in vs.into_iter().enumerate() {
214 if i > 0 {
215 if i % elements_per_line == 0 {
216 write!(f, ",")?;
217 Self::write_newline_indent(indent, f)?
218 } else {
219 write!(f, ", ")?;
220 }
221 }
222 self.fmt(v, max_w, f)?
223 }
224 }
225 _ => {
226 if summarize && dims[0] > 2 * edge_items {
227 for i in 0..edge_items {
228 match t.get(i) {
229 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
230 Err(e) => write!(f, "{e:?}")?,
231 }
232 write!(f, ",")?;
233 Self::write_newline_indent(indent, f)?
234 }
235 write!(f, "...")?;
236 Self::write_newline_indent(indent, f)?;
237 for i in dims[0] - edge_items..dims[0] {
238 match t.get(i) {
239 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
240 Err(e) => write!(f, "{e:?}")?,
241 }
242 if i + 1 != dims[0] {
243 write!(f, ",")?;
244 Self::write_newline_indent(indent, f)?
245 }
246 }
247 } else {
248 for i in 0..dims[0] {
249 match t.get(i) {
250 Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
251 Err(e) => write!(f, "{e:?}")?,
252 }
253 if i + 1 != dims[0] {
254 write!(f, ",")?;
255 Self::write_newline_indent(indent, f)?
256 }
257 }
258 }
259 }
260 }
261 write!(f, "]")?;
262 Ok(())
263 }
264}
265
266struct FloatFormatter<S: FloatDType> {
267 int_mode: bool,
268 sci_mode: bool,
269 precision: usize,
270 _phantom: std::marker::PhantomData<S>,
271}
272
273impl<S: FloatDType> FloatFormatter<S> {
274 fn new(t: &Tensor<S>, po: &PrinterOptions) -> Result<Self> {
275 let mut int_mode = true;
276 let mut sci_mode = false;
277
278 let values = t
281 .flatten_all()?
282 .to_vec().expect("Meta Tensor")
283 .into_iter()
284 .filter(|v: &S| v.is_finite() && !v.is_zero())
285 .collect::<Vec<_>>();
286 if !values.is_empty() {
287 let mut nonzero_finite_min = S::MAX_VALUE;
288 let mut nonzero_finite_max = S::MIN_VALUE;
289 for &v in values.iter() {
290 let v = v.abs();
291 if v < nonzero_finite_min {
292 nonzero_finite_min = v
293 }
294 if v > nonzero_finite_max {
295 nonzero_finite_max = v
296 }
297 }
298
299 for &value in values.iter() {
300 if value.ceil() != value {
301 int_mode = false;
302 break;
303 }
304 }
305 if let Some(v1) = S::from(1000.) {
306 if let Some(v2) = S::from(1e8) {
307 if let Some(v3) = S::from(1e-4) {
308 sci_mode = nonzero_finite_max / nonzero_finite_min > v1
309 || nonzero_finite_max > v2
310 || nonzero_finite_min < v3
311 }
312 }
313 }
314 }
315
316 match po.sci_mode {
317 None => {}
318 Some(v) => sci_mode = v,
319 }
320 Ok(Self {
321 int_mode,
322 sci_mode,
323 precision: po.precision,
324 _phantom: std::marker::PhantomData,
325 })
326 }
327}
328
329impl<S: FloatDType> TensorFormatter for FloatFormatter<S> {
330 type Elem = S;
331
332 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
333 if self.sci_mode {
334 write!(
335 f,
336 "{v:width$.prec$e}",
337 v = v,
338 width = max_w,
339 prec = self.precision
340 )
341 } else if self.int_mode {
342 if v.is_finite() {
343 write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
344 } else {
345 write!(f, "{v:max_w$.0}")
346 }
347 } else {
348 write!(
349 f,
350 "{v:width$.prec$}",
351 v = v,
352 width = max_w,
353 prec = self.precision
354 )
355 }
356 }
357}
358
359struct IntFormatter<S: WithDType> {
360 _phantom: std::marker::PhantomData<S>,
361}
362
363impl<S: WithDType> IntFormatter<S> {
364 fn new() -> Self {
365 Self {
366 _phantom: std::marker::PhantomData,
367 }
368 }
369}
370
371impl<S: WithDType> TensorFormatter for IntFormatter<S> {
372 type Elem = S;
373
374 fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
375 write!(f, "{v:max_w$}")
376 }
377}
378
379fn get_summarized_data<T: WithDType>(t: &Tensor<T>, edge_items: usize) -> Result<Tensor<T>> {
380 let dims = t.dims();
381 if dims.is_empty() {
382 Ok(t.clone())
383 } else if dims.len() == 1 {
384 if dims[0] > 2 * edge_items {
385 Tensor::cat(
386 &[
387 t.narrow(0, 0, edge_items)?,
388 t.narrow(0, dims[0] - edge_items, edge_items)?,
389 ],
390 0,
391 )
392 } else {
393 Ok(t.clone())
394 }
395 } else if dims[0] > 2 * edge_items {
396 let mut vs: Vec<_> = (0..edge_items)
397 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
398 .collect::<Result<Vec<_>>>()?;
399 for i in (dims[0] - edge_items)..dims[0] {
400 vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
401 }
402 Tensor::cat(&vs, 0)
403 } else {
404 let vs: Vec<_> = (0..dims[0])
405 .map(|i| get_summarized_data(&t.get(i)?, edge_items))
406 .collect::<Result<Vec<_>>>()?;
407 Tensor::cat(&vs, 0)
408 }
409}
410
411trait Display<T: WithDType> {
412 fn fmt_display(tensor: &Tensor<T>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 if tensor.is_meta() {
414 return write!(
415 f,
416 "Tensor[Meta; {:?}, {}]",
417 tensor.dims(),
418 tensor.dtype(),
419 )
420 }
421
422 let po: std::sync::MutexGuard<'_, PrinterOptions> = PRINT_OPTS.lock().unwrap();
423 let summarize = tensor.element_count() > po.threshold;
424 let to_display = if summarize {
425 match get_summarized_data(tensor, po.edge_items) {
426 Ok(v) => v,
427 Err(err) => return write!(f, "{err:?}"),
428 }
429 } else {
430 tensor.clone()
431 };
432
433 Self::do_fmt_display(tensor, to_display, f, summarize, po)?;
434
435 write!(
436 f,
437 "Tensor[{:?}, {}]",
438 tensor.dims(),
439 tensor.dtype(),
440 )
441 }
442
443 fn do_fmt_display(
444 tensor: &Tensor<T>,
445 to_display: Tensor<T>,
446 f: &mut std::fmt::Formatter,
447 summarize: bool,
448 po: std::sync::MutexGuard<'_, PrinterOptions>
449 ) -> std::fmt::Result {
450 let tf = IntFormatter::new();
451 let max_w = tf.max_width(&to_display);
452 tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
453 writeln!(f)
454 }
455}
456
457struct DisplayOp;
458
459impl<T: NumDType> Display<T> for DisplayOp
460where
461 Self: DisplayCategory<T>,
462{
463 fn do_fmt_display(
464 tensor: &Tensor<T>,
465 to_display: Tensor<T>,
466 f: &mut std::fmt::Formatter,
467 summarize: bool,
468 po: std::sync::MutexGuard<'_, PrinterOptions>
469 ) -> std::fmt::Result {
470 DisplayOp::do_fmt_display_category(tensor, to_display, f, summarize, po)
471 }
472}
473
474trait DisplayCategory<T: NumDType, C: NumCategory = <T as NumDType>::Category> {
475 fn do_fmt_display_category(
476 tensor: &Tensor<T>,
477 to_display: Tensor<T>,
478 f: &mut std::fmt::Formatter,
479 summarize: bool,
480 po: std::sync::MutexGuard<'_, PrinterOptions>
481 ) -> std::fmt::Result;
482}
483
484impl<T: IntDType> DisplayCategory<T, IntCategory> for DisplayOp {
485 fn do_fmt_display_category(
486 tensor: &Tensor<T>,
487 to_display: Tensor<T>,
488 f: &mut std::fmt::Formatter,
489 summarize: bool,
490 po: std::sync::MutexGuard<'_, PrinterOptions>
491 ) -> std::fmt::Result {
492 let tf: IntFormatter<T> = IntFormatter::new();
493 let max_w = tf.max_width(&to_display);
494 tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
495 writeln!(f)
496 }
497}
498
499impl<T: FloatDType> DisplayCategory<T, FloatCategory> for DisplayOp {
500 fn do_fmt_display_category(
501 tensor: &Tensor<T>,
502 to_display: Tensor<T>,
503 f: &mut std::fmt::Formatter,
504 summarize: bool,
505 po: std::sync::MutexGuard<'_, PrinterOptions>
506 ) -> std::fmt::Result {
507 if let Ok(tf) = FloatFormatter::new(&to_display, &po) {
508 let max_w = tf.max_width(&to_display);
509 tf.fmt_tensor(tensor, 1, max_w, summarize, &po, f)?;
510 writeln!(f)?;
511 }
512 Ok(())
513 }
514}
515
516impl Display<bool> for DisplayOp {
517
518}
519
520impl<T: WithDType> std::fmt::Display for Tensor<T>
521where
522 DisplayOp: Display<T>,
523{
524 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
525 DisplayOp::fmt_display(self, f)
526 }
527}