burn_tensor/tensor/api/
transaction.rs1use core::future::Future;
2
3use super::{BasicOps, Tensor, TensorPrimitive};
4use crate::{
5 backend::Backend,
6 ops::{BoolTensor, IntTensor, TransactionPrimitive},
7 TensorData,
8};
9use alloc::vec::Vec;
10
11#[derive(Default)]
12pub struct Transaction<B: Backend> {
27 op: TransactionPrimitive<B>,
28 orders: Vec<Order>,
29}
30
31enum Order {
32 Float(usize),
33 QFloat(usize),
34 Int(usize),
35 Bool(usize),
36}
37
38impl<B: Backend> Transaction<B> {
39 pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
41 K::register_transaction(&mut self, tensor.into_primitive());
42 self
43 }
44
45 pub fn execute(self) -> Vec<TensorData> {
48 burn_common::future::block_on(self.execute_async())
49 }
50
51 pub fn execute_async(self) -> impl Future<Output = Vec<TensorData>> {
54 let fut = B::tr_execute(self.op);
55
56 async move {
57 let result = fut.await;
58
59 let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
60 let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
61 let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
62 let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
63
64 self.orders
65 .into_iter()
66 .map(|order| match order {
67 Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
68 Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
69 Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
70 Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
71 })
72 .collect::<Vec<_>>()
73 }
74 }
75
76 pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
77 match tensor {
78 TensorPrimitive::Float(tensor) => {
79 self.orders.push(Order::Float(self.op.read_floats.len()));
80 self.op.read_floats.push(tensor);
81 }
82 TensorPrimitive::QFloat(tensor) => {
83 self.orders.push(Order::QFloat(self.op.read_qfloats.len()));
84 self.op.read_qfloats.push(tensor);
85 }
86 }
87 }
88
89 pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
90 self.orders.push(Order::Int(self.op.read_ints.len()));
91 self.op.read_ints.push(tensor);
92 }
93
94 pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
95 self.orders.push(Order::Bool(self.op.read_bools.len()));
96 self.op.read_bools.push(tensor);
97 }
98}