burn_tensor/tensor/api/
transaction.rs

1use super::{BasicOps, Tensor};
2use crate::{
3    TensorData,
4    backend::{Backend, ExecutionError},
5    ops::TransactionPrimitive,
6};
7use alloc::vec::Vec;
8
9#[derive(Default)]
10/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
11/// compute utilization with optimized laziness.
12///
13/// # Example
14///
15/// ```rust,ignore
16///  let [output_data, loss_data, targets_data] = Transaction::default()
17///    .register(output)
18///    .register(loss)
19///    .register(targets)
20///    .execute()
21///    .try_into()
22///    .expect("Correct amount of tensor data");
23/// ```
24pub struct Transaction<B: Backend> {
25    op: TransactionPrimitive<B>,
26}
27
28impl<B: Backend> Transaction<B> {
29    /// Add a [tensor](Tensor) to the transaction to be read.
30    pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
31        K::register_transaction(&mut self.op, tensor.into_primitive());
32        self
33    }
34
35    /// Executes the transaction synchronously and returns the [data](TensorData) in the same order
36    /// in which they were [registered](Self::register).
37    pub fn execute(self) -> Vec<TensorData> {
38        burn_std::future::block_on(self.execute_async())
39            .expect("Error while reading data: use `try_execute` to handle error at runtime")
40    }
41
42    /// Executes the transaction synchronously and returns the [data](TensorData) in the same
43    /// order in which they were [registered](Self::register).
44    ///
45    /// # Returns
46    ///
47    /// Any error that might have occurred since the last time the device was synchronized.
48    pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
49        burn_std::future::block_on(self.execute_async())
50    }
51
52    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
53    /// in which they were [registered](Self::register).
54    pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
55        self.op.execute_async().await
56    }
57}