rai_core/
dispatch.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
use crate::{
    primitives, CandleBackend, Cpu, Device, Eval, Primitive, Tensor, BF16, F16, F32, F64, I64, U32,
    U8,
};
use once_cell::sync::Lazy;
use std::{any::TypeId, collections::HashMap, sync::RwLock};

#[derive(Debug, Clone)]
pub struct BackendWrapper<D, P, B> {
    backend: B,
    phantom: std::marker::PhantomData<fn(D, P)>,
}

impl<D, P, B> Eval<dyn Device, dyn Primitive> for BackendWrapper<D, P, B>
where
    D: Device + 'static + Clone,
    P: Primitive + 'static + Clone,
    B: Eval<D, P> + 'static + Clone,
{
    #[inline]
    fn eval(
        &self,
        device: &dyn Device,
        primitive: &dyn Primitive,
        inputs: &[Tensor],
        output: &Tensor,
    ) {
        let device = device.as_any().downcast_ref::<D>().unwrap();
        let primitive = primitive.as_any().downcast_ref::<P>().unwrap();
        self.backend.eval(device, primitive, inputs, output);
    }
}

impl<D, P, B> Eval<Box<dyn Device>, Box<dyn Primitive>> for BackendWrapper<D, P, B>
where
    D: Device + 'static + Clone,
    P: Primitive + 'static + Clone,
    B: Eval<D, P> + 'static + Clone,
{
    #[inline]
    fn eval(
        &self,
        device: &Box<dyn Device>,
        primitive: &Box<dyn Primitive>,
        inputs: &[Tensor],
        output: &Tensor,
    ) {
        let device = device.as_any().downcast_ref::<D>().unwrap();
        let primitive = primitive.as_any().downcast_ref::<P>().unwrap();
        self.backend.eval(device, primitive, inputs, output);
    }
}

type DynBackend = Box<dyn Eval<dyn Device, dyn Primitive>>;

macro_rules! register_backend {
    ($backend:ident, $device:ty, $rules:expr) => {
        // creation
        _register::<$backend, $device, primitives::Full<U8>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<U32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<BF16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Full<I64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Random<BF16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Random<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Random<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Random<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Normal<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Normal<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Normal<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<U8>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<U32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<BF16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Arange<I64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<U8>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<U32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<BF16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::FromArray<I64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Concatenate>($backend, &mut $rules);

        // binary
        _register::<$backend, $device, primitives::Add>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Sub>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Mul>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Div>($backend, &mut $rules);
        _register::<$backend, $device, primitives::MatMul>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Equal>($backend, &mut $rules);
        _register::<$backend, $device, primitives::NotEqual>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Greater>($backend, &mut $rules);
        _register::<$backend, $device, primitives::GreaterEqual>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Less>($backend, &mut $rules);
        _register::<$backend, $device, primitives::LessEqual>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Maximum>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Minimum>($backend, &mut $rules);

        // unary
        _register::<$backend, $device, primitives::Sin>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Cos>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Tanh>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Negative>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Square>($backend, &mut $rules);
        _register::<$backend, $device, primitives::PowerFloat>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Sqrt>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Rsqrt>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Sign>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Abs>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Exp>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Log>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Log2>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Log10>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<U8>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<U32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<BF16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<F16>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<F32>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<F64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDType<I64>>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Softmax>($backend, &mut $rules);
        _register::<$backend, $device, primitives::LogSoftmax>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Erf>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToDevice<Cpu>>($backend, &mut $rules);
        #[cfg(feature = "cuda")]
        _register::<$backend, $device, primitives::ToDevice<crate::Cuda>>($backend, &mut $rules);
        #[cfg(feature = "metal")]
        _register::<$backend, $device, primitives::ToDevice<crate::Metal>>($backend, &mut $rules);

        // indexing
        _register::<$backend, $device, primitives::Gather>($backend, &mut $rules);
        _register::<$backend, $device, primitives::IndexSelect>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Narrow>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Where>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ScatterAdd>($backend, &mut $rules);
        _register::<$backend, $device, primitives::IndexAdd>($backend, &mut $rules);

        // transform
        _register::<$backend, $device, primitives::Transpose>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Reshape>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Permute>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Broadcast>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ToContiguous>($backend, &mut $rules);

        // reduce
        _register::<$backend, $device, primitives::ReduceSum>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ReduceMax>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ReduceMin>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ArgMax>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ArgMin>($backend, &mut $rules);

        // others
        _register::<$backend, $device, primitives::Conv1d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::Conv2d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ConvTranspose1d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::ConvTranspose2d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::MaxPool1d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::MaxPool2d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::AvgPool1d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::AvgPool2d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::UpsampleNearest1d>($backend, &mut $rules);
        _register::<$backend, $device, primitives::UpsampleNearest2d>($backend, &mut $rules);
    };
}

static EVAL_DISPATCHER: Lazy<RwLock<HashMap<(TypeId, TypeId), DynBackend>>> = Lazy::new(|| {
    let mut rules: HashMap<(TypeId, TypeId), DynBackend> = HashMap::new();

    #[cfg(feature = "candle-backend")]
    register_backend!(CandleBackend, Cpu, rules);

    #[cfg(all(feature = "candle-backend", feature = "cuda"))]
    register_backend!(CandleBackend, crate::Cuda, rules);

    #[cfg(all(feature = "candle-backend", feature = "metal"))]
    register_backend!(CandleBackend, crate::Metal, rules);

    #[cfg(all(
        feature = "candle-backend",
        feature = "cuda",
        feature = "candle-flash-attn"
    ))]
    _register::<CandleBackend, crate::Cuda, primitives::FlashAttention>(CandleBackend, rules);

    RwLock::new(rules)
});

pub fn register<D, P, B>(backend: B)
where
    D: Device + 'static + Clone,
    P: Primitive + 'static + Clone,
    B: Eval<D, P> + 'static + Clone,
{
    let mut dispatcher = EVAL_DISPATCHER.write().unwrap();
    _register::<B, D, P>(backend, &mut dispatcher);
}

fn _register<B, D, P>(backend: B, dispatcher: &mut HashMap<(TypeId, TypeId), DynBackend>)
where
    D: Device + 'static + Clone,
    P: Primitive + 'static + Clone,
    B: Eval<D, P> + 'static + Clone,
{
    dispatcher.insert(
        (TypeId::of::<D>(), TypeId::of::<P>()),
        Box::new(BackendWrapper {
            backend,
            phantom: std::marker::PhantomData::<fn(D, P)>,
        }),
    );
}

#[inline(always)]
pub fn eval_rule(device: &dyn Device, primitive: &dyn Primitive) -> Option<DynBackend> {
    let dispatcher = EVAL_DISPATCHER.read().unwrap();
    dispatcher
        .get(&(device.as_any().type_id(), primitive.as_any().type_id()))
        .cloned()
}