burn_cubecl/ops/
transaction.rs1use burn_backend::{
2 DType, TensorData,
3 backend::ExecutionError,
4 ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
5};
6use burn_std::{Shape, Strides};
7use cubecl::server::{CopyDescriptor, Handle};
8
9use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
10
11impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
12where
13 R: CubeRuntime,
14 F: FloatElement,
15 I: IntElement,
16 BT: BoolElement,
17{
18 async fn tr_execute(
19 transaction: TransactionPrimitive<Self>,
20 ) -> Result<TransactionPrimitiveData, ExecutionError> {
21 let mut client = None;
22
23 enum Kind {
24 Float,
25 Int,
26 Bool,
27 }
28
29 #[derive(new)]
30 struct BindingData {
31 index: usize,
32 kind: Kind,
33 handle: Option<Handle>,
34 shape: Shape,
35 strides: Strides,
36 dtype: DType,
37 }
38
39 let mut num_bindings = 0;
40
41 let mut kinds = Vec::new();
42
43 for t in transaction.read_floats.into_iter() {
44 if client.is_none() {
45 client = Some(t.client.clone());
46 }
47
48 let t = crate::kernel::into_contiguous_aligned(t);
49 let binding = BindingData::new(
50 num_bindings,
51 Kind::Float,
52 Some(t.handle.clone()),
53 t.meta.shape.clone(),
54 t.meta.strides.clone(),
55 t.dtype,
56 );
57
58 kinds.push(binding);
59 num_bindings += 1;
60 }
61 for t in transaction.read_ints.into_iter() {
62 if client.is_none() {
63 client = Some(t.client.clone());
64 }
65
66 let t = crate::kernel::into_contiguous_aligned(t);
67 let binding = BindingData::new(
68 num_bindings,
69 Kind::Int,
70 Some(t.handle.clone()),
71 t.meta.shape.clone(),
72 t.meta.strides.clone(),
73 t.dtype,
74 );
75
76 kinds.push(binding);
77 num_bindings += 1;
78 }
79 for t in transaction.read_bools.into_iter() {
80 if client.is_none() {
81 client = Some(t.client.clone());
82 }
83
84 let t = crate::kernel::into_contiguous_aligned(t);
85 let binding = BindingData::new(
86 num_bindings,
87 Kind::Bool,
88 Some(t.handle.clone()),
89 t.meta.shape.clone(),
90 t.meta.strides.clone(),
91 t.dtype,
92 );
93
94 kinds.push(binding);
95 num_bindings += 1;
96 }
97
98 let client = client.unwrap();
99
100 let bindings = kinds
101 .iter_mut()
102 .map(|b| {
103 CopyDescriptor::new(
104 b.handle.take().unwrap().binding(),
105 b.shape.clone(),
106 b.strides.clone(),
107 b.dtype.size(),
108 )
109 })
110 .collect();
111
112 let mut data: Vec<Option<_>> = client
113 .read_tensor_async(bindings)
114 .await
115 .map_err(|err| ExecutionError::WithContext {
116 reason: format!("{err:?}"),
117 })?
118 .into_iter()
119 .map(Some)
120 .collect::<Vec<Option<_>>>();
121
122 let mut result = TransactionPrimitiveData::default();
123
124 for binding in kinds {
125 let bytes = data.get_mut(binding.index).unwrap().take().unwrap();
126 let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype);
127
128 match binding.kind {
129 Kind::Float => {
130 result.read_floats.push(t_data);
131 }
132 Kind::Int => {
133 result.read_ints.push(t_data);
134 }
135 Kind::Bool => {
136 result.read_bools.push(t_data);
137 }
138 }
139 }
140
141 Ok(result)
142 }
143}