1use mlx_core::{Result, Tensor};
4
5use crate::Module;
6
7pub struct Linear {
12 weight: Tensor,
13 bias: Option<Tensor>,
14}
15
16impl Linear {
17 pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
19 Self { weight, bias }
20 }
21
22 pub fn weight(&self) -> &Tensor {
24 &self.weight
25 }
26
27 pub fn bias(&self) -> Option<&Tensor> {
29 self.bias.as_ref()
30 }
31}
32
33impl Module for Linear {
34 fn forward(&self, input: &Tensor) -> Result<Tensor> {
35 let wt = self.weight.transpose(None)?;
37 let mut y = input.matmul(&wt)?;
38 if let Some(ref bias) = self.bias {
39 y = y.add(&bias.broadcast_to(y.shape())?)?;
41 }
42 Ok(y)
43 }
44}