burn_tensor/tensor/ops/
transaction.rs1use alloc::vec::Vec;
2use core::future::Future;
3
4use super::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
5use crate::{TensorData, backend::Backend};
6
7#[derive(Default)]
8pub struct TransactionPrimitive<B: Backend> {
10 pub read_floats: Vec<FloatTensor<B>>,
12 pub read_qfloats: Vec<QuantizedTensor<B>>,
14 pub read_ints: Vec<IntTensor<B>>,
16 pub read_bools: Vec<BoolTensor<B>>,
18}
19
20#[derive(Default)]
21pub struct TransactionPrimitiveResult {
23 pub read_floats: Vec<TensorData>,
25 pub read_qfloats: Vec<TensorData>,
27 pub read_ints: Vec<TensorData>,
29 pub read_bools: Vec<TensorData>,
31}
32
33pub trait TransactionOps<B: Backend> {
36 fn tr_execute(
39 transaction: TransactionPrimitive<B>,
40 ) -> impl Future<Output = TransactionPrimitiveResult> + Send {
41 async move {
42 let mut floats = Vec::new();
43 let mut qfloats = Vec::new();
44 let mut ints = Vec::new();
45 let mut bools = Vec::new();
46
47 for t in transaction.read_floats {
48 floats.push(B::float_into_data(t).await);
49 }
50 for t in transaction.read_qfloats {
51 qfloats.push(B::q_into_data(t).await);
52 }
53 for t in transaction.read_ints {
54 ints.push(B::int_into_data(t).await);
55 }
56 for t in transaction.read_bools {
57 bools.push(B::bool_into_data(t).await);
58 }
59
60 TransactionPrimitiveResult {
61 read_floats: floats,
62 read_qfloats: qfloats,
63 read_ints: ints,
64 read_bools: bools,
65 }
66 }
67 }
68}