1use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};
2use alloc::{format, string::String};
3use burn_backend::{
4 Backend, BackendTypes, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme,
5};
6use core::marker::PhantomData;
7
8pub struct BackendRouter<R: RunnerChannel> {
10 r: PhantomData<R>,
11}
12
13impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
14 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
15 f.write_fmt(format_args!("router"))
16 }
17}
18
19impl<R: RunnerChannel> Clone for BackendRouter<R> {
20 fn clone(&self) -> Self {
21 Self { r: PhantomData }
22 }
23}
24
25impl<R: RunnerChannel> Default for BackendRouter<R> {
26 fn default() -> Self {
27 Self { r: PhantomData }
28 }
29}
30
31impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
32 fn scheme(&self) -> &QuantScheme {
33 if let DType::QFloat(scheme) = &self.dtype {
34 scheme
35 } else {
36 panic!("Expected quantized float dtype, got {:?}", self.dtype)
38 }
39 }
40}
41
42impl<R: RunnerChannel> BackendTypes for BackendRouter<R> {
43 type Device = R::Device;
44
45 type FloatTensorPrimitive = RouterTensor<R::Client>;
46
47 type FloatElem = R::FloatElem;
48
49 type IntTensorPrimitive = RouterTensor<R::Client>;
50
51 type IntElem = R::IntElem;
52
53 type BoolTensorPrimitive = RouterTensor<R::Client>;
54
55 type BoolElem = R::BoolElem;
56
57 type QuantizedTensorPrimitive = RouterTensor<R::Client>;
58}
59
60impl<R: RunnerChannel> Backend for BackendRouter<R> {
61 fn name(device: &Self::Device) -> String {
62 format!("router<{}>", R::name(device))
63 }
64
65 fn seed(device: &Self::Device, seed: u64) {
66 let client = get_client::<R>(device);
67 client.seed(seed);
68 }
69
70 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
71 let client = get_client::<R>(device);
72 client.sync()
73 }
74
75 fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
76 let client = get_client::<R>(device);
77 client.dtype_usage(dtype)
78 }
79
80 fn device_count(_: u16) -> usize {
81 1
83 }
84}