burn_tensor/tensor/api/
transaction.rs

1use core::future::Future;
2
3use super::{BasicOps, Tensor, TensorPrimitive};
4use crate::{
5    backend::Backend,
6    ops::{BoolTensor, IntTensor, TransactionPrimitive},
7    TensorData,
8};
9use alloc::vec::Vec;
10
11#[derive(Default)]
12/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
13/// compute utilization with optimized laziness.
14///
15/// # Example
16///
17/// ```rust,ignore
18///  let [output_data, loss_data, targets_data] = Transaction::default()
19///    .register(output)
20///    .register(loss)
21///    .register(targets)
22///    .execute()
23///    .try_into()
24///    .expect("Correct amount of tensor data");
25/// ```
26pub struct Transaction<B: Backend> {
27    op: TransactionPrimitive<B>,
28    orders: Vec<Order>,
29}
30
31enum Order {
32    Float(usize),
33    QFloat(usize),
34    Int(usize),
35    Bool(usize),
36}
37
38impl<B: Backend> Transaction<B> {
39    /// Add a [tensor](Tensor) to the transaction to be read.
40    pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
41        K::register_transaction(&mut self, tensor.into_primitive());
42        self
43    }
44
45    /// Executes the transaction synchronously and returns the [data](TensorData) in the same order
46    /// in which they were [registered](Self::register).
47    pub fn execute(self) -> Vec<TensorData> {
48        burn_common::future::block_on(self.execute_async())
49    }
50
51    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
52    /// in which they were [registered](Self::register).
53    pub fn execute_async(self) -> impl Future<Output = Vec<TensorData>> {
54        let fut = B::tr_execute(self.op);
55
56        async move {
57            let result = fut.await;
58
59            let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
60            let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
61            let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
62            let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
63
64            self.orders
65                .into_iter()
66                .map(|order| match order {
67                    Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
68                    Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
69                    Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
70                    Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
71                })
72                .collect::<Vec<_>>()
73        }
74    }
75
76    pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
77        match tensor {
78            TensorPrimitive::Float(tensor) => {
79                self.orders.push(Order::Float(self.op.read_floats.len()));
80                self.op.read_floats.push(tensor);
81            }
82            TensorPrimitive::QFloat(tensor) => {
83                self.orders.push(Order::QFloat(self.op.read_qfloats.len()));
84                self.op.read_qfloats.push(tensor);
85            }
86        }
87    }
88
89    pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
90        self.orders.push(Order::Int(self.op.read_ints.len()));
91        self.op.read_ints.push(tensor);
92    }
93
94    pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
95        self.orders.push(Order::Bool(self.op.read_bools.len()));
96        self.op.read_bools.push(tensor);
97    }
98}