use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::autograd::AutogradError;
use crate::nn::Module;
use crate::tensor::{GradId, ParamId, Tensor};
static NEXT_PARAM_ID: AtomicUsize = AtomicUsize::new(0);
fn alloc_param_id() -> ParamId {
ParamId(NEXT_PARAM_ID.fetch_add(1, Ordering::Relaxed))
}
#[derive(Clone, Debug)]
pub struct Parameter {
pub id: ParamId,
pub tensor: Tensor,
}
impl Parameter {
pub fn new(mut tensor: Tensor) -> Self {
tensor.set_requires_grad(true);
Self {
id: alloc_param_id(),
tensor,
}
}
pub fn grad_id(&self) -> GradId {
self.tensor
.grad_id()
.expect("Parameter tensor must have a GradId")
}
}
impl Module for Parameter {
fn parameters(&self) -> Vec<Parameter> {
vec![self.clone()]
}
fn state_dict(&self, prefix: &str) -> HashMap<String, Tensor> {
let key = prefix.trim_end_matches('.').to_string();
let mut map = HashMap::new();
map.insert(key, self.tensor.clone());
map
}
fn load_state_dict(
&mut self,
dict: &HashMap<String, Tensor>,
prefix: &str,
) -> Result<(), AutogradError> {
let key = prefix.trim_end_matches('.');
if let Some(loaded) = dict.get(key) {
if self.tensor.shape() != loaded.shape() {
return Err(AutogradError::StateError {
key: key.to_string(),
message: format!(
"shape mismatch: expected {:?}, got {:?}",
self.tensor.shape(),
loaded.shape(),
),
});
}
let src_guard = loaded.storage.data();
let mut dst_guard = self.tensor.storage.data_write();
dst_guard.copy_from_slice(&*src_guard);
drop(dst_guard);
drop(src_guard);
self.tensor.storage.bump_version();
}
Ok(())
}
}