burn_tensor/tensor/api/transaction.rs
1use super::{BasicOps, Tensor};
2use crate::{
3 TensorData,
4 backend::{Backend, ExecutionError},
5 ops::TransactionPrimitive,
6};
7use alloc::vec::Vec;
8use burn_backend::tensor::TransactionOp;
9
10#[derive(Default)]
11/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
12/// compute utilization with optimized laziness.
13///
14/// # Example
15///
16/// ```rust,ignore
17/// let [output_data, loss_data, targets_data] = Transaction::default()
18/// .register(output)
19/// .register(loss)
20/// .register(targets)
21/// .execute()
22/// .try_into()
23/// .expect("Correct amount of tensor data");
24/// ```
25pub struct Transaction<B: Backend> {
26 op: TransactionPrimitive<B>,
27}
28
29impl<B: Backend> Transaction<B> {
30 /// Add a [tensor](Tensor) to the transaction to be read.
31 pub fn register<const D: usize, K: TransactionOp<B> + BasicOps<B>>(
32 mut self,
33 tensor: Tensor<B, D, K>,
34 ) -> Self {
35 K::register_transaction(&mut self.op, tensor.into_primitive());
36 self
37 }
38
39 /// Executes the transaction synchronously and returns the [data](TensorData) in the same order
40 /// in which they were [registered](Self::register).
41 pub fn execute(self) -> Vec<TensorData> {
42 burn_std::future::block_on(self.execute_async())
43 .expect("Error while reading data: use `try_execute` to handle error at runtime")
44 }
45
46 /// Executes the transaction synchronously and returns the [data](TensorData) in the same
47 /// order in which they were [registered](Self::register).
48 ///
49 /// # Returns
50 ///
51 /// Any error that might have occurred since the last time the device was synchronized.
52 pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
53 burn_std::future::block_on(self.execute_async())
54 }
55
56 /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
57 /// in which they were [registered](Self::register).
58 pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
59 self.op.execute_async().await
60 }
61}