Skip to main content

burn_tensor/tensor/api/
transaction.rs

1use super::{BasicOps, Tensor};
2use crate::{
3    TensorData,
4    backend::{Backend, ExecutionError},
5    ops::TransactionPrimitive,
6};
7use alloc::vec::Vec;
8use burn_backend::tensor::TransactionOp;
9
10#[derive(Default)]
11/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
12/// compute utilization with optimized laziness.
13///
14/// # Example
15///
16/// ```rust,ignore
17///  let [output_data, loss_data, targets_data] = Transaction::default()
18///    .register(output)
19///    .register(loss)
20///    .register(targets)
21///    .execute()
22///    .try_into()
23///    .expect("Correct amount of tensor data");
24/// ```
25pub struct Transaction<B: Backend> {
26    op: TransactionPrimitive<B>,
27}
28
29impl<B: Backend> Transaction<B> {
30    /// Add a [tensor](Tensor) to the transaction to be read.
31    pub fn register<const D: usize, K: TransactionOp<B> + BasicOps<B>>(
32        mut self,
33        tensor: Tensor<B, D, K>,
34    ) -> Self {
35        K::register_transaction(&mut self.op, tensor.into_primitive());
36        self
37    }
38
39    /// Executes the transaction synchronously and returns the [data](TensorData) in the same order
40    /// in which they were [registered](Self::register).
41    pub fn execute(self) -> Vec<TensorData> {
42        burn_std::future::block_on(self.execute_async())
43            .expect("Error while reading data: use `try_execute` to handle error at runtime")
44    }
45
46    /// Executes the transaction synchronously and returns the [data](TensorData) in the same
47    /// order in which they were [registered](Self::register).
48    ///
49    /// # Returns
50    ///
51    /// Any error that might have occurred since the last time the device was synchronized.
52    pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
53        burn_std::future::block_on(self.execute_async())
54    }
55
56    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
57    /// in which they were [registered](Self::register).
58    pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
59        self.op.execute_async().await
60    }
61}