Skip to main content

burn_router/
backend.rs

1use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
2use alloc::{format, string::String};
3use burn_backend::{
4    Backend, BackendTypes, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme,
5};
6use core::marker::PhantomData;
7
8/// A backend that forwards the tensor operations to the appropriate backend (given multiple backends).
9pub struct BackendRouter<R: RunnerChannel> {
10    r: PhantomData<R>,
11}
12
13impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
14    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
15        f.write_fmt(format_args!("router"))
16    }
17}
18
19impl<R: RunnerChannel> Clone for BackendRouter<R> {
20    fn clone(&self) -> Self {
21        Self { r: PhantomData }
22    }
23}
24
25impl<R: RunnerChannel> Default for BackendRouter<R> {
26    fn default() -> Self {
27        Self { r: PhantomData }
28    }
29}
30
31impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
32    fn scheme(&self) -> &QuantScheme {
33        if let DType::QFloat(scheme) = &self.dtype {
34            scheme
35        } else {
36            // TODO: maybe `tensor.scheme()` should return an option
37            panic!("Expected quantized float dtype, got {:?}", self.dtype)
38        }
39    }
40}
41
42impl<R: RunnerChannel> BackendTypes for BackendRouter<R> {
43    type Device = R::Device;
44
45    type FloatTensorPrimitive = RouterTensor<R::Client>;
46
47    type FloatElem = R::FloatElem;
48
49    type IntTensorPrimitive = RouterTensor<R::Client>;
50
51    type IntElem = R::IntElem;
52
53    type BoolTensorPrimitive = RouterTensor<R::Client>;
54
55    type BoolElem = R::BoolElem;
56
57    type QuantizedTensorPrimitive = RouterTensor<R::Client>;
58}
59
60impl<R: RunnerChannel> Backend for BackendRouter<R> {
61    fn name(device: &Self::Device) -> String {
62        format!("router<{}>", R::name(device))
63    }
64
65    fn seed(device: &Self::Device, seed: u64) {
66        let client = get_client::<R>(device);
67        client.seed(seed);
68    }
69
70    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
71        let client = get_client::<R>(device);
72        client.sync()
73    }
74
75    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
76        let client = get_client::<R>(device);
77        client.dtype_usage(dtype)
78    }
79
80    fn device_count(_: u16) -> usize {
81        // This is what was there before, not sure if it's actually correct
82        1
83    }
84}