burn_jit/ops/
base.rs

1use crate::{element::JitElement, kernel, tensor::JitTensor, BoolElement, JitRuntime};
2use burn_tensor::{Shape, TensorData};
3use cubecl::tensor_vectorization_factor;
4
5pub(crate) fn from_data<R: JitRuntime, E: JitElement>(
6    data: TensorData,
7    device: &R::Device,
8) -> JitTensor<R> {
9    let shape: Shape = (&data.shape).into();
10    let client = R::client(device);
11    let buffer = client.create(data.convert::<E>().as_bytes());
12
13    JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype())
14}
15
16pub(crate) async fn into_data<R: JitRuntime, E: JitElement>(tensor: JitTensor<R>) -> TensorData {
17    let tensor = kernel::into_contiguous(tensor);
18
19    let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
20    let actual_len = tensor.shape.num_elements() * size_of::<E>();
21    TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape)
22}
23
24/// Read data from a `JitTensor` synchronously
25#[allow(unused, reason = "useful for debugging kernels")]
26pub fn into_data_sync<R: JitRuntime, E: JitElement>(tensor: JitTensor<R>) -> TensorData {
27    let tensor = kernel::into_contiguous(tensor);
28
29    let bytes = tensor.client.read_one(tensor.handle.binding());
30    let actual_len = tensor.shape.num_elements() * size_of::<E>();
31    TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape)
32}
33
34pub(crate) async fn bool_into_data<R: JitRuntime, BT: BoolElement>(
35    tensor: JitTensor<R>,
36) -> TensorData {
37    let tensor = kernel::into_contiguous(tensor);
38    let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
39    let actual_len = tensor.shape.num_elements() * size_of::<BT>();
40    TensorData::new(
41        BT::from_bytes(&bytes[..actual_len])
42            .iter()
43            .map(|i| *i != BT::false_val())
44            .collect(),
45        tensor.shape,
46    )
47}
48
49pub(crate) fn to_device<R: JitRuntime>(tensor: JitTensor<R>, device: &R::Device) -> JitTensor<R> {
50    if &tensor.device == device {
51        return tensor;
52    }
53
54    let client = R::client(device);
55    tensor.to_client(client, device.clone())
56}
57
58pub(crate) fn empty<R: JitRuntime, E: JitElement>(
59    shape: Shape,
60    device: &R::Device,
61) -> JitTensor<R> {
62    let client = R::client(device);
63    let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
64
65    JitTensor::new_contiguous(client, device.clone(), shape, buffer, E::dtype())
66}
67
68pub(crate) fn swap_dims<R: JitRuntime>(
69    mut tensor: JitTensor<R>,
70    dim1: usize,
71    dim2: usize,
72) -> JitTensor<R> {
73    tensor.strides.swap(dim1, dim2);
74    tensor.shape.dims.swap(dim1, dim2);
75
76    tensor
77}
78
79pub fn permute<R: JitRuntime>(mut tensor: JitTensor<R>, axes: &[usize]) -> JitTensor<R> {
80    // remap strides
81    tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
82
83    // remap shape
84    tensor.shape.dims = axes.iter().map(|i| tensor.shape.dims[*i]).collect();
85
86    tensor
87}
88pub(crate) fn expand<R: JitRuntime>(tensor: JitTensor<R>, target_shape: Shape) -> JitTensor<R> {
89    let ndims_in = tensor.shape.num_dims();
90    let ndims_out = target_shape.num_dims();
91
92    // Initialize new strides with zeros
93    let mut new_strides = vec![0usize; ndims_out];
94
95    // Calculate the difference in dimensions
96    let dim_diff = ndims_out.saturating_sub(ndims_in);
97
98    // Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones
99    let mut tensor_dim_iter = tensor.shape.dims.iter().rev();
100    for i in (0..ndims_out).rev() {
101        if i >= dim_diff {
102            if let Some(&tensor_dim) = tensor_dim_iter.next() {
103                if tensor_dim == target_shape.dims[i] || tensor_dim == 1 {
104                    // Copy stride for non-broadcast dimensions or set to 0 for broadcast ones
105                    new_strides[i] = if tensor_dim == target_shape.dims[i] {
106                        tensor.strides[i - dim_diff]
107                    } else {
108                        0
109                    };
110                } else {
111                    // Error handling: Dimension mismatch for broadcasting
112                    panic!(
113                        "Dimension mismatch: cannot broadcast dimension {} of tensor to target shape",
114                        tensor_dim
115                    );
116                }
117            } else {
118                // If the input tensor has fewer dimensions, treat missing dimensions as 1
119                // and set stride to 0 (broadcasting)
120                new_strides[i] = 0;
121            }
122        } else {
123            // For extra dimensions in the target shape, set stride to 0 (broadcasting)
124            new_strides[i] = 0;
125        }
126    }
127
128    JitTensor {
129        client: tensor.client,
130        device: tensor.device,
131        shape: target_shape,
132        strides: new_strides,
133        handle: tensor.handle,
134        dtype: tensor.dtype,
135    }
136}
137
138pub(crate) fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
139    // TODO: Not force standard layout all the time (improve performance).
140    let tensor = kernel::into_contiguous(tensor);
141
142    JitTensor::new_contiguous(
143        tensor.client,
144        tensor.device,
145        shape,
146        tensor.handle,
147        tensor.dtype,
148    )
149}
150
151pub(crate) fn max_vectorization<R: JitRuntime>(tensor: &JitTensor<R>) -> u8 {
152    tensor_vectorization_factor(
153        R::supported_line_sizes(),
154        &tensor.shape.dims,
155        &tensor.strides,
156        tensor.shape.num_dims() - 1,
157    )
158}