burn_tensor/tensor/ops/
transaction.rs

1use alloc::vec::Vec;
2use core::future::Future;
3
4use super::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
5use crate::{TensorData, backend::Backend};
6
7#[derive(Default)]
8/// Contains all tensor primitives that are going to be read.
9pub struct TransactionPrimitive<B: Backend> {
10    /// Float tensors.
11    pub read_floats: Vec<FloatTensor<B>>,
12    /// Quantized tensors.
13    pub read_qfloats: Vec<QuantizedTensor<B>>,
14    /// Int tensors.
15    pub read_ints: Vec<IntTensor<B>>,
16    /// Bool tensors.
17    pub read_bools: Vec<BoolTensor<B>>,
18}
19
20#[derive(Default)]
21/// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive).
22pub struct TransactionPrimitiveResult {
23    /// Float tensor data.
24    pub read_floats: Vec<TensorData>,
25    /// Quantized tensor data.
26    pub read_qfloats: Vec<TensorData>,
27    /// Int tensor data.
28    pub read_ints: Vec<TensorData>,
29    /// Bool tensor data.
30    pub read_bools: Vec<TensorData>,
31}
32
33/// Operations that are sync by nature and that can be batch together in transactions to improve
34/// compute utilization with efficient laziness.
35pub trait TransactionOps<B: Backend> {
36    /// Executes a [transaction](TransactionPrimitive) and return its
37    /// [result](TransactionPrimitiveResult).
38    fn tr_execute(
39        transaction: TransactionPrimitive<B>,
40    ) -> impl Future<Output = TransactionPrimitiveResult> + Send {
41        async move {
42            let mut floats = Vec::new();
43            let mut qfloats = Vec::new();
44            let mut ints = Vec::new();
45            let mut bools = Vec::new();
46
47            for t in transaction.read_floats {
48                floats.push(B::float_into_data(t).await);
49            }
50            for t in transaction.read_qfloats {
51                qfloats.push(B::q_into_data(t).await);
52            }
53            for t in transaction.read_ints {
54                ints.push(B::int_into_data(t).await);
55            }
56            for t in transaction.read_bools {
57                bools.push(B::bool_into_data(t).await);
58            }
59
60            TransactionPrimitiveResult {
61                read_floats: floats,
62                read_qfloats: qfloats,
63                read_ints: ints,
64                read_bools: bools,
65            }
66        }
67    }
68}