burn_jit/ops/
transaction.rs1use burn_tensor::{
2 ops::{TransactionOps, TransactionPrimitiveResult},
3 DType, TensorData,
4};
5
6use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
7
8impl<R, F, I, BT> TransactionOps<Self> for JitBackend<R, F, I, BT>
9where
10 R: JitRuntime,
11 F: FloatElement,
12 I: IntElement,
13 BT: BoolElement,
14{
15 fn tr_execute(
16 transaction: burn_tensor::ops::TransactionPrimitive<Self>,
17 ) -> impl std::future::Future<Output = burn_tensor::ops::TransactionPrimitiveResult> + 'static + Send
18 {
19 let mut bindings = Vec::new();
20 let mut client = None;
21
22 enum Kind {
23 Float(usize, Vec<usize>, DType),
24 Int(usize, Vec<usize>, DType),
25 Bool(usize, Vec<usize>, DType),
26 }
27
28 let mut num_bindings = 0;
29
30 let mut kinds = Vec::new();
31
32 transaction.read_floats.into_iter().for_each(|t| {
33 if client.is_none() {
34 client = Some(t.client.clone());
35 }
36
37 kinds.push(Kind::Float(num_bindings, t.shape.into(), F::dtype()));
38 num_bindings += 1;
39 bindings.push(t.handle.binding())
40 });
41 transaction.read_ints.into_iter().for_each(|t| {
42 if client.is_none() {
43 client = Some(t.client.clone());
44 }
45
46 kinds.push(Kind::Int(num_bindings, t.shape.into(), I::dtype()));
47 num_bindings += 1;
48 bindings.push(t.handle.binding())
49 });
50 transaction.read_bools.into_iter().for_each(|t| {
51 if client.is_none() {
52 client = Some(t.client.clone());
53 }
54
55 kinds.push(Kind::Bool(num_bindings, t.shape.into(), BT::dtype()));
56 num_bindings += 1;
57 bindings.push(t.handle.binding())
58 });
59
60 let client = client.unwrap();
61
62 async move {
63 let mut data: Vec<Option<_>> = client
64 .read_async(bindings)
65 .await
66 .into_iter()
67 .map(Some)
68 .collect::<Vec<Option<_>>>();
69
70 let mut result = TransactionPrimitiveResult::default();
71
72 for kind in kinds {
73 match kind {
74 Kind::Float(index, shape, dtype) => {
75 let bytes = data.get_mut(index).unwrap().take().unwrap();
76 result
77 .read_floats
78 .push(TensorData::from_bytes(bytes, shape, dtype));
79 }
80 Kind::Int(index, shape, dtype) => {
81 let bytes = data.get_mut(index).unwrap().take().unwrap();
82 result
83 .read_ints
84 .push(TensorData::from_bytes(bytes, shape, dtype));
85 }
86 Kind::Bool(index, shape, dtype) => {
87 let bytes = data.get_mut(index).unwrap().take().unwrap();
88 result
89 .read_bools
90 .push(TensorData::from_bytes(bytes, shape, dtype));
91 }
92 }
93 }
94
95 result
96 }
97 }
98}