Skip to main content

mlx_nn/
linear.rs

1//! Linear (fully-connected) layer.
2
3use mlx_core::{Result, Tensor};
4
5use crate::Module;
6
7/// A linear (fully-connected) layer: `y = x @ W^T + b`.
8///
9/// Weight has shape `[out_features, in_features]`. Bias (optional) has shape
10/// `[out_features]`.
11pub struct Linear {
12    weight: Tensor,
13    bias: Option<Tensor>,
14}
15
16impl Linear {
17    /// Create a new Linear layer from pre-existing weight and bias tensors.
18    pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
19        Self { weight, bias }
20    }
21
22    /// Get a reference to the weight tensor.
23    pub fn weight(&self) -> &Tensor {
24        &self.weight
25    }
26
27    /// Get a reference to the bias tensor (if any).
28    pub fn bias(&self) -> Option<&Tensor> {
29        self.bias.as_ref()
30    }
31}
32
33impl Module for Linear {
34    fn forward(&self, input: &Tensor) -> Result<Tensor> {
35        // y = input @ weight^T
36        let wt = self.weight.transpose(None)?;
37        let mut y = input.matmul(&wt)?;
38        if let Some(ref bias) = self.bias {
39            // Broadcast bias [out_features] to match output [batch, out_features]
40            y = y.add(&bias.broadcast_to(y.shape())?)?;
41        }
42        Ok(y)
43    }
44}