burn_backend/backend/ops/
transaction.rs1use alloc::vec::Vec;
2use core::future::Future;
3
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
5use crate::{Backend, ExecutionError, TensorData, TensorPrimitive};
6
7enum Order {
8 Float(usize),
9 QFloat(usize),
10 Int(usize),
11 Bool(usize),
12}
13
14#[derive(Default)]
15pub struct TransactionPrimitive<B: Backend> {
17 pub read_floats: Vec<FloatTensor<B>>,
19 pub read_qfloats: Vec<QuantizedTensor<B>>,
21 pub read_ints: Vec<IntTensor<B>>,
23 pub read_bools: Vec<BoolTensor<B>>,
25 orders: Vec<Order>,
26}
27
28#[derive(Default)]
29pub struct TransactionPrimitiveData {
31 pub read_floats: Vec<TensorData>,
33 pub read_qfloats: Vec<TensorData>,
35 pub read_ints: Vec<TensorData>,
37 pub read_bools: Vec<TensorData>,
39}
40
41pub trait TransactionOps<B: Backend> {
44 fn tr_execute(
47 transaction: TransactionPrimitive<B>,
48 ) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send {
49 async move {
50 let mut floats = Vec::new();
51 let mut qfloats = Vec::new();
52 let mut ints = Vec::new();
53 let mut bools = Vec::new();
54
55 for t in transaction.read_floats {
56 floats.push(B::float_into_data(t).await?);
57 }
58 for t in transaction.read_qfloats {
59 qfloats.push(B::q_into_data(t).await?);
60 }
61 for t in transaction.read_ints {
62 ints.push(B::int_into_data(t).await?);
63 }
64 for t in transaction.read_bools {
65 bools.push(B::bool_into_data(t).await?);
66 }
67
68 Ok(TransactionPrimitiveData {
69 read_floats: floats,
70 read_qfloats: qfloats,
71 read_ints: ints,
72 read_bools: bools,
73 })
74 }
75 }
76}
77
78impl<B: Backend> TransactionPrimitive<B> {
79 pub fn new(
81 read_floats: Vec<FloatTensor<B>>,
82 read_qfloats: Vec<QuantizedTensor<B>>,
83 read_ints: Vec<IntTensor<B>>,
84 read_bools: Vec<BoolTensor<B>>,
85 ) -> Self {
86 Self {
87 read_floats,
88 read_qfloats,
89 read_ints,
90 read_bools,
91 orders: Vec::default(),
92 }
93 }
94 pub async fn execute_async(mut self) -> Result<Vec<TensorData>, ExecutionError> {
97 let mut orders = Vec::new();
98 core::mem::swap(&mut orders, &mut self.orders);
99 let result = B::tr_execute(self).await?;
100
101 let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();
102 let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();
103 let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();
104 let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();
105
106 Ok(orders
107 .into_iter()
108 .map(|order| match order {
109 Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),
110 Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),
111 Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),
112 Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),
113 })
114 .collect::<Vec<_>>())
115 }
116
117 pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {
118 match tensor {
119 TensorPrimitive::Float(tensor) => {
120 self.orders.push(Order::Float(self.read_floats.len()));
121 self.read_floats.push(tensor);
122 }
123 TensorPrimitive::QFloat(tensor) => {
124 self.orders.push(Order::QFloat(self.read_qfloats.len()));
125 self.read_qfloats.push(tensor);
126 }
127 }
128 }
129
130 pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {
131 self.orders.push(Order::Int(self.read_ints.len()));
132 self.read_ints.push(tensor);
133 }
134
135 pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {
136 self.orders.push(Order::Bool(self.read_bools.len()));
137 self.read_bools.push(tensor);
138 }
139}