Skip to main content

burn_central_runtime/
routine.rs

1use crate::error::RuntimeError;
2use crate::inference::{RoutineIn, RoutineInput};
3use crate::output::RoutineOutput;
4use crate::params::RoutineParam;
5use crate::type_name::fn_type_name;
6use std::marker::PhantomData;
7use variadics_please::all_tuples;
8
9#[diagnostic::on_unimplemented(message = "`{Self}` is not a routine", label = "invalid routine")]
10pub trait Routine<Ctx>: Send + Sync + 'static {
11    type In: RoutineInput;
12    type Out;
13
14    fn name(&self) -> &str;
15    fn run(
16        &self,
17        input: RoutineIn<'_, Ctx, Self>,
18        ctx: &mut Ctx,
19    ) -> anyhow::Result<Self::Out, RuntimeError>;
20}
21
22pub type RoutineParamItem<'ctx, Ctx, P> = <P as RoutineParam<Ctx>>::Item<'ctx>;
23
24#[diagnostic::on_unimplemented(
25    message = "`{Self}` is not a valid routine",
26    label = "invalid routine"
27)]
28pub trait RoutineParamFunction<Ctx, Marker>: Send + Sync + 'static {
29    type In: RoutineInput;
30    type Out;
31    type Param: RoutineParam<Ctx>;
32
33    fn run(
34        &self,
35        input: <Self::In as RoutineInput>::Inner<'_>,
36        param_value: RoutineParamItem<Ctx, Self::Param>,
37    ) -> anyhow::Result<Self::Out, RuntimeError>;
38}
39
40#[doc(hidden)]
41pub struct HasRoutineInput;
42
43macro_rules! impl_routine_function {
44    ($($param: ident),*) => {
45        #[expect(
46            clippy::allow_attributes,
47            reason = "This is within a macro, and as such, the below lints may not always apply."
48        )]
49        #[allow(
50            non_snake_case,
51            reason = "Certain variable names are provided by the caller, not by us."
52        )]
53        #[allow(clippy::too_many_arguments)]
54        impl<Ctx, Out, Func, $($param: RoutineParam<Ctx>),*> RoutineParamFunction<Ctx, fn($($param,)*) -> Out> for Func
55        where
56            Func: Send + Sync + 'static,
57            for <'a> &'a Func:
58                Fn($($param),*) -> Out +
59                Fn($(RoutineParamItem<Ctx, $param>),*) -> Out,
60            Out: 'static,
61            Ctx: 'static,
62        {
63            type In = ();
64            type Out = Out;
65            type Param = ($($param,)*);
66            #[inline]
67            fn run(&self, _input: (), param_value: RoutineParamItem<Ctx, ($($param,)*)>) -> Result<Self::Out, RuntimeError> {
68                #[expect(
69                    clippy::allow_attributes,
70                    reason = "This is within a macro, and as such, the below lints may not always apply."
71                )]
72                #[allow(clippy::too_many_arguments)]
73                fn call_inner<Out, $($param,)*>(
74                    f: impl Fn($($param,)*)->Out,
75                    $($param: $param,)*
76                )->Out{
77                    f($($param,)*)
78                }
79                let ($($param,)*) = param_value;
80                Ok(call_inner(self, $($param),*))
81            }
82        }
83
84        #[expect(
85            clippy::allow_attributes,
86            reason = "This is within a macro, and as such, the below lints may not always apply."
87        )]
88        #[allow(
89            non_snake_case,
90            reason = "Certain variable names are provided by the caller, not by us."
91        )]
92        #[allow(clippy::too_many_arguments)]
93        impl<Ctx, In, Out, Func, $($param: RoutineParam<Ctx>),*> RoutineParamFunction<Ctx, (HasRoutineInput, fn(In, $($param,)*) -> Out)> for Func
94        where
95            Func: Send + Sync + 'static,
96            for <'a> &'a Func:
97                Fn(In, $($param),*) -> Out +
98                Fn(In::Param<'_>, $(RoutineParamItem<Ctx, $param>),*) -> Out,
99            In: RoutineInput + 'static,
100            Out: 'static,
101            Ctx: 'static,
102        {
103            type In = In;
104            type Out = Out;
105            type Param = ($($param,)*);
106            #[inline]
107            fn run(&self, input: In::Inner<'_>, param_value: RoutineParamItem<Ctx, ($($param,)*)>) -> Result<Self::Out, RuntimeError> {
108                fn call_inner<In: RoutineInput, Out, $($param,)*>(
109                    _: PhantomData<In>,
110                    f: impl Fn(In::Param<'_>, $($param,)*)->Out,
111                    input: In::Inner<'_>,
112                    $($param: $param,)*
113                )->Out{
114                    f(In::wrap(input), $($param,)*)
115                }
116                let ($($param,)*) = param_value;
117                Ok(call_inner(PhantomData::<In>, self, input, $($param),*))
118            }
119        }
120    };
121}
122
123all_tuples!(impl_routine_function, 0, 16, F);
124
125#[doc(hidden)]
126pub struct IsFunctionRoutine;
127
128pub struct FunctionRoutine<Marker, F> {
129    func: F,
130    name: String,
131    _marker: PhantomData<fn() -> Marker>,
132}
133
134impl<Marker, F> FunctionRoutine<Marker, F> {
135    pub fn with_name(mut self, name: impl Into<String>) -> Self {
136        self.name = name.into();
137        self
138    }
139}
140
141impl<Marker, F: Clone> Clone for FunctionRoutine<Marker, F> {
142    fn clone(&self) -> Self {
143        FunctionRoutine {
144            func: self.func.clone(),
145            name: self.name.clone(),
146            _marker: PhantomData,
147        }
148    }
149}
150
151impl<Ctx, Marker, F> IntoRoutine<Ctx, F::In, F::Out, (IsFunctionRoutine, Marker)> for F
152where
153    Marker: 'static,
154    F: RoutineParamFunction<Ctx, Marker>,
155{
156    type Routine = FunctionRoutine<Marker, F>;
157
158    fn into_routine(func: Self) -> Self::Routine {
159        FunctionRoutine {
160            func,
161            name: fn_type_name::<F>(),
162            _marker: PhantomData,
163        }
164    }
165}
166
167impl<Ctx, Marker, F> Routine<Ctx> for FunctionRoutine<Marker, F>
168where
169    Marker: 'static,
170    F: RoutineParamFunction<Ctx, Marker>,
171{
172    type In = F::In;
173    type Out = F::Out;
174
175    fn name(&self) -> &str {
176        self.name.as_str()
177    }
178
179    fn run(
180        &self,
181        input: RoutineIn<'_, Ctx, Self>,
182        ctx: &mut Ctx,
183    ) -> anyhow::Result<Self::Out, RuntimeError> {
184        let params = <F::Param as RoutineParam<Ctx>>::try_retrieve(ctx).map_err(|e| {
185            RuntimeError::HandlerFailed(anyhow::anyhow!("Failed to retrieve parameters: {}", e))
186        })?;
187        let output = self.func.run(input, params)?;
188        Ok(output)
189    }
190}
191
192impl<Ctx, T: Routine<Ctx>> IntoRoutine<Ctx, T::In, T::Out, ()> for T {
193    type Routine = T;
194    fn into_routine(this: Self) -> Self::Routine {
195        this
196    }
197}
198
199#[diagnostic::on_unimplemented(
200    message = "`{Self}` is not a valid routine with output `{Output}`",
201    label = "invalid routine"
202)]
203pub trait IntoRoutine<Ctx, Input, Output, Marker>: Sized {
204    type Routine: Routine<Ctx, In = Input, Out = Output>;
205
206    #[allow(clippy::wrong_self_convention)]
207    fn into_routine(this: Self) -> Self::Routine;
208
209    fn with_name(self, name: impl Into<String>) -> IntoNamedRoutine<Ctx, Self> {
210        IntoNamedRoutine {
211            routine: self,
212            name: name.into(),
213            marker: Default::default(),
214        }
215    }
216}
217
218#[derive(Clone)]
219pub struct IntoNamedRoutine<Ctx, S> {
220    routine: S,
221    name: String,
222    marker: PhantomData<fn(Ctx)>,
223}
224
225pub struct NamedRoutine<S> {
226    inner: S,
227    name: String,
228}
229
230impl<Ctx, S> Routine<Ctx> for NamedRoutine<S>
231where
232    S: Routine<Ctx>,
233{
234    type In = S::In;
235    type Out = S::Out;
236
237    fn name(&self) -> &str {
238        &self.name
239    }
240
241    fn run(
242        &self,
243        input: RoutineIn<'_, Ctx, Self>,
244        ctx: &mut Ctx,
245    ) -> anyhow::Result<Self::Out, RuntimeError> {
246        self.inner.run(input, ctx)
247    }
248}
249
250#[doc(hidden)]
251pub struct IsNamedRoutine;
252impl<Ctx, I, O, M, S> IntoRoutine<Ctx, I, O, (IsNamedRoutine, M)> for IntoNamedRoutine<Ctx, S>
253where
254    S: IntoRoutine<Ctx, I, O, M>,
255{
256    type Routine = NamedRoutine<S::Routine>;
257
258    fn into_routine(this: Self) -> Self::Routine {
259        NamedRoutine {
260            inner: IntoRoutine::into_routine(this.routine),
261            name: this.name,
262        }
263    }
264}
265
266impl<Ctx, I, O, M, S, N> IntoRoutine<Ctx, I, O, (IsNamedRoutine, N, M)> for (N, S)
267where
268    S: IntoRoutine<Ctx, I, O, M>,
269    N: Into<String>,
270{
271    type Routine = NamedRoutine<S::Routine>;
272
273    fn into_routine(this: Self) -> Self::Routine {
274        let (name, routines) = this;
275        NamedRoutine {
276            inner: IntoRoutine::into_routine(routines),
277            name: name.into(),
278        }
279    }
280}
281
282pub struct ExecutorRoutineWrapper<S, Ctx>(S, PhantomData<Ctx>);
283
284impl<S, Ctx, Input, Output> ExecutorRoutineWrapper<S, Ctx>
285where
286    S: Routine<Ctx, In = Input, Out = Output>,
287{
288    pub fn new(routine: S) -> Self {
289        ExecutorRoutineWrapper(routine, PhantomData)
290    }
291}
292
293impl<Ctx, S, Input, Output> Routine<Ctx> for ExecutorRoutineWrapper<S, Ctx>
294where
295    S: Routine<Ctx, In = Input, Out = Output>,
296    Input: RoutineInput,
297    Output: RoutineOutput<Ctx>,
298    Ctx: Send + Sync + 'static,
299{
300    type In = Input;
301    type Out = ();
302
303    fn name(&self) -> &str {
304        self.0.name()
305    }
306
307    fn run(
308        &self,
309        input: RoutineIn<'_, Ctx, Self>,
310        ctx: &mut Ctx,
311    ) -> anyhow::Result<Self::Out, RuntimeError> {
312        match self.0.run(input, ctx) {
313            Ok(output) => {
314                output.apply_output(ctx).map_err(|e| {
315                    RuntimeError::HandlerFailed(anyhow::anyhow!("Failed to apply output: {}", e))
316                })?;
317                Ok(())
318            }
319            Err(e) => Err(e),
320        }
321    }
322}
323
324pub type BoxedRoutine<Ctx, In, Out> = Box<dyn Routine<Ctx, In = In, Out = Out>>;