Skip to main content

axonml_nn/
parameter.rs

1//! Parameter - Learnable Parameter Wrapper
2//!
3//! Wraps Variables that are learnable parameters of a module.
4//! Parameters are special Variables that are registered with modules
5//! and can be optimized during training.
6//!
7//! @version 0.1.0
8//! @author AutomataNexus Development Team
9
10use std::sync::Arc;
11
12use axonml_autograd::Variable;
13use axonml_tensor::Tensor;
14use parking_lot::RwLock;
15
16// =============================================================================
17// Parameter
18// =============================================================================
19
20/// A learnable parameter of a neural network module.
21///
22/// Parameters wrap Variables and provide additional functionality:
23/// - Automatic requires_grad=true by default
24/// - Registration with parent modules
25/// - Easy access to data and gradients
26#[derive(Clone)]
27pub struct Parameter {
28    /// The underlying variable.
29    data: Arc<RwLock<Variable>>,
30    /// Parameter name (for debugging and serialization).
31    name: String,
32}
33
34impl Parameter {
35    /// Creates a new parameter from a tensor.
36    ///
37    /// # Arguments
38    /// * `data` - The tensor data
39    /// * `requires_grad` - Whether to track gradients (default true)
40    pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
41        Self {
42            data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
43            name: String::new(),
44        }
45    }
46
47    /// Creates a new parameter with a name.
48    pub fn named(name: impl Into<String>, data: Tensor<f32>, requires_grad: bool) -> Self {
49        Self {
50            data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
51            name: name.into(),
52        }
53    }
54
55    /// Creates a parameter from an existing Variable.
56    pub fn from_variable(var: Variable) -> Self {
57        Self {
58            data: Arc::new(RwLock::new(var)),
59            name: String::new(),
60        }
61    }
62
63    /// Returns the parameter name.
64    pub fn name(&self) -> &str {
65        &self.name
66    }
67
68    /// Sets the parameter name.
69    pub fn set_name(&mut self, name: impl Into<String>) {
70        self.name = name.into();
71    }
72
73    /// Returns a clone of the underlying Variable.
74    pub fn variable(&self) -> Variable {
75        self.data.read().clone()
76    }
77
78    /// Returns a clone of the tensor data.
79    pub fn data(&self) -> Tensor<f32> {
80        self.data.read().data()
81    }
82
83    /// Returns the shape of the parameter.
84    pub fn shape(&self) -> Vec<usize> {
85        self.data.read().shape()
86    }
87
88    /// Returns the number of elements.
89    pub fn numel(&self) -> usize {
90        self.data.read().numel()
91    }
92
93    /// Returns whether this parameter requires gradients.
94    pub fn requires_grad(&self) -> bool {
95        self.data.read().requires_grad()
96    }
97
98    /// Returns the gradient if available.
99    pub fn grad(&self) -> Option<Tensor<f32>> {
100        self.data.read().grad()
101    }
102
103    /// Zeros the gradient.
104    pub fn zero_grad(&self) {
105        self.data.read().zero_grad();
106    }
107
108    /// Updates the parameter data in-place.
109    ///
110    /// Used by optimizers to update weights.
111    pub fn update_data(&self, new_data: Tensor<f32>) {
112        let mut guard = self.data.write();
113        let requires_grad = guard.requires_grad();
114        *guard = Variable::new(new_data, requires_grad);
115    }
116
117    /// Applies a function to the parameter data.
118    pub fn apply_update<F>(&self, f: F)
119    where
120        F: FnOnce(&Tensor<f32>) -> Tensor<f32>,
121    {
122        let current = self.data();
123        let updated = f(&current);
124        self.update_data(updated);
125    }
126}
127
128impl std::fmt::Debug for Parameter {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("Parameter")
131            .field("name", &self.name)
132            .field("shape", &self.shape())
133            .field("requires_grad", &self.requires_grad())
134            .finish()
135    }
136}
137
138// =============================================================================
139// Tests
140// =============================================================================
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_parameter_creation() {
148        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
149        let param = Parameter::new(data, true);
150        assert!(param.requires_grad());
151        assert_eq!(param.shape(), vec![3]);
152        assert_eq!(param.numel(), 3);
153    }
154
155    #[test]
156    fn test_parameter_named() {
157        let data = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
158        let param = Parameter::named("weight", data, true);
159        assert_eq!(param.name(), "weight");
160    }
161
162    #[test]
163    fn test_parameter_update() {
164        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
165        let param = Parameter::new(data, true);
166
167        let new_data = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
168        param.update_data(new_data);
169
170        assert_eq!(param.data().to_vec(), vec![4.0, 5.0, 6.0]);
171    }
172
173    #[test]
174    fn test_parameter_apply_update() {
175        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
176        let param = Parameter::new(data, true);
177
178        param.apply_update(|d| d.mul_scalar(2.0));
179
180        assert_eq!(param.data().to_vec(), vec![2.0, 4.0, 6.0]);
181    }
182}