Skip to main content

cubek_test_utils/test_tensor/
host_data.rs

1use cubecl::{
2    CubeElement, TestRuntime,
3    client::ComputeClient,
4    prelude::CubePrimitive,
5    std::tensor::TensorHandle,
6    zspace::{Shape, Strides},
7};
8
9use crate::test_tensor::{cast::copy_casted, strides::physical_extent};
10
11#[derive(Debug, Clone)]
12pub struct HostData {
13    pub data: HostDataVec,
14    pub shape: Shape,
15    pub strides: Strides,
16}
17
18#[derive(Eq, PartialEq, PartialOrd, Clone, Copy, Debug)]
19pub enum HostDataType {
20    F32,
21    I32,
22    Bool,
23}
24
25#[derive(Clone, Debug)]
26pub enum HostDataVec {
27    F32(Vec<f32>),
28    I32(Vec<i32>),
29    Bool(Vec<bool>),
30}
31
32impl HostDataVec {
33    pub fn get_f32(&self, i: usize) -> f32 {
34        match self {
35            HostDataVec::F32(items) => items[i],
36            _ => panic!("Can't get as f32"),
37        }
38    }
39
40    pub fn get_bool(&self, i: usize) -> bool {
41        match self {
42            HostDataVec::Bool(items) => items[i],
43            _ => panic!("Can't get as bool"),
44        }
45    }
46
47    pub fn get_i32(&self, i: usize) -> i32 {
48        match self {
49            HostDataVec::I32(items) => items[i],
50            _ => panic!("Can't get as i32"),
51        }
52    }
53
54    pub fn try_get_f32(&self, i: usize) -> Option<f32> {
55        match self {
56            HostDataVec::F32(items) => items.get(i).copied(),
57            _ => None,
58        }
59    }
60
61    pub fn try_get_i32(&self, i: usize) -> Option<i32> {
62        match self {
63            HostDataVec::I32(items) => items.get(i).copied(),
64            _ => None,
65        }
66    }
67
68    pub fn try_get_bool(&self, i: usize) -> Option<bool> {
69        match self {
70            HostDataVec::Bool(items) => items.get(i).copied(),
71            _ => None,
72        }
73    }
74}
75
76impl HostData {
77    pub fn from_tensor_handle(
78        client: &ComputeClient<TestRuntime>,
79        mut tensor_handle: TensorHandle<TestRuntime>,
80        host_data_type: HostDataType,
81    ) -> Self {
82        let shape = tensor_handle.shape().clone();
83        let strides = tensor_handle.strides().clone();
84
85        // Reshape to a flat 1D view of the full physical buffer so the read
86        // covers every offset the jumpy strides might reach. Without this, a
87        // shape like [256,256] with strides [512,1] would only read the
88        // shape.product() (65536) elements that `copy_casted`'s contiguous
89        // rewrite walks, and HostData.get_f32 would then index out-of-bounds
90        // when the logical walk crosses the padding.
91        let physical_len = physical_extent(&shape, &strides);
92        tensor_handle.metadata.shape = Shape::from(vec![physical_len]);
93        tensor_handle.metadata.strides = Strides::new(&[1]);
94
95        let data = match host_data_type {
96            HostDataType::F32 => {
97                let handle = copy_casted(
98                    client,
99                    tensor_handle,
100                    f32::as_type_native_unchecked().storage_type(),
101                );
102                let data = f32::from_bytes(
103                    &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
104                )
105                .to_owned();
106
107                HostDataVec::F32(data)
108            }
109            HostDataType::I32 => {
110                let handle = copy_casted(
111                    client,
112                    tensor_handle,
113                    i32::as_type_native_unchecked().storage_type(),
114                );
115                let data = i32::from_bytes(
116                    &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
117                )
118                .to_owned();
119
120                HostDataVec::I32(data)
121            }
122            HostDataType::Bool => {
123                let handle = copy_casted(
124                    client,
125                    tensor_handle,
126                    u32::as_type_native_unchecked().storage_type(),
127                );
128                let data = u32::from_bytes(
129                    &client.read_one_unchecked_tensor(handle.into_copy_descriptor()),
130                )
131                .to_owned();
132
133                HostDataVec::Bool(data.iter().map(|&x| x > 0).collect())
134            }
135        };
136
137        Self {
138            data,
139            shape,
140            strides,
141        }
142    }
143
144    pub fn get_f32(&self, index: &[usize]) -> f32 {
145        self.data.get_f32(self.strided_index(index))
146    }
147
148    pub fn get_bool(&self, index: &[usize]) -> bool {
149        self.data.get_bool(self.strided_index(index))
150    }
151
152    pub fn get_i32(&self, index: &[usize]) -> i32 {
153        self.data.get_i32(self.strided_index(index))
154    }
155
156    /// Like [`get_f32`], but returns `None` if the underlying data isn't `F32`
157    /// (or the index is out of bounds), instead of panicking.
158    pub fn try_get_f32(&self, index: &[usize]) -> Option<f32> {
159        self.data.try_get_f32(self.strided_index(index))
160    }
161
162    pub fn try_get_i32(&self, index: &[usize]) -> Option<i32> {
163        self.data.try_get_i32(self.strided_index(index))
164    }
165
166    pub fn try_get_bool(&self, index: &[usize]) -> Option<bool> {
167        self.data.try_get_bool(self.strided_index(index))
168    }
169
170    /// Iterate every logical index in row-major order, yielding the index vector.
171    ///
172    /// Useful when callers want to walk a non-contiguous tensor without
173    /// re-implementing the rank recursion themselves.
174    pub fn iter_indices(&self) -> impl Iterator<Item = Vec<usize>> + '_ {
175        IndexIter::new(self.shape.as_slice().to_vec())
176    }
177
178    /// Iterate `(index, f32 value)` pairs in row-major order.
179    /// Panics if the underlying data isn't `F32`.
180    pub fn iter_indexed_f32(&self) -> impl Iterator<Item = (Vec<usize>, f32)> + '_ {
181        self.iter_indices().map(move |idx| {
182            let v = self.get_f32(&idx);
183            (idx, v)
184        })
185    }
186
187    /// Iterate `(index, i32 value)` pairs in row-major order.
188    /// Panics if the underlying data isn't `I32`.
189    pub fn iter_indexed_i32(&self) -> impl Iterator<Item = (Vec<usize>, i32)> + '_ {
190        self.iter_indices().map(move |idx| {
191            let v = self.get_i32(&idx);
192            (idx, v)
193        })
194    }
195
196    /// Iterate `(index, bool value)` pairs in row-major order.
197    /// Panics if the underlying data isn't `Bool`.
198    pub fn iter_indexed_bool(&self) -> impl Iterator<Item = (Vec<usize>, bool)> + '_ {
199        self.iter_indices().map(move |idx| {
200            let v = self.get_bool(&idx);
201            (idx, v)
202        })
203    }
204
205    fn strided_index(&self, index: &[usize]) -> usize {
206        let mut i = 0usize;
207        for (d, idx) in index.iter().enumerate() {
208            i += idx * self.strides[d];
209        }
210        i
211    }
212
213    /// Render the tensor as one or more 2-D tables.
214    ///
215    /// - rank 1: a single row.
216    /// - rank 2: a table.
217    /// - rank ≥ 3: one labeled table per combination of leading-dim indices
218    ///   (the last two dims are always the row/col axes).
219    pub fn pretty_print(&self) -> String {
220        self.pretty_print_filtered(None)
221    }
222
223    /// Like [`pretty_print`], but only prints slices whose leading-dim indices
224    /// match the filter. Wildcards (`DimFilter::Any`) iterate every value.
225    ///
226    /// `filter` accepts both `Vec<std::ops::Range<usize>>` and the canonical
227    /// `TensorFilter` (the `CUBE_TEST_MODE` `M-K` syntax).
228    pub fn pretty_print_slice<I>(&self, filter: I) -> String
229    where
230        I: IntoIterator,
231        I::Item: Into<crate::DimFilter>,
232    {
233        let f: crate::TensorFilter = filter.into_iter().map(Into::into).collect();
234        assert_eq!(
235            f.len(),
236            self.shape.rank(),
237            "pretty_print_slice: filter rank ({}) must match tensor rank ({})",
238            f.len(),
239            self.shape.rank(),
240        );
241        self.pretty_print_filtered(Some(f))
242    }
243
244    fn pretty_print_filtered(&self, filter: Option<crate::TensorFilter>) -> String {
245        let rank = self.shape.rank();
246        match rank {
247            0 => String::new(),
248            1 => {
249                // Single-row table; the only filter entry filters the col axis.
250                let col_filter = filter.as_ref().and_then(|f| f.first());
251                let cols = axis_indices(col_filter, self.shape[0]);
252                let rows = vec![0usize];
253                pretty_print_table(&rows, &cols, |_row_label, col_label| {
254                    self.cell_string(self.strided_index(&[col_label]))
255                })
256            }
257            2 => {
258                // Last two filter entries (here filter[0], filter[1]) drive
259                // row and col selection respectively.
260                let row_filter = filter.as_ref().and_then(|f| f.first());
261                let col_filter = filter.as_ref().and_then(|f| f.get(1));
262                let rows = axis_indices(row_filter, self.shape[0]);
263                let cols = axis_indices(col_filter, self.shape[1]);
264                pretty_print_table(&rows, &cols, |row_label, col_label| {
265                    self.cell_string(self.strided_index(&[row_label, col_label]))
266                })
267            }
268            _ => self.print_higher_rank(filter.as_ref()),
269        }
270    }
271
272    fn cell_string(&self, idx: usize) -> String {
273        match &self.data {
274            HostDataVec::I32(_) => self.data.get_i32(idx).to_string(),
275            HostDataVec::F32(_) => format!("{:.3}", self.data.get_f32(idx)),
276            HostDataVec::Bool(_) => self.data.get_bool(idx).to_string(),
277        }
278    }
279
280    fn print_higher_rank(&self, filter: Option<&crate::TensorFilter>) -> String {
281        let rank = self.shape.rank();
282        let leading_dims = rank - 2;
283        let row_dim = self.shape[rank - 2];
284        let col_dim = self.shape[rank - 1];
285
286        // Filter entries for the row and col axes (the last two), if any.
287        let row_filter = filter.and_then(|f| f.get(rank - 2));
288        let col_filter = filter.and_then(|f| f.get(rank - 1));
289        let row_indices = axis_indices(row_filter, row_dim);
290        let col_indices = axis_indices(col_filter, col_dim);
291
292        let mut out = String::new();
293        let mut leading = vec![0usize; leading_dims];
294
295        // Iterate every leading-index combination, lexicographically.
296        loop {
297            let print_this = match filter {
298                None => true,
299                Some(f) => leading_indices_match(&leading, f),
300            };
301
302            if print_this {
303                if !out.is_empty() {
304                    out.push('\n');
305                }
306                out.push_str(&format!("{}:\n", format_leading_label(&leading, rank)));
307
308                let table = pretty_print_table(&row_indices, &col_indices, |row, col| {
309                    let mut full = leading.clone();
310                    full.push(row);
311                    full.push(col);
312                    self.cell_string(self.strided_index(&full))
313                });
314                out.push_str(&table);
315            }
316
317            // Increment the leading-index counter.
318            if !increment_lex(&mut leading, &self.shape.as_slice()[..leading_dims]) {
319                break;
320            }
321        }
322
323        out
324    }
325}
326
327pub fn pretty_print_zip(tensors: &[&HostData]) -> String {
328    assert!(!tensors.is_empty(), "Need at least one tensor");
329
330    let dims = tensors[0].shape.as_slice();
331
332    for t in tensors {
333        assert_eq!(t.shape.as_slice(), dims, "All tensors must have same shape");
334    }
335
336    let rank = tensors[0].shape.rank();
337
338    let cell = |full: &[usize]| -> String {
339        let mut parts = Vec::with_capacity(tensors.len());
340        for t in tensors {
341            let idx = t.strided_index(full);
342            parts.push(t.cell_string(idx));
343        }
344        parts.join("/")
345    };
346
347    match rank {
348        0 => String::new(),
349        1 => {
350            let cols: Vec<usize> = (0..dims[0]).collect();
351            pretty_print_table(&[0], &cols, |_, col| cell(&[col]))
352        }
353        2 => {
354            let rows: Vec<usize> = (0..dims[0]).collect();
355            let cols: Vec<usize> = (0..dims[1]).collect();
356            pretty_print_table(&rows, &cols, |row, col| cell(&[row, col]))
357        }
358        _ => {
359            let leading_dims = rank - 2;
360            let rows: Vec<usize> = (0..dims[rank - 2]).collect();
361            let cols: Vec<usize> = (0..dims[rank - 1]).collect();
362            let mut out = String::new();
363            let mut leading = vec![0usize; leading_dims];
364            loop {
365                if !out.is_empty() {
366                    out.push('\n');
367                }
368                out.push_str(&format!("{}:\n", format_leading_label(&leading, rank)));
369                let table = pretty_print_table(&rows, &cols, |row, col| {
370                    let mut full = leading.clone();
371                    full.push(row);
372                    full.push(col);
373                    cell(&full)
374                });
375                out.push_str(&table);
376
377                if !increment_lex(&mut leading, &dims[..leading_dims]) {
378                    break;
379                }
380            }
381            out
382        }
383    }
384}
385
386/// Match leading indices against the leading slice of a tensor filter. The
387/// trailing two filter entries (covering the row/col axes) are ignored — we
388/// always print the full row × col table for the slices we keep.
389fn leading_indices_match(leading: &[usize], filter: &crate::TensorFilter) -> bool {
390    use crate::DimFilter::*;
391    for (dim, &idx) in leading.iter().enumerate() {
392        let f = filter.get(dim).unwrap_or(&Any);
393        match f {
394            Any => {}
395            Exact(v) => {
396                if idx != *v {
397                    return false;
398                }
399            }
400            Range { start, end } => {
401                if idx < *start || idx > *end {
402                    return false;
403                }
404            }
405        }
406    }
407    true
408}
409
410/// Lexicographic increment over `idx[i] in 0..bounds[i]`. Returns `false` when
411/// the counter has wrapped past the last position (i.e. iteration is done).
412fn increment_lex(idx: &mut [usize], bounds: &[usize]) -> bool {
413    if idx.is_empty() {
414        return false;
415    }
416    for d in (0..idx.len()).rev() {
417        idx[d] += 1;
418        if idx[d] < bounds[d] {
419            return true;
420        }
421        idx[d] = 0;
422    }
423    false
424}
425
426/// Row-major index iterator. Yields every position in a tensor of the given
427/// shape, lexicographically (last dim varies fastest). For a rank-0 tensor
428/// (empty shape) the iterator yields a single empty index vector and stops.
429struct IndexIter {
430    shape: Vec<usize>,
431    next: Option<Vec<usize>>,
432}
433
434impl IndexIter {
435    fn new(shape: Vec<usize>) -> Self {
436        // Empty dim → no indices.
437        let next = if shape.contains(&0) {
438            None
439        } else {
440            Some(vec![0; shape.len()])
441        };
442        Self { shape, next }
443    }
444}
445
446impl Iterator for IndexIter {
447    type Item = Vec<usize>;
448
449    fn next(&mut self) -> Option<Self::Item> {
450        let current = self.next.clone()?;
451
452        // Advance the counter for the next call. `increment_lex` returns
453        // `false` when we've passed the last position; rank-0 also lands
454        // here on the first call.
455        let mut tentative = current.clone();
456        if !increment_lex(&mut tentative, &self.shape) {
457            self.next = None;
458        } else {
459            self.next = Some(tentative);
460        }
461
462        Some(current)
463    }
464}
465
466fn format_leading_label(leading: &[usize], rank: usize) -> String {
467    let mut parts: Vec<String> = leading.iter().map(|i| i.to_string()).collect();
468    // Rows/cols are the last two dims — render as `*` so the label reads
469    // `[i, j, *, *]`.
470    for _ in 0..(rank - leading.len()) {
471        parts.push("*".to_string());
472    }
473    format!("[{}]", parts.join(", "))
474}
475
476/// Resolve which indices along a single dim should be rendered, given a
477/// per-dim filter entry. `None` means "render everything", which is the
478/// default for unfiltered prints.
479fn axis_indices(f: Option<&crate::DimFilter>, dim_size: usize) -> Vec<usize> {
480    use crate::DimFilter::*;
481    match f {
482        None | Some(Any) => (0..dim_size).collect(),
483        Some(Exact(v)) => {
484            if *v < dim_size {
485                vec![*v]
486            } else {
487                Vec::new()
488            }
489        }
490        Some(Range { start, end }) => {
491            if *start >= dim_size {
492                Vec::new()
493            } else {
494                (*start..=(*end).min(dim_size.saturating_sub(1))).collect()
495            }
496        }
497    }
498}
499
500fn pretty_print_table<F>(rows: &[usize], cols: &[usize], mut cell: F) -> String
501where
502    F: FnMut(usize, usize) -> String,
503{
504    let mut max_width = 0;
505
506    for &r in rows {
507        for &c in cols {
508            max_width = max_width.max(cell(r, c).len());
509        }
510    }
511
512    // Also account for the column-label width (so a tensor sliced down to
513    // `[10-12]` renders header `10 11 12` without crowding).
514    let label_width = cols.iter().map(|c| c.to_string().len()).max().unwrap_or(0);
515    max_width = max_width.max(label_width).max(2);
516
517    let row_label_width = rows
518        .iter()
519        .map(|r| r.to_string().len())
520        .max()
521        .unwrap_or(0)
522        .max(3);
523
524    let mut s = String::new();
525
526    // header
527    s.push_str(&format!("{:>width$} |", "", width = row_label_width));
528    for &col in cols {
529        s.push_str(&format!(" {:>width$}", col, width = max_width));
530    }
531    s.push('\n');
532
533    // separator
534    s.push_str(&"-".repeat(row_label_width + 1));
535    s.push('+');
536    for _ in cols {
537        s.push_str(&"-".repeat(max_width + 1));
538    }
539    s.push('\n');
540
541    // rows
542    for &row in rows {
543        s.push_str(&format!("{:>width$} |", row, width = row_label_width));
544
545        for &col in cols {
546            let value = cell(row, col);
547            s.push_str(&format!(" {:>width$}", value, width = max_width));
548        }
549
550        s.push('\n');
551    }
552
553    s
554}