1use candle::{Result, Tensor};
3use std::sync::Arc;
4
5#[derive(Clone)]
7pub struct Func<'a> {
8 #[allow(clippy::type_complexity)]
9 f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>,
10}
11
12impl std::fmt::Debug for Func<'_> {
13 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
14 write!(f, "func")
15 }
16}
17
18pub fn func<'a, F>(f: F) -> Func<'a>
19where
20 F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
21{
22 Func { f: Arc::new(f) }
23}
24
25impl super::Module for Func<'_> {
26 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
27 (*self.f)(xs)
28 }
29}
30
31impl<'a> Func<'a> {
32 pub fn new<F>(f: F) -> Self
33 where
34 F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
35 {
36 Self { f: Arc::new(f) }
37 }
38}
39
40#[derive(Clone)]
42pub struct FuncT<'a> {
43 #[allow(clippy::type_complexity)]
44 f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
45}
46
47impl std::fmt::Debug for FuncT<'_> {
48 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49 write!(f, "func")
50 }
51}
52
53pub fn func_t<'a, F>(f: F) -> FuncT<'a>
54where
55 F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
56{
57 FuncT { f: Arc::new(f) }
58}
59
60impl super::ModuleT for FuncT<'_> {
61 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
62 (*self.f)(xs, train)
63 }
64}
65
66impl<'a> FuncT<'a> {
67 pub fn new<F>(f: F) -> Self
68 where
69 F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
70 {
71 Self { f: Arc::new(f) }
72 }
73}