candle_nn/
func.rs

1//! Layers defined by closures.
2use candle::{Result, Tensor};
3use std::sync::Arc;
4
5/// A layer defined by a simple closure.
6#[derive(Clone)]
7pub struct Func<'a> {
8    #[allow(clippy::type_complexity)]
9    f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>,
10}
11
12impl std::fmt::Debug for Func<'_> {
13    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
14        write!(f, "func")
15    }
16}
17
18pub fn func<'a, F>(f: F) -> Func<'a>
19where
20    F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
21{
22    Func { f: Arc::new(f) }
23}
24
25impl super::Module for Func<'_> {
26    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
27        (*self.f)(xs)
28    }
29}
30
31impl<'a> Func<'a> {
32    pub fn new<F>(f: F) -> Self
33    where
34        F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
35    {
36        Self { f: Arc::new(f) }
37    }
38}
39
40/// A layer defined by a simple closure.
41#[derive(Clone)]
42pub struct FuncT<'a> {
43    #[allow(clippy::type_complexity)]
44    f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
45}
46
47impl std::fmt::Debug for FuncT<'_> {
48    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49        write!(f, "func")
50    }
51}
52
53pub fn func_t<'a, F>(f: F) -> FuncT<'a>
54where
55    F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
56{
57    FuncT { f: Arc::new(f) }
58}
59
60impl super::ModuleT for FuncT<'_> {
61    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
62        (*self.f)(xs, train)
63    }
64}
65
66impl<'a> FuncT<'a> {
67    pub fn new<F>(f: F) -> Self
68    where
69        F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
70    {
71        Self { f: Arc::new(f) }
72    }
73}