candle_ltc/
lib.rs

1use candle_core::*;
2use candle_nn::*;
3
4pub struct Ltc {
5    pub dense_layer: Linear,
6    activation: Activation,
7    tau: Tensor,
8    a: Tensor,
9}
10
11impl Ltc {
12    pub fn new(state_dim: usize, input_dim: usize, vb: VarBuilder) -> Result<Self> {
13        let a = Ltc {
14            activation: Activation::GeluPytorchTanh,
15            dense_layer: linear(state_dim + input_dim, state_dim, vb.clone())?,
16            tau: Tensor::from_slice(&[1.0f32], 1, vb.clone().device())?,
17            a: Tensor::from_slice(&[1.0f32], 1, vb.device())?,
18        };
19        Ok(a)
20    }
21
22    fn fused_state(&self, x: &Tensor, dt: &Tensor, i: &Tensor) -> Result<Tensor> {
23        let one = Tensor::ones(1, x.dtype(), x.device())?;
24        let tau_inv = (one.clone()/self.tau.clone())?;
25        let h = if i.dims()[0] > 0 { 
26            Tensor::cat(&[x.clone(), i.clone()], 1)?
27        } else {
28            x.clone()
29        };
30
31        let f = self.dense_layer.forward(&h)?;
32        let f = self.activation.forward(&f)?;
33
34        let denom = (f.clone().broadcast_add(&tau_inv)?).broadcast_mul(&(one + dt.clone())?)?;
35
36        let x = ((x.broadcast_div(&denom))?.broadcast_add(&((f.broadcast_mul(dt)?.broadcast_mul(&self.a)?)/(denom.clone()))?))?;
37        Ok(x)
38    }
39    
40    pub fn solve_ode(&self, x: &Tensor, dt: Tensor, i: Tensor) -> Result<Tensor> {
41        const RANGE: u8 = 10;
42
43        let mut dt = dt;
44        let mut x = x.clone();
45        let sub_dt = (dt.clone()/Tensor::from_slice(&[RANGE as f32], &[1], x.device()))?;
46
47        for _ in 0..10 {
48            x = self.fused_state(&x, &dt, &i)?;
49            dt =  (dt + sub_dt.clone())?;
50        }
51        Ok(x)
52    }
53}
54
55impl Module for Ltc {
56    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57        let x = xs;
58        let dt = Tensor::from_slice(&[0.1f32], 1, x.device())?;
59        let i = Tensor::zeros(0, DType::F32, x.device())?;
60        let x = self.solve_ode(x, dt, i)?;
61        Ok(x)
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_ltc() -> Result<()> {
71        let device = Device::cuda_if_available(0)?;
72        let varmap = VarMap::new();
73        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
74
75        let ltc = Ltc::new(2, 1, vb)?;
76        let x = Tensor::from_slice(&[1.0f32, 1.0f32], &[1,2], &device.clone())?;
77        let y = ltc.forward(&x)?;
78        Ok(())
79    }
80}