1use std::sync::Arc;
11
12use axonml_autograd::Variable;
13use axonml_tensor::Tensor;
14use parking_lot::RwLock;
15
16#[derive(Clone)]
27pub struct Parameter {
28 data: Arc<RwLock<Variable>>,
30 name: String,
32}
33
34impl Parameter {
35 pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
41 Self {
42 data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
43 name: String::new(),
44 }
45 }
46
47 pub fn named(name: impl Into<String>, data: Tensor<f32>, requires_grad: bool) -> Self {
49 Self {
50 data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
51 name: name.into(),
52 }
53 }
54
55 pub fn from_variable(var: Variable) -> Self {
57 Self {
58 data: Arc::new(RwLock::new(var)),
59 name: String::new(),
60 }
61 }
62
63 pub fn name(&self) -> &str {
65 &self.name
66 }
67
68 pub fn set_name(&mut self, name: impl Into<String>) {
70 self.name = name.into();
71 }
72
73 pub fn variable(&self) -> Variable {
75 self.data.read().clone()
76 }
77
78 pub fn data(&self) -> Tensor<f32> {
80 self.data.read().data()
81 }
82
83 pub fn shape(&self) -> Vec<usize> {
85 self.data.read().shape()
86 }
87
88 pub fn numel(&self) -> usize {
90 self.data.read().numel()
91 }
92
93 pub fn requires_grad(&self) -> bool {
95 self.data.read().requires_grad()
96 }
97
98 pub fn grad(&self) -> Option<Tensor<f32>> {
100 self.data.read().grad()
101 }
102
103 pub fn zero_grad(&self) {
105 self.data.read().zero_grad();
106 }
107
108 pub fn update_data(&self, new_data: Tensor<f32>) {
112 let mut guard = self.data.write();
113 let requires_grad = guard.requires_grad();
114 *guard = Variable::new(new_data, requires_grad);
115 }
116
117 pub fn apply_update<F>(&self, f: F)
119 where
120 F: FnOnce(&Tensor<f32>) -> Tensor<f32>,
121 {
122 let current = self.data();
123 let updated = f(¤t);
124 self.update_data(updated);
125 }
126}
127
128impl std::fmt::Debug for Parameter {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("Parameter")
131 .field("name", &self.name)
132 .field("shape", &self.shape())
133 .field("requires_grad", &self.requires_grad())
134 .finish()
135 }
136}
137
138#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_parameter_creation() {
148 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
149 let param = Parameter::new(data, true);
150 assert!(param.requires_grad());
151 assert_eq!(param.shape(), vec![3]);
152 assert_eq!(param.numel(), 3);
153 }
154
155 #[test]
156 fn test_parameter_named() {
157 let data = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
158 let param = Parameter::named("weight", data, true);
159 assert_eq!(param.name(), "weight");
160 }
161
162 #[test]
163 fn test_parameter_update() {
164 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
165 let param = Parameter::new(data, true);
166
167 let new_data = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
168 param.update_data(new_data);
169
170 assert_eq!(param.data().to_vec(), vec![4.0, 5.0, 6.0]);
171 }
172
173 #[test]
174 fn test_parameter_apply_update() {
175 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
176 let param = Parameter::new(data, true);
177
178 param.apply_update(|d| d.mul_scalar(2.0));
179
180 assert_eq!(param.data().to_vec(), vec![2.0, 4.0, 6.0]);
181 }
182}