burn_backend/backend/ops/
transaction.rs

1use alloc::vec::Vec;
2use core::future::Future;
3
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
5use crate::{Backend, ExecutionError, TensorData, TensorPrimitive};
6
7enum Order {
8    Float(usize),
9    QFloat(usize),
10    Int(usize),
11    Bool(usize),
12}
13
14#[derive(Default)]
15/// Contains all tensor primitives that are going to be read.
16pub struct TransactionPrimitive<B: Backend> {
17    /// Float tensors.
18    pub read_floats: Vec<FloatTensor<B>>,
19    /// Quantized tensors.
20    pub read_qfloats: Vec<QuantizedTensor<B>>,
21    /// Int tensors.
22    pub read_ints: Vec<IntTensor<B>>,
23    /// Bool tensors.
24    pub read_bools: Vec<BoolTensor<B>>,
25    orders: Vec<Order>,
26}
27
28#[derive(Default)]
29/// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive).
30pub struct TransactionPrimitiveData {
31    /// Float tensor data.
32    pub read_floats: Vec<TensorData>,
33    /// Quantized tensor data.
34    pub read_qfloats: Vec<TensorData>,
35    /// Int tensor data.
36    pub read_ints: Vec<TensorData>,
37    /// Bool tensor data.
38    pub read_bools: Vec<TensorData>,
39}
40
41/// Operations that are sync by nature and that can be batch together in transactions to improve
42/// compute utilization with efficient laziness.
43pub trait TransactionOps<B: Backend> {
44    /// Executes a [transaction](TransactionPrimitive) and return its
45    /// [data](TransactionPrimitiveData).
46    fn tr_execute(
47        transaction: TransactionPrimitive<B>,
48    ) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send {
49        async move {
50            let mut floats = Vec::new();
51            let mut qfloats = Vec::new();
52            let mut ints = Vec::new();
53            let mut bools = Vec::new();
54
55            for t in transaction.read_floats {
56                floats.push(B::float_into_data(t).await?);
57            }
58            for t in transaction.read_qfloats {
59                qfloats.push(B::q_into_data(t).await?);
60            }
61            for t in transaction.read_ints {
62                ints.push(B::int_into_data(t).await?);
63            }
64            for t in transaction.read_bools {
65                bools.push(B::bool_into_data(t).await?);
66            }
67
68            Ok(TransactionPrimitiveData {
69                read_floats: floats,
70                read_qfloats: qfloats,
71                read_ints: ints,
72                read_bools: bools,
73            })
74        }
75    }
76}
77
78impl<B: Backend> TransactionPrimitive<B> {
79    /// Creates a new transaction.
80    pub fn new(
81        read_floats: Vec<FloatTensor<B>>,
82        read_qfloats: Vec<QuantizedTensor<B>>,
83        read_ints: Vec<IntTensor<B>>,
84        read_bools: Vec<BoolTensor<B>>,
85    ) -> Self {
86        Self {
87            read_floats,
88            read_qfloats,
89            read_ints,
90            read_bools,
91            orders: Vec::default(),
92        }
93    }
94    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order
95    /// in which they were [registered](crate::tensor::BasicOps::register_transaction).
96    pub async fn execute_async(mut self) -> Result<Vec<TensorData>, ExecutionError> {
97        let mut orders = Vec::new();
98        core::mem::swap(&mut orders, &mut self.orders);
99        let result = B::tr_execute(self).await?;
100
101        let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
102        let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
103        let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
104        let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
105
106        Ok(orders
107            .into_iter()
108            .map(|order| match order {
109                Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
110                Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
111                Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
112                Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
113            })
114            .collect::<Vec<_>>())
115    }
116
117    pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
118        match tensor {
119            TensorPrimitive::Float(tensor) => {
120                self.orders.push(Order::Float(self.read_floats.len()));
121                self.read_floats.push(tensor);
122            }
123            TensorPrimitive::QFloat(tensor) => {
124                self.orders.push(Order::QFloat(self.read_qfloats.len()));
125                self.read_qfloats.push(tensor);
126            }
127        }
128    }
129
130    pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
131        self.orders.push(Order::Int(self.read_ints.len()));
132        self.read_ints.push(tensor);
133    }
134
135    pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
136        self.orders.push(Order::Bool(self.read_bools.len()));
137        self.read_bools.push(tensor);
138    }
139}