1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
use crate::{ops::TchOps, LibTorch, QuantElement, TchElement, TchTensor};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device};
use std::marker::PhantomData;

/// Handle precision conversion for the candle backend.
#[derive(Debug)]
pub struct PrecisionBridge<E: TchElement> {
    _e: PhantomData<E>,
}

impl<TElem, OElem, QElem> BackendBridge<LibTorch<OElem, QElem>> for PrecisionBridge<TElem>
where
    TElem: TchElement,
    OElem: TchElement,
    QElem: QuantElement,
{
    type Target = LibTorch<TElem>;

    fn into_target<const D: usize>(
        tensor: FloatTensor<LibTorch<OElem>, D>,
        device: Option<Device<Self::Target>>,
    ) -> FloatTensor<Self::Target, D> {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.to_kind(TElem::KIND);

        let tensor = TchTensor::from_existing(tensor, storage);

        if let Some(device) = &device {
            TchOps::to_device(tensor, device)
        } else {
            tensor
        }
    }

    fn from_target<const D: usize>(
        tensor: FloatTensor<Self::Target, D>,
        device: Option<Device<LibTorch<OElem>>>,
    ) -> FloatTensor<LibTorch<OElem>, D> {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.to_kind(OElem::KIND);

        let tensor = TchTensor::from_existing(tensor, storage);

        if let Some(device) = &device {
            TchOps::to_device(tensor, device)
        } else {
            tensor
        }
    }
}