1use hpt_common::utils::pointer::Pointer;
2use hpt_traits::tensor::{CommonBounds, TensorInfo};
3use hpt_types::into_scalar::Cast;
4use std::fmt::Formatter;
5
6use crate::formats::format_val;
7
8fn main_loop_push_str<U, T>(
11 tensor: &U,
12 lr_elements_size: usize,
13 inner_loop: usize,
14 last_stride: i64,
15 string: &mut String,
16 precision: usize,
17 col_width: &mut Vec<usize>,
18 prg: &mut Vec<i64>,
19 shape: &Vec<i64>,
20 mut ptr: Pointer<T>,
21) where
22 U: TensorInfo<T>,
23 T: CommonBounds + Cast<f64>,
24{
25 let print = |string: &mut String, ptr: Pointer<T>, offset: &mut i64, col: usize| {
26 let val = format_val(ptr[*offset], precision);
27 string.push_str(&format!("{:>width$}", val, width = col_width[col]));
28 if col < inner_loop - 1 {
29 string.push(' ');
30 }
31 *offset += last_stride;
32 };
33 let mut outer_loop = 1;
34 for i in tensor.shape().iter().take(tensor.ndim() - 1) {
35 if i > &(2 * (lr_elements_size as i64)) {
36 outer_loop *= 2 * (lr_elements_size as i64);
37 } else {
38 outer_loop *= i;
39 }
40 }
41 for _ in 0..outer_loop {
42 let mut offset = 0;
43 if inner_loop >= 2 * lr_elements_size {
44 for i in 0..2 {
45 for j in 0..lr_elements_size {
46 print(string, ptr.clone(), &mut offset, j);
47 }
48 if i == 0 {
49 string.push_str("... ");
50 offset += last_stride * ((inner_loop as i64) - 2 * (lr_elements_size as i64));
51 }
52 }
53 } else {
54 for j in 0..inner_loop {
55 print(string, ptr.clone(), &mut offset, j);
56 }
57 }
58 string.push_str("]");
59 for k in (0..tensor.ndim() - 1).rev() {
60 if prg[k] < shape[k] {
61 prg[k] += 1;
62 ptr.offset(tensor.strides()[k]);
63 if tensor.shape()[k] > 2 * (lr_elements_size as i64)
64 && prg[k] == (lr_elements_size as i64)
65 {
66 string.push_str("\n");
67 string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
68 string.push_str("...");
69 string.push_str("\n\n");
70 string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
71 string.push_str(&"[".repeat(tensor.ndim() - (k + 1)));
72 ptr.offset(
73 tensor.strides()[k] * (tensor.shape()[k] - 2 * (lr_elements_size as i64)),
74 );
75 prg[k] += tensor.shape()[k] - 2 * (lr_elements_size as i64);
76 assert!(prg[k] < tensor.shape()[k]);
77 break;
78 }
79
80 string.push_str("\n");
81 string.push_str(&" ".repeat(k + 1 + "Tensor(".len()));
82 string.push_str(&"[".repeat(tensor.ndim() - (k + 1)));
83 assert!(prg[k] < tensor.shape()[k]);
84 break;
85 } else {
86 prg[k] = 0;
87 string.push_str("]");
88 if k >= 1 && prg[k - 1] < shape[k - 1] {
89 string.push_str(&"\n".repeat(tensor.ndim() - (k + 1)));
90 }
91 ptr.offset(-tensor.strides()[k] * shape[k]);
92 }
93 }
94 }
95}
96
97fn main_loop_get_width<U, T>(
100 tensor: &U,
101 lr_elements_size: usize,
102 inner_loop: usize,
103 last_stride: i64,
104 precision: usize,
105 col_width: &mut Vec<usize>,
106 prg: &mut Vec<i64>,
107 shape: &Vec<i64>,
108 mut ptr: Pointer<T>,
109) where
110 U: TensorInfo<T>,
111 T: CommonBounds + Cast<f64>,
112{
113 let mut outer_loop = 1;
114 for i in tensor.shape().iter().take(tensor.ndim() - 1) {
115 if i > &(2 * (lr_elements_size as i64)) {
116 outer_loop *= 2 * (lr_elements_size as i64);
117 } else {
118 outer_loop *= i;
119 }
120 }
121 for _ in 0..outer_loop {
122 let mut offset: i64 = 0;
123 if inner_loop >= 2 * lr_elements_size {
124 for i in 0..2 {
125 for j in 0..lr_elements_size {
126 let val = format_val(ptr[offset], precision);
127 col_width[j] = std::cmp::max(col_width[j], val.len());
128 offset += last_stride;
129 }
130 if i == 0 {
131 offset += last_stride * ((inner_loop as i64) - 2 * (lr_elements_size as i64));
132 }
133 }
134 } else {
135 for j in 0..inner_loop {
136 let val = format_val(ptr[offset], precision);
137 col_width[j] = std::cmp::max(col_width[j], val.len());
138 offset += last_stride;
139 }
140 }
141 for k in (0..tensor.ndim() - 1).rev() {
142 if prg[k] < shape[k] {
143 prg[k] += 1;
144 ptr.offset(tensor.strides()[k]);
145 if tensor.shape()[k] > 2 * (lr_elements_size as i64)
146 && prg[k] == (lr_elements_size as i64)
147 {
148 ptr.offset(
149 tensor.strides()[k] * (tensor.shape()[k] - 2 * (lr_elements_size as i64)),
150 );
151 prg[k] += tensor.shape()[k] - 2 * (lr_elements_size as i64);
152 assert!(prg[k] < tensor.shape()[k]);
153 break;
154 }
155 assert!(prg[k] < tensor.shape()[k]);
156 break;
157 } else {
158 prg[k] = 0;
159 ptr.offset(-tensor.strides()[k] * shape[k]);
160 }
161 }
162 }
163}
164
165pub fn display<U, T>(
174 tensor: U,
175 f: &mut Formatter<'_>,
176 lr_elements_size: usize,
177 precision: usize,
178 show_backward: bool,
179) -> std::fmt::Result
180where
181 U: TensorInfo<T>,
182 T: CommonBounds + Cast<f64>,
183{
184 let mut string: String = String::new();
185 if tensor.size() == 0 {
186 write!(f, "{}", "Tensor([])\n".to_string())
187 } else if tensor.ndim() == 0 {
188 let val = format_val(unsafe { tensor.ptr().ptr.read() }, precision);
189 write!(f, "{}", format!("Tensor({})\n", val))
190 } else {
191 let ptr: Pointer<T> = tensor.ptr();
192 if !ptr.ptr.is_null() {
193 let inner_loop: usize = tensor.shape()[tensor.ndim() - 1] as usize;
194 let mut prg: Vec<i64> = vec![0; tensor.ndim()];
195 let mut shape: Vec<i64> = tensor.shape().to_vec();
196 shape.iter_mut().for_each(|x: &mut i64| {
197 *x -= 1;
198 });
199 let mut strides: Vec<i64> = tensor.strides().to_vec();
200 shape.iter().enumerate().for_each(|(i, x)| {
201 if *x == 0 {
202 strides[i] = 0;
203 }
204 });
205 let last_stride = strides[tensor.ndim() - 1];
206 string.push_str("Tensor(");
207 for _ in 0..tensor.ndim() {
208 string.push_str("[");
209 }
210 let mut col_width: Vec<usize> = vec![0; inner_loop];
211 main_loop_get_width(
212 &tensor,
213 lr_elements_size,
214 inner_loop,
215 last_stride,
216 precision,
217 &mut col_width,
218 &mut prg,
219 &shape,
220 ptr.clone(),
221 );
222 main_loop_push_str(
223 &tensor,
224 lr_elements_size,
225 inner_loop,
226 last_stride,
227 &mut string,
228 precision,
229 &mut col_width,
230 &mut prg,
231 &shape,
232 ptr.clone(),
233 );
234 }
235 let shape_str = tensor
236 .shape()
237 .iter()
238 .map(|x| x.to_string())
239 .collect::<Vec<String>>()
240 .join(", ");
241 let strides_str = tensor
242 .strides()
243 .iter()
244 .map(|x| x.to_string())
245 .collect::<Vec<String>>()
246 .join(", ");
247 if !show_backward {
248 string.push_str(&format!(
249 ", shape=({}), strides=({}), dtype={})\n",
250 shape_str,
251 strides_str,
252 T::STR
253 ));
254 } else {
255 string.push_str(&format!(
256 ", shape=({}), strides=({}), dtype={}, grad_fn={})\n",
257 shape_str,
258 strides_str,
259 T::STR,
260 "None"
261 ));
262 }
263 write!(f, "{}", format!("{}", string))
264 }
265}