1use ndarray::prelude::*;
2
3use crate::msgpack::TensorDict;
4#[derive(Debug, Clone)]
5pub struct Linear {
6 weight: Array<f32, Ix2>,
7 bias: Array<f32, Ix2>,
8}
9
10impl<'a> From<&'a TensorDict> for Linear {
11 fn from(state_dict: &TensorDict) -> Self {
12 let dict = state_dict.as_dict();
13 let weight = dict["weight"].as_tensor().to_ndarray_f32();
14 let bias = dict["bias"].as_tensor().to_ndarray_f32();
15 Linear {
16 weight: weight.reversed_axes().into_dimensionality().unwrap(),
17 bias: bias.insert_axis(Axis(0)).into_dimensionality().unwrap(),
18 }
19 }
20}
21
22impl Linear {
23 pub fn forward(&self, x: ArrayView2<f32>) -> Array2<f32> {
24 x.dot(&self.weight) + &self.bias
25 }
26}