rogue_net/
linear.rs

1use ndarray::prelude::*;
2
3use crate::msgpack::TensorDict;
4#[derive(Debug, Clone)]
5pub struct Linear {
6    weight: Array<f32, Ix2>,
7    bias: Array<f32, Ix2>,
8}
9
10impl<'a> From<&'a TensorDict> for Linear {
11    fn from(state_dict: &TensorDict) -> Self {
12        let dict = state_dict.as_dict();
13        let weight = dict["weight"].as_tensor().to_ndarray_f32();
14        let bias = dict["bias"].as_tensor().to_ndarray_f32();
15        Linear {
16            weight: weight.reversed_axes().into_dimensionality().unwrap(),
17            bias: bias.insert_axis(Axis(0)).into_dimensionality().unwrap(),
18        }
19    }
20}
21
22impl Linear {
23    pub fn forward(&self, x: ArrayView2<f32>) -> Array2<f32> {
24        x.dot(&self.weight) + &self.bias
25    }
26}