burn_cubecl/ops/
transaction.rs

1use burn_backend::{
2    DType, TensorData,
3    backend::ExecutionError,
4    ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
5};
6use cubecl::server::{Binding, CopyDescriptor};
7
8use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
9
10impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
11where
12    R: CubeRuntime,
13    F: FloatElement,
14    I: IntElement,
15    BT: BoolElement,
16{
17    async fn tr_execute(
18        transaction: TransactionPrimitive<Self>,
19    ) -> Result<TransactionPrimitiveData, ExecutionError> {
20        let mut client = None;
21
22        enum Kind {
23            Float,
24            Int,
25            Bool,
26        }
27
28        #[derive(new)]
29        struct BindingData {
30            index: usize,
31            kind: Kind,
32            handle: Option<Binding>,
33            shape: Vec<usize>,
34            strides: Vec<usize>,
35            dtype: DType,
36        }
37
38        let mut num_bindings = 0;
39
40        let mut kinds = Vec::new();
41
42        for t in transaction.read_floats.into_iter() {
43            if client.is_none() {
44                client = Some(t.client.clone());
45            }
46
47            let t = crate::kernel::into_contiguous_aligned(t);
48            let binding = BindingData::new(
49                num_bindings,
50                Kind::Float,
51                Some(t.handle.binding()),
52                t.shape.into(),
53                t.strides,
54                t.dtype,
55            );
56
57            kinds.push(binding);
58            num_bindings += 1;
59        }
60        for t in transaction.read_ints.into_iter() {
61            if client.is_none() {
62                client = Some(t.client.clone());
63            }
64
65            let t = crate::kernel::into_contiguous_aligned(t);
66            let binding = BindingData::new(
67                num_bindings,
68                Kind::Int,
69                Some(t.handle.binding()),
70                t.shape.into(),
71                t.strides,
72                t.dtype,
73            );
74
75            kinds.push(binding);
76            num_bindings += 1;
77        }
78        for t in transaction.read_bools.into_iter() {
79            if client.is_none() {
80                client = Some(t.client.clone());
81            }
82
83            let t = crate::kernel::into_contiguous_aligned(t);
84            let binding = BindingData::new(
85                num_bindings,
86                Kind::Bool,
87                Some(t.handle.binding()),
88                t.shape.into(),
89                t.strides,
90                t.dtype,
91            );
92
93            kinds.push(binding);
94            num_bindings += 1;
95        }
96
97        let client = client.unwrap();
98
99        let bindings = kinds
100            .iter_mut()
101            .map(|b| {
102                CopyDescriptor::new(
103                    b.handle.take().unwrap(),
104                    &b.shape,
105                    &b.strides,
106                    b.dtype.size(),
107                )
108            })
109            .collect();
110
111        let mut data: Vec<Option<_>> = client
112            .read_tensor_async(bindings)
113            .await
114            .map_err(|err| ExecutionError::WithContext {
115                reason: format!("{err:?}"),
116            })?
117            .into_iter()
118            .map(Some)
119            .collect::<Vec<Option<_>>>();
120
121        let mut result = TransactionPrimitiveData::default();
122
123        for binding in kinds {
124            let bytes = data.get_mut(binding.index).unwrap().take().unwrap();
125            let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype);
126
127            match binding.kind {
128                Kind::Float => {
129                    result.read_floats.push(t_data);
130                }
131                Kind::Int => {
132                    result.read_ints.push(t_data);
133                }
134                Kind::Bool => {
135                    result.read_bools.push(t_data);
136                }
137            }
138        }
139
140        Ok(result)
141    }
142}