1use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
2use alloc::{format, string::String};
3use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme};
4use core::marker::PhantomData;
5
6pub struct BackendRouter<R: RunnerChannel> {
8 r: PhantomData<R>,
9}
10
11impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
12 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13 f.write_fmt(format_args!("router"))
14 }
15}
16
17impl<R: RunnerChannel> Clone for BackendRouter<R> {
18 fn clone(&self) -> Self {
19 Self { r: PhantomData }
20 }
21}
22
23impl<R: RunnerChannel> Default for BackendRouter<R> {
24 fn default() -> Self {
25 Self { r: PhantomData }
26 }
27}
28
29impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
30 fn scheme(&self) -> &QuantScheme {
31 if let DType::QFloat(scheme) = &self.dtype {
32 scheme
33 } else {
34 panic!("Expected quantized float dtype, got {:?}", self.dtype)
36 }
37 }
38}
39
40impl<R: RunnerChannel> Backend for BackendRouter<R> {
41 type Device = R::Device;
42
43 type FloatTensorPrimitive = RouterTensor<R::Client>;
44
45 type FloatElem = R::FloatElem;
46
47 type IntTensorPrimitive = RouterTensor<R::Client>;
48
49 type IntElem = R::IntElem;
50
51 type BoolTensorPrimitive = RouterTensor<R::Client>;
52
53 type BoolElem = R::BoolElem;
54
55 type QuantizedTensorPrimitive = RouterTensor<R::Client>;
56
57 fn name(device: &Self::Device) -> String {
58 format!("router<{}>", R::name(device))
59 }
60
61 fn seed(device: &Self::Device, seed: u64) {
62 let client = get_client::<R>(device);
63 client.seed(seed);
64 }
65
66 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
67 let client = get_client::<R>(device);
68 client.sync()
69 }
70
71 fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {
72 let client = get_client::<R>(device);
73 client.supports_dtype(dtype)
74 }
75}