burn_tensor/tensor/api/
transaction.rs1use super::{BasicOps, Tensor, TensorPrimitive};
2use crate::{
3 TensorData,
4 backend::Backend,
5 ops::{BoolTensor, IntTensor, TransactionPrimitive},
6};
7use alloc::vec::Vec;
8
9#[derive(Default)]
10pub struct Transaction<B: Backend> {
25 op: TransactionPrimitive<B>,
26 orders: Vec<Order>,
27}
28
29enum Order {
30 Float(usize),
31 QFloat(usize),
32 Int(usize),
33 Bool(usize),
34}
35
36impl<B: Backend> Transaction<B> {
37 pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
39 K::register_transaction(&mut self, tensor.into_primitive());
40 self
41 }
42
43 pub fn execute(self) -> Vec<TensorData> {
46 burn_common::future::block_on(self.execute_async())
47 }
48
49 pub async fn execute_async(self) -> Vec<TensorData> {
52 let result = B::tr_execute(self.op).await;
53
54 let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
55 let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
56 let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
57 let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
58
59 self.orders
60 .into_iter()
61 .map(|order| match order {
62 Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
63 Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
64 Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
65 Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
66 })
67 .collect::<Vec<_>>()
68 }
69
70 pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
71 match tensor {
72 TensorPrimitive::Float(tensor) => {
73 self.orders.push(Order::Float(self.op.read_floats.len()));
74 self.op.read_floats.push(tensor);
75 }
76 TensorPrimitive::QFloat(tensor) => {
77 self.orders.push(Order::QFloat(self.op.read_qfloats.len()));
78 self.op.read_qfloats.push(tensor);
79 }
80 }
81 }
82
83 pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
84 self.orders.push(Order::Int(self.op.read_ints.len()));
85 self.op.read_ints.push(tensor);
86 }
87
88 pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
89 self.orders.push(Order::Bool(self.op.read_bools.len()));
90 self.op.read_bools.push(tensor);
91 }
92}