Skip to main content

burn_fusion/ops/
transaction.rs

1use burn_backend::{
2    backend::ExecutionError,
3    ops::{TransactionOps, TransactionPrimitive},
4};
5
6use crate::{Fusion, FusionBackend};
7
8impl<B: FusionBackend> TransactionOps<Fusion<B>> for Fusion<B> {
9    async fn tr_execute(
10        transaction: TransactionPrimitive<Self>,
11    ) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {
12        B::tr_execute(TransactionPrimitive::new(
13            transaction
14                .read_floats
15                .into_iter()
16                .map(|t| t.client.clone().resolve_tensor_float::<B>(t))
17                .collect(),
18            transaction
19                .read_qfloats
20                .into_iter()
21                .map(|_t| todo!("Quantization not supported yet"))
22                .collect(),
23            transaction
24                .read_ints
25                .into_iter()
26                .map(|t| t.client.clone().resolve_tensor_int::<B>(t))
27                .collect(),
28            transaction
29                .read_bools
30                .into_iter()
31                .map(|t| t.client.clone().resolve_tensor_bool::<B>(t))
32                .collect(),
33        ))
34        .await
35    }
36}