use super::{BasicOps, Tensor};
use crate::{
TensorData,
backend::{Backend, ExecutionError},
ops::TransactionPrimitive,
};
use alloc::vec::Vec;
#[derive(Default)]
pub struct Transaction<B: Backend> {
op: TransactionPrimitive<B>,
}
impl<B: Backend> Transaction<B> {
pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {
K::register_transaction(&mut self.op, tensor.into_primitive());
self
}
pub fn execute(self) -> Vec<TensorData> {
burn_std::future::block_on(self.execute_async())
.expect("Error while reading data: use `try_execute` to handle error at runtime")
}
pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {
burn_std::future::block_on(self.execute_async())
}
pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {
self.op.execute_async().await
}
}