1use candle_core::*;
2use candle_nn::*;
3
4pub struct Ltc {
5 pub dense_layer: Linear,
6 activation: Activation,
7 tau: Tensor,
8 a: Tensor,
9}
10
11impl Ltc {
12 pub fn new(state_dim: usize, input_dim: usize, vb: VarBuilder) -> Result<Self> {
13 let a = Ltc {
14 activation: Activation::GeluPytorchTanh,
15 dense_layer: linear(state_dim + input_dim, state_dim, vb.clone())?,
16 tau: Tensor::from_slice(&[1.0f32], 1, vb.clone().device())?,
17 a: Tensor::from_slice(&[1.0f32], 1, vb.device())?,
18 };
19 Ok(a)
20 }
21
22 fn fused_state(&self, x: &Tensor, dt: &Tensor, i: &Tensor) -> Result<Tensor> {
23 let one = Tensor::ones(1, x.dtype(), x.device())?;
24 let tau_inv = (one.clone()/self.tau.clone())?;
25 let h = if i.dims()[0] > 0 {
26 Tensor::cat(&[x.clone(), i.clone()], 1)?
27 } else {
28 x.clone()
29 };
30
31 let f = self.dense_layer.forward(&h)?;
32 let f = self.activation.forward(&f)?;
33
34 let denom = (f.clone().broadcast_add(&tau_inv)?).broadcast_mul(&(one + dt.clone())?)?;
35
36 let x = ((x.broadcast_div(&denom))?.broadcast_add(&((f.broadcast_mul(dt)?.broadcast_mul(&self.a)?)/(denom.clone()))?))?;
37 Ok(x)
38 }
39
40 pub fn solve_ode(&self, x: &Tensor, dt: Tensor, i: Tensor) -> Result<Tensor> {
41 const RANGE: u8 = 10;
42
43 let mut dt = dt;
44 let mut x = x.clone();
45 let sub_dt = (dt.clone()/Tensor::from_slice(&[RANGE as f32], &[1], x.device()))?;
46
47 for _ in 0..10 {
48 x = self.fused_state(&x, &dt, &i)?;
49 dt = (dt + sub_dt.clone())?;
50 }
51 Ok(x)
52 }
53}
54
55impl Module for Ltc {
56 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57 let x = xs;
58 let dt = Tensor::from_slice(&[0.1f32], 1, x.device())?;
59 let i = Tensor::zeros(0, DType::F32, x.device())?;
60 let x = self.solve_ode(x, dt, i)?;
61 Ok(x)
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn test_ltc() -> Result<()> {
71 let device = Device::cuda_if_available(0)?;
72 let varmap = VarMap::new();
73 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
74
75 let ltc = Ltc::new(2, 1, vb)?;
76 let x = Tensor::from_slice(&[1.0f32, 1.0f32], &[1,2], &device.clone())?;
77 let y = ltc.forward(&x)?;
78 Ok(())
79 }
80}