burn_cubecl/ops/
transaction.rs1use burn_backend::{
2 DType, TensorData,
3 backend::ExecutionError,
4 ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},
5};
6use cubecl::server::{Binding, CopyDescriptor};
7
8use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
9
10impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
11where
12 R: CubeRuntime,
13 F: FloatElement,
14 I: IntElement,
15 BT: BoolElement,
16{
17 async fn tr_execute(
18 transaction: TransactionPrimitive<Self>,
19 ) -> Result<TransactionPrimitiveData, ExecutionError> {
20 let mut client = None;
21
22 enum Kind {
23 Float,
24 Int,
25 Bool,
26 }
27
28 #[derive(new)]
29 struct BindingData {
30 index: usize,
31 kind: Kind,
32 handle: Option<Binding>,
33 shape: Vec<usize>,
34 strides: Vec<usize>,
35 dtype: DType,
36 }
37
38 let mut num_bindings = 0;
39
40 let mut kinds = Vec::new();
41
42 for t in transaction.read_floats.into_iter() {
43 if client.is_none() {
44 client = Some(t.client.clone());
45 }
46
47 let t = crate::kernel::into_contiguous_aligned(t);
48 let binding = BindingData::new(
49 num_bindings,
50 Kind::Float,
51 Some(t.handle.binding()),
52 t.shape.into(),
53 t.strides,
54 t.dtype,
55 );
56
57 kinds.push(binding);
58 num_bindings += 1;
59 }
60 for t in transaction.read_ints.into_iter() {
61 if client.is_none() {
62 client = Some(t.client.clone());
63 }
64
65 let t = crate::kernel::into_contiguous_aligned(t);
66 let binding = BindingData::new(
67 num_bindings,
68 Kind::Int,
69 Some(t.handle.binding()),
70 t.shape.into(),
71 t.strides,
72 t.dtype,
73 );
74
75 kinds.push(binding);
76 num_bindings += 1;
77 }
78 for t in transaction.read_bools.into_iter() {
79 if client.is_none() {
80 client = Some(t.client.clone());
81 }
82
83 let t = crate::kernel::into_contiguous_aligned(t);
84 let binding = BindingData::new(
85 num_bindings,
86 Kind::Bool,
87 Some(t.handle.binding()),
88 t.shape.into(),
89 t.strides,
90 t.dtype,
91 );
92
93 kinds.push(binding);
94 num_bindings += 1;
95 }
96
97 let client = client.unwrap();
98
99 let bindings = kinds
100 .iter_mut()
101 .map(|b| {
102 CopyDescriptor::new(
103 b.handle.take().unwrap(),
104 &b.shape,
105 &b.strides,
106 b.dtype.size(),
107 )
108 })
109 .collect();
110
111 let mut data: Vec<Option<_>> = client
112 .read_tensor_async(bindings)
113 .await
114 .map_err(|err| ExecutionError::WithContext {
115 reason: format!("{err:?}"),
116 })?
117 .into_iter()
118 .map(Some)
119 .collect::<Vec<Option<_>>>();
120
121 let mut result = TransactionPrimitiveData::default();
122
123 for binding in kinds {
124 let bytes = data.get_mut(binding.index).unwrap().take().unwrap();
125 let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype);
126
127 match binding.kind {
128 Kind::Float => {
129 result.read_floats.push(t_data);
130 }
131 Kind::Int => {
132 result.read_ints.push(t_data);
133 }
134 Kind::Bool => {
135 result.read_bools.push(t_data);
136 }
137 }
138 }
139
140 Ok(result)
141 }
142}