Skip to main content

burn_cubecl/ops/
transaction.rs

1use burn_backend::{
2    DType, TensorData,
3    backend::ExecutionError,
4    ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
5};
6use burn_std::{Shape, Strides};
7use cubecl::server::{CopyDescriptor, Handle};
8
9use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
10
11impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
12where
13    R: CubeRuntime,
14    F: FloatElement,
15    I: IntElement,
16    BT: BoolElement,
17{
18    async fn tr_execute(
19        transaction: TransactionPrimitive<Self>,
20    ) -> Result<TransactionPrimitiveData, ExecutionError> {
21        let mut client = None;
22
23        enum Kind {
24            Float,
25            Int,
26            Bool,
27        }
28
29        #[derive(new)]
30        struct BindingData {
31            index: usize,
32            kind: Kind,
33            handle: Option<Handle>,
34            shape: Shape,
35            strides: Strides,
36            dtype: DType,
37        }
38
39        let mut num_bindings = 0;
40
41        let mut kinds = Vec::new();
42
43        for t in transaction.read_floats.into_iter() {
44            if client.is_none() {
45                client = Some(t.client.clone());
46            }
47
48            let t = crate::kernel::into_contiguous_aligned(t);
49            let binding = BindingData::new(
50                num_bindings,
51                Kind::Float,
52                Some(t.handle.clone()),
53                t.meta.shape.clone(),
54                t.meta.strides.clone(),
55                t.dtype,
56            );
57
58            kinds.push(binding);
59            num_bindings += 1;
60        }
61        for t in transaction.read_ints.into_iter() {
62            if client.is_none() {
63                client = Some(t.client.clone());
64            }
65
66            let t = crate::kernel::into_contiguous_aligned(t);
67            let binding = BindingData::new(
68                num_bindings,
69                Kind::Int,
70                Some(t.handle.clone()),
71                t.meta.shape.clone(),
72                t.meta.strides.clone(),
73                t.dtype,
74            );
75
76            kinds.push(binding);
77            num_bindings += 1;
78        }
79        for t in transaction.read_bools.into_iter() {
80            if client.is_none() {
81                client = Some(t.client.clone());
82            }
83
84            let t = crate::kernel::into_contiguous_aligned(t);
85            let binding = BindingData::new(
86                num_bindings,
87                Kind::Bool,
88                Some(t.handle.clone()),
89                t.meta.shape.clone(),
90                t.meta.strides.clone(),
91                t.dtype,
92            );
93
94            kinds.push(binding);
95            num_bindings += 1;
96        }
97
98        let client = client.unwrap();
99
100        let bindings = kinds
101            .iter_mut()
102            .map(|b| {
103                CopyDescriptor::new(
104                    b.handle.take().unwrap().binding(),
105                    b.shape.clone(),
106                    b.strides.clone(),
107                    b.dtype.size(),
108                )
109            })
110            .collect();
111
112        let mut data: Vec<Option<_>> = client
113            .read_tensor_async(bindings)
114            .await
115            .map_err(|err| ExecutionError::WithContext {
116                reason: format!("{err:?}"),
117            })?
118            .into_iter()
119            .map(Some)
120            .collect::<Vec<Option<_>>>();
121
122        let mut result = TransactionPrimitiveData::default();
123
124        for binding in kinds {
125            let bytes = data.get_mut(binding.index).unwrap().take().unwrap();
126            let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype);
127
128            match binding.kind {
129                Kind::Float => {
130                    result.read_floats.push(t_data);
131                }
132                Kind::Int => {
133                    result.read_ints.push(t_data);
134                }
135                Kind::Bool => {
136                    result.read_bools.push(t_data);
137                }
138            }
139        }
140
141        Ok(result)
142    }
143}