Skip to main content

burn_router/
backend.rs

1use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
2use alloc::{format, string::String};
3use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme};
4use core::marker::PhantomData;
5
6/// A backend that forwards the tensor operations to the appropriate backend (given multiple backends).
7pub struct BackendRouter<R: RunnerChannel> {
8    r: PhantomData<R>,
9}
10
11impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        f.write_fmt(format_args!("router"))
14    }
15}
16
17impl<R: RunnerChannel> Clone for BackendRouter<R> {
18    fn clone(&self) -> Self {
19        Self { r: PhantomData }
20    }
21}
22
23impl<R: RunnerChannel> Default for BackendRouter<R> {
24    fn default() -> Self {
25        Self { r: PhantomData }
26    }
27}
28
29impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
30    fn scheme(&self) -> &QuantScheme {
31        if let DType::QFloat(scheme) = &self.dtype {
32            scheme
33        } else {
34            // TODO: maybe `tensor.scheme()` should return an option
35            panic!("Expected quantized float dtype, got {:?}", self.dtype)
36        }
37    }
38}
39
40impl<R: RunnerChannel> Backend for BackendRouter<R> {
41    type Device = R::Device;
42
43    type FloatTensorPrimitive = RouterTensor<R::Client>;
44
45    type FloatElem = R::FloatElem;
46
47    type IntTensorPrimitive = RouterTensor<R::Client>;
48
49    type IntElem = R::IntElem;
50
51    type BoolTensorPrimitive = RouterTensor<R::Client>;
52
53    type BoolElem = R::BoolElem;
54
55    type QuantizedTensorPrimitive = RouterTensor<R::Client>;
56
57    fn name(device: &Self::Device) -> String {
58        format!("router<{}>", R::name(device))
59    }
60
61    fn seed(device: &Self::Device, seed: u64) {
62        let client = get_client::<R>(device);
63        client.seed(seed);
64    }
65
66    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
67        let client = get_client::<R>(device);
68        client.sync()
69    }
70
71    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
72        let client = get_client::<R>(device);
73        client.supports_dtype(dtype)
74    }
75}