burn_cubecl/ops/
transaction.rs

1use burn_tensor::{
2    DType, TensorData,
3    ops::{TransactionOps, TransactionPrimitiveResult},
4};
5
6use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
7
8impl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>
9where
10    R: CubeRuntime,
11    F: FloatElement,
12    I: IntElement,
13    BT: BoolElement,
14{
15    async fn tr_execute(
16        transaction: burn_tensor::ops::TransactionPrimitive<Self>,
17    ) -> burn_tensor::ops::TransactionPrimitiveResult {
18        let mut bindings = Vec::new();
19        let mut client = None;
20
21        enum Kind {
22            Float(usize, Vec<usize>, DType),
23            Int(usize, Vec<usize>, DType),
24            Bool(usize, Vec<usize>, DType),
25        }
26
27        let mut num_bindings = 0;
28
29        let mut kinds = Vec::new();
30
31        transaction.read_floats.into_iter().for_each(|t| {
32            if client.is_none() {
33                client = Some(t.client.clone());
34            }
35
36            kinds.push(Kind::Float(num_bindings, t.shape.into(), F::dtype()));
37            num_bindings += 1;
38            bindings.push(t.handle)
39        });
40        transaction.read_ints.into_iter().for_each(|t| {
41            if client.is_none() {
42                client = Some(t.client.clone());
43            }
44
45            kinds.push(Kind::Int(num_bindings, t.shape.into(), I::dtype()));
46            num_bindings += 1;
47            bindings.push(t.handle)
48        });
49        transaction.read_bools.into_iter().for_each(|t| {
50            if client.is_none() {
51                client = Some(t.client.clone());
52            }
53
54            kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype()));
55            num_bindings += 1;
56            bindings.push(t.handle)
57        });
58
59        let client = client.unwrap();
60
61        let mut data: Vec<Option<_>> = client
62            .read_async(bindings)
63            .await
64            .into_iter()
65            .map(Some)
66            .collect::<Vec<Option<_>>>();
67
68        let mut result = TransactionPrimitiveResult::default();
69
70        for kind in kinds {
71            match kind {
72                Kind::Float(index, shape, dtype) => {
73                    let bytes = data.get_mut(index).unwrap().take().unwrap();
74                    result
75                        .read_floats
76                        .push(TensorData::from_bytes(bytes, shape, dtype));
77                }
78                Kind::Int(index, shape, dtype) => {
79                    let bytes = data.get_mut(index).unwrap().take().unwrap();
80                    result
81                        .read_ints
82                        .push(TensorData::from_bytes(bytes, shape, dtype));
83                }
84                Kind::Bool(index, shape, dtype) => {
85                    let bytes = data.get_mut(index).unwrap().take().unwrap();
86                    result
87                        .read_bools
88                        .push(TensorData::from_bytes(bytes, shape, dtype));
89                }
90            }
91        }
92
93        result
94    }
95}