hpt_display/
display.rs

1use hpt_common::utils::pointer::Pointer;
2use hpt_traits::tensor::{CommonBounds, TensorInfo};
3use hpt_types::into_scalar::Cast;
4use std::fmt::Formatter;
5
6use crate::formats::format_val;
7
8/// # Internal Function
9/// Pushes the string representation of the tensor to the string.
10fn main_loop_push_str<U, T>(
11    tensor: &U,
12    lr_elements_size: usize,
13    inner_loop: usize,
14    last_stride: i64,
15    string: &mut String,
16    precision: usize,
17    col_width: &mut Vec<usize>,
18    prg: &mut Vec<i64>,
19    shape: &Vec<i64>,
20    mut ptr: Pointer<T>,
21) where
22    U: TensorInfo<T>,
23    T: CommonBounds + Cast<f64>,
24{
25    let print = |string: &mut String, ptr: Pointer<T>, offset: &mut i64, col: usize| {
26        let val = format_val(ptr[*offset], precision);
27        string.push_str(&format!("{:>width$}", val, width = col_width[col]));
28        if col < inner_loop - 1 {
29            string.push(' ');
30        }
31        *offset += last_stride;
32    };
33    let mut outer_loop = 1;
34    for i in tensor.shape().iter().take(tensor.ndim() - 1) {
35        if i > &(2 * (lr_elements_size as i64)) {
36            outer_loop *= 2 * (lr_elements_size as i64);
37        } else {
38            outer_loop *= i;
39        }
40    }
41    for _ in 0..outer_loop {
42        let mut offset = 0;
43        if inner_loop >= 2 * lr_elements_size {
44            for i in 0..2 {
45                for j in 0..lr_elements_size {
46                    print(string, ptr.clone(), &mut offset, j);
47                }
48                if i == 0 {
49                    string.push_str("... ");
50                    offset += last_stride * ((inner_loop as i64) - 2 * (lr_elements_size as i64));
51                }
52            }
53        } else {
54            for j in 0..inner_loop {
55                print(string, ptr.clone(), &mut offset, j);
56            }
57        }
58        string.push_str("]");
59        for k in (0..tensor.ndim() - 1).rev() {
60            if prg[k] < shape[k] {
61                prg[k] += 1;
62                ptr.offset(tensor.strides()[k]);
63                if tensor.shape()[k] > 2 * (lr_elements_size as i64)
64                    && prg[k] == (lr_elements_size as i64)
65                {
66                    string.push_str("\n");
67                    string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
68                    string.push_str("...");
69                    string.push_str("\n\n");
70                    string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
71                    string.push_str(&"[".repeat(tensor.ndim() - (k + 1)));
72                    ptr.offset(
73                        tensor.strides()[k] * (tensor.shape()[k] - 2 * (lr_elements_size as i64)),
74                    );
75                    prg[k] += tensor.shape()[k] - 2 * (lr_elements_size as i64);
76                    assert!(prg[k] < tensor.shape()[k]);
77                    break;
78                }
79
80                string.push_str("\n");
81                string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
82                string.push_str(&"[".repeat(tensor.ndim() - (k + 1)));
83                assert!(prg[k] < tensor.shape()[k]);
84                break;
85            } else {
86                prg[k] = 0;
87                string.push_str("]");
88                if k >= 1 && prg[k - 1] < shape[k - 1] {
89                    string.push_str(&"\n".repeat(tensor.ndim() - (k + 1)));
90                }
91                ptr.offset(-tensor.strides()[k] * shape[k]);
92            }
93        }
94    }
95}
96
97/// # Internal Function
98/// Get the width of each column in the tensor.
99fn main_loop_get_width<U, T>(
100    tensor: &U,
101    lr_elements_size: usize,
102    inner_loop: usize,
103    last_stride: i64,
104    precision: usize,
105    col_width: &mut Vec<usize>,
106    prg: &mut Vec<i64>,
107    shape: &Vec<i64>,
108    mut ptr: Pointer<T>,
109) where
110    U: TensorInfo<T>,
111    T: CommonBounds + Cast<f64>,
112{
113    let mut outer_loop = 1;
114    for i in tensor.shape().iter().take(tensor.ndim() - 1) {
115        if i > &(2 * (lr_elements_size as i64)) {
116            outer_loop *= 2 * (lr_elements_size as i64);
117        } else {
118            outer_loop *= i;
119        }
120    }
121    for _ in 0..outer_loop {
122        let mut offset: i64 = 0;
123        if inner_loop >= 2 * lr_elements_size {
124            for i in 0..2 {
125                for j in 0..lr_elements_size {
126                    let val = format_val(ptr[offset], precision);
127                    col_width[j] = std::cmp::max(col_width[j], val.len());
128                    offset += last_stride;
129                }
130                if i == 0 {
131                    offset += last_stride * ((inner_loop as i64) - 2 * (lr_elements_size as i64));
132                }
133            }
134        } else {
135            for j in 0..inner_loop {
136                let val = format_val(ptr[offset], precision);
137                col_width[j] = std::cmp::max(col_width[j], val.len());
138                offset += last_stride;
139            }
140        }
141        for k in (0..tensor.ndim() - 1).rev() {
142            if prg[k] < shape[k] {
143                prg[k] += 1;
144                ptr.offset(tensor.strides()[k]);
145                if tensor.shape()[k] > 2 * (lr_elements_size as i64)
146                    && prg[k] == (lr_elements_size as i64)
147                {
148                    ptr.offset(
149                        tensor.strides()[k] * (tensor.shape()[k] - 2 * (lr_elements_size as i64)),
150                    );
151                    prg[k] += tensor.shape()[k] - 2 * (lr_elements_size as i64);
152                    assert!(prg[k] < tensor.shape()[k]);
153                    break;
154                }
155                assert!(prg[k] < tensor.shape()[k]);
156                break;
157            } else {
158                prg[k] = 0;
159                ptr.offset(-tensor.strides()[k] * shape[k]);
160            }
161        }
162    }
163}
164
165/// Display a tensor.
166///
167/// # Arguments
168/// - `tensor`: A reference to the tensor to be displayed.
169/// - `f`: A reference to the formatter.
170/// - `lr_elements_size`: Number of elements to display in left and right for each row and column.
171/// - `precision`: Number of decimal places to display for floating point numbers.
172/// - `show_backward`: A boolean indicating whether to display the gradient function of the tensor, currently only used in DiffTensor.
173pub fn display<U, T>(
174    tensor: U,
175    f: &mut Formatter<'_>,
176    lr_elements_size: usize,
177    precision: usize,
178    show_backward: bool,
179) -> std::fmt::Result
180where
181    U: TensorInfo<T>,
182    T: CommonBounds + Cast<f64>,
183{
184    let mut string: String = String::new();
185    if tensor.size() == 0 {
186        write!(f, "{}", "Tensor([])\n".to_string())
187    } else if tensor.ndim() == 0 {
188        let val = format_val(unsafe { tensor.ptr().ptr.read() }, precision);
189        write!(f, "{}", format!("Tensor({})\n", val))
190    } else {
191        let ptr: Pointer<T> = tensor.ptr();
192        if !ptr.ptr.is_null() {
193            let inner_loop: usize = tensor.shape()[tensor.ndim() - 1] as usize;
194            let mut prg: Vec<i64> = vec![0; tensor.ndim()];
195            let mut shape: Vec<i64> = tensor.shape().to_vec();
196            shape.iter_mut().for_each(|x: &mut i64| {
197                *x -= 1;
198            });
199            let mut strides: Vec<i64> = tensor.strides().to_vec();
200            shape.iter().enumerate().for_each(|(i, x)| {
201                if *x == 0 {
202                    strides[i] = 0;
203                }
204            });
205            let last_stride = strides[tensor.ndim() - 1];
206            string.push_str("Tensor(");
207            for _ in 0..tensor.ndim() {
208                string.push_str("[");
209            }
210            let mut col_width: Vec<usize> = vec![0; inner_loop];
211            main_loop_get_width(
212                &tensor,
213                lr_elements_size,
214                inner_loop,
215                last_stride,
216                precision,
217                &mut col_width,
218                &mut prg,
219                &shape,
220                ptr.clone(),
221            );
222            main_loop_push_str(
223                &tensor,
224                lr_elements_size,
225                inner_loop,
226                last_stride,
227                &mut string,
228                precision,
229                &mut col_width,
230                &mut prg,
231                &shape,
232                ptr.clone(),
233            );
234        }
235        let shape_str = tensor
236            .shape()
237            .iter()
238            .map(|x| x.to_string())
239            .collect::<Vec<String>>()
240            .join(", ");
241        let strides_str = tensor
242            .strides()
243            .iter()
244            .map(|x| x.to_string())
245            .collect::<Vec<String>>()
246            .join(", ");
247        if !show_backward {
248            string.push_str(&format!(
249                ", shape=({}), strides=({}), dtype={})\n",
250                shape_str,
251                strides_str,
252                T::STR
253            ));
254        } else {
255            string.push_str(&format!(
256                ", shape=({}), strides=({}), dtype={}, grad_fn={})\n",
257                shape_str,
258                strides_str,
259                T::STR,
260                "None"
261            ));
262        }
263        write!(f, "{}", format!("{}", string))
264    }
265}