burn_jit/ops/
transaction.rs

1use burn_tensor::{
2    ops::{TransactionOps, TransactionPrimitiveResult},
3    DType, TensorData,
4};
5
6use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
7
8impl<R, F, I, BT> TransactionOps<Self> for JitBackend<R, F, I, BT>
9where
10    R: JitRuntime,
11    F: FloatElement,
12    I: IntElement,
13    BT: BoolElement,
14{
15    fn tr_execute(
16        transaction: burn_tensor::ops::TransactionPrimitive<Self>,
17    ) -> impl std::future::Future<Output = burn_tensor::ops::TransactionPrimitiveResult> + 'static + Send
18    {
19        let mut bindings = Vec::new();
20        let mut client = None;
21
22        enum Kind {
23            Float(usize, Vec<usize>, DType),
24            Int(usize, Vec<usize>, DType),
25            Bool(usize, Vec<usize>, DType),
26        }
27
28        let mut num_bindings = 0;
29
30        let mut kinds = Vec::new();
31
32        transaction.read_floats.into_iter().for_each(|t| {
33            if client.is_none() {
34                client = Some(t.client.clone());
35            }
36
37            kinds.push(Kind::Float(num_bindings, t.shape.into(), F::dtype()));
38            num_bindings += 1;
39            bindings.push(t.handle.binding())
40        });
41        transaction.read_ints.into_iter().for_each(|t| {
42            if client.is_none() {
43                client = Some(t.client.clone());
44            }
45
46            kinds.push(Kind::Int(num_bindings, t.shape.into(), I::dtype()));
47            num_bindings += 1;
48            bindings.push(t.handle.binding())
49        });
50        transaction.read_bools.into_iter().for_each(|t| {
51            if client.is_none() {
52                client = Some(t.client.clone());
53            }
54
55            kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype()));
56            num_bindings += 1;
57            bindings.push(t.handle.binding())
58        });
59
60        let client = client.unwrap();
61
62        async move {
63            let mut data: Vec<Option<_>> = client
64                .read_async(bindings)
65                .await
66                .into_iter()
67                .map(Some)
68                .collect::<Vec<Option<_>>>();
69
70            let mut result = TransactionPrimitiveResult::default();
71
72            for kind in kinds {
73                match kind {
74                    Kind::Float(index, shape, dtype) => {
75                        let bytes = data.get_mut(index).unwrap().take().unwrap();
76                        result
77                            .read_floats
78                            .push(TensorData::from_bytes(bytes, shape, dtype));
79                    }
80                    Kind::Int(index, shape, dtype) => {
81                        let bytes = data.get_mut(index).unwrap().take().unwrap();
82                        result
83                            .read_ints
84                            .push(TensorData::from_bytes(bytes, shape, dtype));
85                    }
86                    Kind::Bool(index, shape, dtype) => {
87                        let bytes = data.get_mut(index).unwrap().take().unwrap();
88                        result
89                            .read_bools
90                            .push(TensorData::from_bytes(bytes, shape, dtype));
91                    }
92                }
93            }
94
95            result
96        }
97    }
98}