1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use crate::{
    element::{FloatCandleElement, IntCandleElement},
    Candle, CandleTensor,
};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device};
use std::marker::PhantomData;

/// Handle precision conversion for the candle backend.
#[derive(Debug)]
pub struct PrecisionBridge<E: FloatCandleElement> {
    _e: PhantomData<E>,
}

impl<TElem, OElem, IntElem> BackendBridge<Candle<OElem, IntElem>> for PrecisionBridge<TElem>
where
    TElem: FloatCandleElement,
    OElem: FloatCandleElement,
    IntElem: IntCandleElement,
{
    type Target = Candle<TElem, IntElem>;

    fn into_target<const D: usize>(
        tensor: FloatTensor<Candle<OElem, IntElem>, D>,
        device: Option<Device<Self::Target>>,
    ) -> FloatTensor<Self::Target, D> {
        CandleTensor::new(tensor.tensor.to_dtype(TElem::DTYPE).unwrap())
    }

    fn from_target<const D: usize>(
        tensor: FloatTensor<Self::Target, D>,
        device: Option<Device<Candle<OElem, IntElem>>>,
    ) -> FloatTensor<Candle<OElem, IntElem>, D> {
        CandleTensor::new(tensor.tensor.to_dtype(OElem::DTYPE).unwrap())
    }
}