cubek_test_utils/test_tensor/
host_data.rs1use cubecl::{
2 CubeElement, TestRuntime,
3 client::ComputeClient,
4 prelude::CubePrimitive,
5 std::tensor::TensorHandle,
6 zspace::{Shape, Strides},
7};
8
9use crate::test_tensor::{cast::copy_casted, strides::physical_extent};
10
11#[derive(Debug, Clone)]
12pub struct HostData {
13 pub data: HostDataVec,
14 pub shape: Shape,
15 pub strides: Strides,
16}
17
18#[derive(Eq, PartialEq, PartialOrd, Clone, Copy, Debug)]
19pub enum HostDataType {
20 F32,
21 I32,
22 Bool,
23}
24
25#[derive(Clone, Debug)]
26pub enum HostDataVec {
27 F32(Vec<f32>),
28 I32(Vec<i32>),
29 Bool(Vec<bool>),
30}
31
32impl HostDataVec {
33 pub fn get_f32(&self, i: usize) -> f32 {
34 match self {
35 HostDataVec::F32(items) => items[i],
36 _ => panic!("Can't get as f32"),
37 }
38 }
39
40 pub fn get_bool(&self, i: usize) -> bool {
41 match self {
42 HostDataVec::Bool(items) => items[i],
43 _ => panic!("Can't get as bool"),
44 }
45 }
46
47 pub fn get_i32(&self, i: usize) -> i32 {
48 match self {
49 HostDataVec::I32(items) => items[i],
50 _ => panic!("Can't get as i32"),
51 }
52 }
53
54 pub fn try_get_f32(&self, i: usize) -> Option<f32> {
55 match self {
56 HostDataVec::F32(items) => items.get(i).copied(),
57 _ => None,
58 }
59 }
60
61 pub fn try_get_i32(&self, i: usize) -> Option<i32> {
62 match self {
63 HostDataVec::I32(items) => items.get(i).copied(),
64 _ => None,
65 }
66 }
67
68 pub fn try_get_bool(&self, i: usize) -> Option<bool> {
69 match self {
70 HostDataVec::Bool(items) => items.get(i).copied(),
71 _ => None,
72 }
73 }
74}
75
76impl HostData {
77 pub fn from_tensor_handle(
78 client: &ComputeClient<TestRuntime>,
79 mut tensor_handle: TensorHandle<TestRuntime>,
80 host_data_type: HostDataType,
81 ) -> Self {
82 let shape = tensor_handle.shape().clone();
83 let strides = tensor_handle.strides().clone();
84
85 let physical_len = physical_extent(&shape, &strides);
92 tensor_handle.metadata.shape = Shape::from(vec![physical_len]);
93 tensor_handle.metadata.strides = Strides::new(&[1]);
94
95 let data = match host_data_type {
96 HostDataType::F32 => {
97 let handle = copy_casted(
98 client,
99 tensor_handle,
100 f32::as_type_native_unchecked().storage_type(),
101 );
102 let data = f32::from_bytes(
103 &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
104 )
105 .to_owned();
106
107 HostDataVec::F32(data)
108 }
109 HostDataType::I32 => {
110 let handle = copy_casted(
111 client,
112 tensor_handle,
113 i32::as_type_native_unchecked().storage_type(),
114 );
115 let data = i32::from_bytes(
116 &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
117 )
118 .to_owned();
119
120 HostDataVec::I32(data)
121 }
122 HostDataType::Bool => {
123 let handle = copy_casted(
124 client,
125 tensor_handle,
126 u32::as_type_native_unchecked().storage_type(),
127 );
128 let data = u32::from_bytes(
129 &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
130 )
131 .to_owned();
132
133 HostDataVec::Bool(data.iter().map(|&x| x > 0).collect())
134 }
135 };
136
137 Self {
138 data,
139 shape,
140 strides,
141 }
142 }
143
144 pub fn get_f32(&self, index: &[usize]) -> f32 {
145 self.data.get_f32(self.strided_index(index))
146 }
147
148 pub fn get_bool(&self, index: &[usize]) -> bool {
149 self.data.get_bool(self.strided_index(index))
150 }
151
152 pub fn get_i32(&self, index: &[usize]) -> i32 {
153 self.data.get_i32(self.strided_index(index))
154 }
155
156 pub fn try_get_f32(&self, index: &[usize]) -> Option<f32> {
159 self.data.try_get_f32(self.strided_index(index))
160 }
161
162 pub fn try_get_i32(&self, index: &[usize]) -> Option<i32> {
163 self.data.try_get_i32(self.strided_index(index))
164 }
165
166 pub fn try_get_bool(&self, index: &[usize]) -> Option<bool> {
167 self.data.try_get_bool(self.strided_index(index))
168 }
169
170 pub fn iter_indices(&self) -> impl Iterator<Item = Vec<usize>> + '_ {
175 IndexIter::new(self.shape.as_slice().to_vec())
176 }
177
178 pub fn iter_indexed_f32(&self) -> impl Iterator<Item = (Vec<usize>, f32)> + '_ {
181 self.iter_indices().map(move |idx| {
182 let v = self.get_f32(&idx);
183 (idx, v)
184 })
185 }
186
187 pub fn iter_indexed_i32(&self) -> impl Iterator<Item = (Vec<usize>, i32)> + '_ {
190 self.iter_indices().map(move |idx| {
191 let v = self.get_i32(&idx);
192 (idx, v)
193 })
194 }
195
196 pub fn iter_indexed_bool(&self) -> impl Iterator<Item = (Vec<usize>, bool)> + '_ {
199 self.iter_indices().map(move |idx| {
200 let v = self.get_bool(&idx);
201 (idx, v)
202 })
203 }
204
205 fn strided_index(&self, index: &[usize]) -> usize {
206 let mut i = 0usize;
207 for (d, idx) in index.iter().enumerate() {
208 i += idx * self.strides[d];
209 }
210 i
211 }
212
213 pub fn pretty_print(&self) -> String {
220 self.pretty_print_filtered(None)
221 }
222
223 pub fn pretty_print_slice<I>(&self, filter: I) -> String
229 where
230 I: IntoIterator,
231 I::Item: Into<crate::DimFilter>,
232 {
233 let f: crate::TensorFilter = filter.into_iter().map(Into::into).collect();
234 assert_eq!(
235 f.len(),
236 self.shape.rank(),
237 "pretty_print_slice: filter rank ({}) must match tensor rank ({})",
238 f.len(),
239 self.shape.rank(),
240 );
241 self.pretty_print_filtered(Some(f))
242 }
243
244 fn pretty_print_filtered(&self, filter: Option<crate::TensorFilter>) -> String {
245 let rank = self.shape.rank();
246 match rank {
247 0 => String::new(),
248 1 => {
249 let col_filter = filter.as_ref().and_then(|f| f.first());
251 let cols = axis_indices(col_filter, self.shape[0]);
252 let rows = vec![0usize];
253 pretty_print_table(&rows, &cols, |_row_label, col_label| {
254 self.cell_string(self.strided_index(&[col_label]))
255 })
256 }
257 2 => {
258 let row_filter = filter.as_ref().and_then(|f| f.first());
261 let col_filter = filter.as_ref().and_then(|f| f.get(1));
262 let rows = axis_indices(row_filter, self.shape[0]);
263 let cols = axis_indices(col_filter, self.shape[1]);
264 pretty_print_table(&rows, &cols, |row_label, col_label| {
265 self.cell_string(self.strided_index(&[row_label, col_label]))
266 })
267 }
268 _ => self.print_higher_rank(filter.as_ref()),
269 }
270 }
271
272 fn cell_string(&self, idx: usize) -> String {
273 match &self.data {
274 HostDataVec::I32(_) => self.data.get_i32(idx).to_string(),
275 HostDataVec::F32(_) => format!("{:.3}", self.data.get_f32(idx)),
276 HostDataVec::Bool(_) => self.data.get_bool(idx).to_string(),
277 }
278 }
279
280 fn print_higher_rank(&self, filter: Option<&crate::TensorFilter>) -> String {
281 let rank = self.shape.rank();
282 let leading_dims = rank - 2;
283 let row_dim = self.shape[rank - 2];
284 let col_dim = self.shape[rank - 1];
285
286 let row_filter = filter.and_then(|f| f.get(rank - 2));
288 let col_filter = filter.and_then(|f| f.get(rank - 1));
289 let row_indices = axis_indices(row_filter, row_dim);
290 let col_indices = axis_indices(col_filter, col_dim);
291
292 let mut out = String::new();
293 let mut leading = vec![0usize; leading_dims];
294
295 loop {
297 let print_this = match filter {
298 None => true,
299 Some(f) => leading_indices_match(&leading, f),
300 };
301
302 if print_this {
303 if !out.is_empty() {
304 out.push('\n');
305 }
306 out.push_str(&format!("{}:\n", format_leading_label(&leading, rank)));
307
308 let table = pretty_print_table(&row_indices, &col_indices, |row, col| {
309 let mut full = leading.clone();
310 full.push(row);
311 full.push(col);
312 self.cell_string(self.strided_index(&full))
313 });
314 out.push_str(&table);
315 }
316
317 if !increment_lex(&mut leading, &self.shape.as_slice()[..leading_dims]) {
319 break;
320 }
321 }
322
323 out
324 }
325}
326
327pub fn pretty_print_zip(tensors: &[&HostData]) -> String {
328 assert!(!tensors.is_empty(), "Need at least one tensor");
329
330 let dims = tensors[0].shape.as_slice();
331
332 for t in tensors {
333 assert_eq!(t.shape.as_slice(), dims, "All tensors must have same shape");
334 }
335
336 let rank = tensors[0].shape.rank();
337
338 let cell = |full: &[usize]| -> String {
339 let mut parts = Vec::with_capacity(tensors.len());
340 for t in tensors {
341 let idx = t.strided_index(full);
342 parts.push(t.cell_string(idx));
343 }
344 parts.join("/")
345 };
346
347 match rank {
348 0 => String::new(),
349 1 => {
350 let cols: Vec<usize> = (0..dims[0]).collect();
351 pretty_print_table(&[0], &cols, |_, col| cell(&[col]))
352 }
353 2 => {
354 let rows: Vec<usize> = (0..dims[0]).collect();
355 let cols: Vec<usize> = (0..dims[1]).collect();
356 pretty_print_table(&rows, &cols, |row, col| cell(&[row, col]))
357 }
358 _ => {
359 let leading_dims = rank - 2;
360 let rows: Vec<usize> = (0..dims[rank - 2]).collect();
361 let cols: Vec<usize> = (0..dims[rank - 1]).collect();
362 let mut out = String::new();
363 let mut leading = vec![0usize; leading_dims];
364 loop {
365 if !out.is_empty() {
366 out.push('\n');
367 }
368 out.push_str(&format!("{}:\n", format_leading_label(&leading, rank)));
369 let table = pretty_print_table(&rows, &cols, |row, col| {
370 let mut full = leading.clone();
371 full.push(row);
372 full.push(col);
373 cell(&full)
374 });
375 out.push_str(&table);
376
377 if !increment_lex(&mut leading, &dims[..leading_dims]) {
378 break;
379 }
380 }
381 out
382 }
383 }
384}
385
386fn leading_indices_match(leading: &[usize], filter: &crate::TensorFilter) -> bool {
390 use crate::DimFilter::*;
391 for (dim, &idx) in leading.iter().enumerate() {
392 let f = filter.get(dim).unwrap_or(&Any);
393 match f {
394 Any => {}
395 Exact(v) => {
396 if idx != *v {
397 return false;
398 }
399 }
400 Range { start, end } => {
401 if idx < *start || idx > *end {
402 return false;
403 }
404 }
405 }
406 }
407 true
408}
409
410fn increment_lex(idx: &mut [usize], bounds: &[usize]) -> bool {
413 if idx.is_empty() {
414 return false;
415 }
416 for d in (0..idx.len()).rev() {
417 idx[d] += 1;
418 if idx[d] < bounds[d] {
419 return true;
420 }
421 idx[d] = 0;
422 }
423 false
424}
425
426struct IndexIter {
430 shape: Vec<usize>,
431 next: Option<Vec<usize>>,
432}
433
434impl IndexIter {
435 fn new(shape: Vec<usize>) -> Self {
436 let next = if shape.contains(&0) {
438 None
439 } else {
440 Some(vec![0; shape.len()])
441 };
442 Self { shape, next }
443 }
444}
445
446impl Iterator for IndexIter {
447 type Item = Vec<usize>;
448
449 fn next(&mut self) -> Option<Self::Item> {
450 let current = self.next.clone()?;
451
452 let mut tentative = current.clone();
456 if !increment_lex(&mut tentative, &self.shape) {
457 self.next = None;
458 } else {
459 self.next = Some(tentative);
460 }
461
462 Some(current)
463 }
464}
465
466fn format_leading_label(leading: &[usize], rank: usize) -> String {
467 let mut parts: Vec<String> = leading.iter().map(|i| i.to_string()).collect();
468 for _ in 0..(rank - leading.len()) {
471 parts.push("*".to_string());
472 }
473 format!("[{}]", parts.join(", "))
474}
475
476fn axis_indices(f: Option<&crate::DimFilter>, dim_size: usize) -> Vec<usize> {
480 use crate::DimFilter::*;
481 match f {
482 None | Some(Any) => (0..dim_size).collect(),
483 Some(Exact(v)) => {
484 if *v < dim_size {
485 vec![*v]
486 } else {
487 Vec::new()
488 }
489 }
490 Some(Range { start, end }) => {
491 if *start >= dim_size {
492 Vec::new()
493 } else {
494 (*start..=(*end).min(dim_size.saturating_sub(1))).collect()
495 }
496 }
497 }
498}
499
500fn pretty_print_table<F>(rows: &[usize], cols: &[usize], mut cell: F) -> String
501where
502 F: FnMut(usize, usize) -> String,
503{
504 let mut max_width = 0;
505
506 for &r in rows {
507 for &c in cols {
508 max_width = max_width.max(cell(r, c).len());
509 }
510 }
511
512 let label_width = cols.iter().map(|c| c.to_string().len()).max().unwrap_or(0);
515 max_width = max_width.max(label_width).max(2);
516
517 let row_label_width = rows
518 .iter()
519 .map(|r| r.to_string().len())
520 .max()
521 .unwrap_or(0)
522 .max(3);
523
524 let mut s = String::new();
525
526 s.push_str(&format!("{:>width$} |", "", width = row_label_width));
528 for &col in cols {
529 s.push_str(&format!(" {:>width$}", col, width = max_width));
530 }
531 s.push('\n');
532
533 s.push_str(&"-".repeat(row_label_width + 1));
535 s.push('+');
536 for _ in cols {
537 s.push_str(&"-".repeat(max_width + 1));
538 }
539 s.push('\n');
540
541 for &row in rows {
543 s.push_str(&format!("{:>width$} |", row, width = row_label_width));
544
545 for &col in cols {
546 let value = cell(row, col);
547 s.push_str(&format!(" {:>width$}", value, width = max_width));
548 }
549
550 s.push('\n');
551 }
552
553 s
554}