burn_cubecl/ops/
transaction.rs1use burn_tensor::{
2 DType, TensorData,
3 ops::{TransactionOps, TransactionPrimitiveResult},
4};
5
6use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
7
8impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
9where
10 R: CubeRuntime,
11 F: FloatElement,
12 I: IntElement,
13 BT: BoolElement,
14{
15 async fn tr_execute(
16 transaction: burn_tensor::ops::TransactionPrimitive<Self>,
17 ) -> burn_tensor::ops::TransactionPrimitiveResult {
18 let mut bindings = Vec::new();
19 let mut client = None;
20
21 enum Kind {
22 Float(usize, Vec<usize>, DType),
23 Int(usize, Vec<usize>, DType),
24 Bool(usize, Vec<usize>, DType),
25 }
26
27 let mut num_bindings = 0;
28
29 let mut kinds = Vec::new();
30
31 transaction.read_floats.into_iter().for_each(|t| {
32 if client.is_none() {
33 client = Some(t.client.clone());
34 }
35
36 kinds.push(Kind::Float(num_bindings, t.shape.into(), F::dtype()));
37 num_bindings += 1;
38 bindings.push(t.handle)
39 });
40 transaction.read_ints.into_iter().for_each(|t| {
41 if client.is_none() {
42 client = Some(t.client.clone());
43 }
44
45 kinds.push(Kind::Int(num_bindings, t.shape.into(), I::dtype()));
46 num_bindings += 1;
47 bindings.push(t.handle)
48 });
49 transaction.read_bools.into_iter().for_each(|t| {
50 if client.is_none() {
51 client = Some(t.client.clone());
52 }
53
54 kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype()));
55 num_bindings += 1;
56 bindings.push(t.handle)
57 });
58
59 let client = client.unwrap();
60
61 let mut data: Vec<Option<_>> = client
62 .read_async(bindings)
63 .await
64 .into_iter()
65 .map(Some)
66 .collect::<Vec<Option<_>>>();
67
68 let mut result = TransactionPrimitiveResult::default();
69
70 for kind in kinds {
71 match kind {
72 Kind::Float(index, shape, dtype) => {
73 let bytes = data.get_mut(index).unwrap().take().unwrap();
74 result
75 .read_floats
76 .push(TensorData::from_bytes(bytes, shape, dtype));
77 }
78 Kind::Int(index, shape, dtype) => {
79 let bytes = data.get_mut(index).unwrap().take().unwrap();
80 result
81 .read_ints
82 .push(TensorData::from_bytes(bytes, shape, dtype));
83 }
84 Kind::Bool(index, shape, dtype) => {
85 let bytes = data.get_mut(index).unwrap().take().unwrap();
86 result
87 .read_bools
88 .push(TensorData::from_bytes(bytes, shape, dtype));
89 }
90 }
91 }
92
93 result
94 }
95}