burn_tensor/tensor/api/transaction.rs
1use super::{BasicOps, Tensor};
2use crate::{
3 TensorData,
4 backend::{Backend, ExecutionError},
5 ops::TransactionPrimitive,
6};
7use alloc::vec::Vec;
8
9#[derive(Default)]
10/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving
11/// compute utilization with optimized laziness.
12///
13/// # Example
14///
15/// ```rust,ignore
16/// let [output_data, loss_data, targets_data] = Transaction::default()
17/// .register(output)
18/// .register(loss)
19/// .register(targets)
20/// .execute()
21/// .try_into()
22/// .expect("Correct amount of tensor data");
23/// ```
24pub struct Transaction<B: Backend> {
25 op: TransactionPrimitive<B>,
26}
27
28impl<B: Backend> Transaction<B> {
29 /// Add a [tensor](Tensor) to the transaction to be read.
30 pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
31 K::register_transaction(&mut self.op, tensor.into_primitive());
32 self
33 }
34
35 /// Executes the transaction synchronously and returns the [data](TensorData) in the same order
36 /// in which they were [registered](Self::register).
37 pub fn execute(self) -> Vec<TensorData> {
38 burn_std::future::block_on(self.execute_async())
39 .expect("Error while reading data: use `try_execute` to handle error at runtime")
40 }
41
42 /// Executes the transaction synchronously and returns the [data](TensorData) in the same
43 /// order in which they were [registered](Self::register).
44 ///
45 /// # Returns
46 ///
47 /// Any error that might have occurred since the last time the device was synchronized.
48 pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
49 burn_std::future::block_on(self.execute_async())
50 }
51
52 /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
53 /// in which they were [registered](Self::register).
54 pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
55 self.op.execute_async().await
56 }
57}