use std::sync::Arc;
use axonml_autograd::Variable;
use axonml_core::Device;
use axonml_tensor::Tensor;
use parking_lot::RwLock;
#[derive(Clone)]
pub struct Parameter {
data: Arc<RwLock<Variable>>,
name: String,
}
impl Parameter {
pub fn new(data: Tensor<f32>, requires_grad: bool) -> Self {
Self {
data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
name: String::new(),
}
}
pub fn named(name: impl Into<String>, data: Tensor<f32>, requires_grad: bool) -> Self {
Self {
data: Arc::new(RwLock::new(Variable::new(data, requires_grad))),
name: name.into(),
}
}
pub fn from_variable(var: Variable) -> Self {
Self {
data: Arc::new(RwLock::new(var)),
name: String::new(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn set_name(&mut self, name: impl Into<String>) {
self.name = name.into();
}
pub fn variable(&self) -> Variable {
self.data.read().clone()
}
pub fn data(&self) -> Tensor<f32> {
self.data.read().data()
}
pub fn shape(&self) -> Vec<usize> {
self.data.read().shape()
}
pub fn numel(&self) -> usize {
self.data.read().numel()
}
pub fn requires_grad(&self) -> bool {
self.data.read().requires_grad()
}
pub fn grad(&self) -> Option<Tensor<f32>> {
self.data.read().grad()
}
pub fn zero_grad(&self) {
self.data.read().zero_grad();
}
pub fn set_grad(&self, grad: Tensor<f32>) {
self.data.read().set_grad(grad);
}
pub fn update_data(&self, new_data: Tensor<f32>) {
let mut guard = self.data.write();
let requires_grad = guard.requires_grad();
*guard = Variable::new(new_data, requires_grad);
}
pub fn apply_update<F>(&self, f: F)
where
F: FnOnce(&Tensor<f32>) -> Tensor<f32>,
{
let current = self.data();
let updated = f(¤t);
self.update_data(updated);
}
pub fn to_device(&self, device: Device) {
let current = self.data();
if current.device() == device {
return;
}
let moved = current
.to_device(device)
.expect("Failed to move parameter to device");
self.update_data(moved);
}
}
impl std::fmt::Debug for Parameter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Parameter")
.field("name", &self.name)
.field("shape", &self.shape())
.field("requires_grad", &self.requires_grad())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parameter_creation() {
let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let param = Parameter::new(data, true);
assert!(param.requires_grad());
assert_eq!(param.shape(), vec![3]);
assert_eq!(param.numel(), 3);
}
#[test]
fn test_parameter_named() {
let data = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
let param = Parameter::named("weight", data, true);
assert_eq!(param.name(), "weight");
}
#[test]
fn test_parameter_update() {
let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let param = Parameter::new(data, true);
let new_data = Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap();
param.update_data(new_data);
assert_eq!(param.data().to_vec(), vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_parameter_apply_update() {
let data = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let param = Parameter::new(data, true);
param.apply_update(|d| d.mul_scalar(2.0));
assert_eq!(param.data().to_vec(), vec![2.0, 4.0, 6.0]);
}
}