ferrotorch-nn 0.6.1

Neural network modules for ferrotorch — layers, losses, initialization
Documentation
//! `Parameter<T>` — a trainable tensor wrapper, the Rust analog of
//! `torch.nn.Parameter`.
//!
//! Thin newtype around `Tensor<T>` that enforces `requires_grad = true` on
//! construction, derefs to `Tensor<T>` for all tensor operations, and
//! supports the optimizer-facing API (`set_data`, `set_requires_grad`,
//! `to(device)`).
//!
//! ## REQ status (per `.design/ferrotorch-nn/parameter.md`)
//!
//! | REQ | Status | Evidence |
//! |---|---|---|
//! | REQ-1 | SHIPPED | `pub struct Parameter<T: Float> { data: Tensor<T> }` with `#[derive(Debug, Clone)]` mirrors `torch/nn/parameter.py:30-70` via R-DEV-7 (newtype replacing Tensor subclass); consumed by `ferrotorch-nn/src/linear.rs:46` `pub weight: Parameter<T>` and every other weight-bearing layer field. |
//! | REQ-2 | SHIPPED | `Parameter::new(tensor)` enforces `requires_grad = true` mirroring `torch/nn/parameter.py:51-70`; consumed by `ferrotorch-nn/src/linear.rs:83` `Parameter::zeros(...)` (which calls `Parameter::new`) and every layer constructor. |
//! | REQ-3 | SHIPPED | `Parameter::zeros`, `::ones`, `::from_slice` factories; consumed by `ferrotorch-nn/src/linear.rs:83, 88` (weight + bias initialization) and `conv.rs` / `embedding.rs` / `norm.rs`. |
//! | REQ-4 | SHIPPED | `tensor(&self) -> &Tensor<T>` and `into_tensor(self)` accessors; consumed by `ferrotorch-nn/src/module.rs:74` `param.tensor().clone()` inside the default `state_dict`. |
//! | REQ-5 | SHIPPED | `set_data` re-enforces `requires_grad = true`; consumed by every optimizer step (`ferrotorch-optim/src/adam.rs`, …) writing updated weights. |
//! | REQ-6 | SHIPPED | `set_requires_grad(bool)` mirrors `torch/nn/parameter.py`'s `requires_grad_` (Tensor-inherited); consumed by `ferrotorch-nn/src/module.rs` `Module::requires_grad_` default impl. |
//! | REQ-7 | SHIPPED | `to(device) -> FerrotorchResult<Self>`; consumed by `ferrotorch-nn/src/module.rs` `Module::to_device` default impl calling `param.to(device)?` for each parameter. |
//! | REQ-8 | SHIPPED | `impl Deref<Target = Tensor<T>>` (R-DEV-7 — Rust analog of Python class-subclass inheritance); consumed by every callsite invoking a Tensor method on a Parameter (`param.shape()`, `param.device()`, `param.zero_grad()` in `Module::zero_grad`). |
//! | REQ-9 | SHIPPED | `#[derive(Debug, Clone)]` with shallow Arc-backed clone semantics; consumed by `Module::state_dict` calling `param.tensor().clone()` for serialization. |

use ferrotorch_core::{Device, FerrotorchResult, Float, Tensor};

/// A tensor registered for gradient descent.
///
/// Always has `requires_grad = true`. Stored inside `Module` implementations
/// as the unit of registration for optimizer consumption.
///
/// `Parameter<T>` is a thin wrapper — it derefs to `Tensor<T>` for all
/// tensor operations, and cloning shares the same underlying identity
/// (Arc-based, like Tensor).
#[derive(Debug, Clone)]
pub struct Parameter<T: Float> {
    data: Tensor<T>,
}

impl<T: Float> Parameter<T> {
    /// Create a new parameter from a tensor.
    ///
    /// The tensor is set to `requires_grad = true` regardless of its
    /// current state.
    pub fn new(tensor: Tensor<T>) -> Self {
        Self {
            data: tensor.requires_grad_(true),
        }
    }

    /// Create a parameter initialized with zeros.
    pub fn zeros(shape: &[usize]) -> FerrotorchResult<Self> {
        let t = ferrotorch_core::zeros::<T>(shape)?;
        Ok(Self::new(t))
    }

    /// Create a parameter initialized with ones.
    pub fn ones(shape: &[usize]) -> FerrotorchResult<Self> {
        let t = ferrotorch_core::ones::<T>(shape)?;
        Ok(Self::new(t))
    }

    /// Create a parameter from a data slice.
    pub fn from_slice(data: &[T], shape: &[usize]) -> FerrotorchResult<Self> {
        let t = ferrotorch_core::from_slice(data, shape)?;
        Ok(Self::new(t))
    }

    /// Borrow the underlying tensor.
    #[inline]
    pub fn tensor(&self) -> &Tensor<T> {
        &self.data
    }

    /// Consume and return the underlying tensor.
    pub fn into_tensor(self) -> Tensor<T> {
        self.data
    }

    /// Replace the underlying tensor data while preserving `requires_grad`.
    ///
    /// Used by optimizers to update parameter values without breaking the
    /// parameter identity semantics. The new tensor is set to
    /// `requires_grad = true` regardless of its input state.
    pub fn set_data(&mut self, tensor: Tensor<T>) {
        self.data = tensor.requires_grad_(true);
    }

    /// Toggle whether this parameter participates in autograd (#583).
    ///
    /// Setting `false` "freezes" the parameter — backward passes will not
    /// produce a gradient for it; optimizer steps that consult
    /// `requires_grad` will skip it. Mirrors `torch.nn.Parameter.requires_grad_`.
    pub fn set_requires_grad(&mut self, requires_grad: bool) {
        // Tensor::requires_grad_ takes self by value, so clone once.
        let cloned = self.data.clone();
        self.data = cloned.requires_grad_(requires_grad);
    }

    /// Move this parameter to a device.
    pub fn to(&self, device: Device) -> FerrotorchResult<Self> {
        Ok(Self::new(self.data.to(device)?))
    }
}

impl<T: Float> std::ops::Deref for Parameter<T> {
    type Target = Tensor<T>;

    #[inline]
    fn deref(&self) -> &Self::Target {
        &self.data
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parameter_requires_grad() {
        let p = Parameter::<f32>::zeros(&[3, 4]).unwrap();
        assert!(p.requires_grad());
    }

    #[test]
    fn test_parameter_deref_to_tensor() {
        let p = Parameter::<f32>::zeros(&[2, 3]).unwrap();
        assert_eq!(p.shape(), &[2, 3]);
        assert_eq!(p.numel(), 6);
    }

    #[test]
    fn test_parameter_clone_shares_identity() {
        let p = Parameter::<f32>::zeros(&[4]).unwrap();
        let p2 = p.clone();
        assert!(p.tensor().is_same(p2.tensor()));
    }

    #[test]
    fn test_parameter_to_cpu_preserves_data() {
        let p = Parameter::<f32>::from_slice(&[1.0, 2.0, 3.0], &[3]).unwrap();
        let p2 = p.to(ferrotorch_core::Device::Cpu).unwrap();
        assert_eq!(p2.shape(), &[3]);
        assert_eq!(p2.data().unwrap(), &[1.0, 2.0, 3.0]);
        assert!(p2.requires_grad());
    }

    #[test]
    fn test_parameter_to_cuda_without_backend() {
        let p = Parameter::<f32>::zeros(&[2]).unwrap();
        let result = p.to(ferrotorch_core::Device::Cuda(0));
        assert!(result.is_err());
    }
}