use crate::tensor::Tensor;
use crate::utils::TorchError;
use crate::{Device, Kind};
use std::collections::HashMap;
use std::ops::Div;
use std::sync::Mutex;
pub struct VarStore {
variables: Mutex<HashMap<String, Tensor>>,
device: Device,
}
pub struct Path<'a> {
path: Vec<String>,
var_store: &'a VarStore,
}
impl VarStore {
pub fn new(device: Device) -> VarStore {
VarStore {
variables: Mutex::new(HashMap::new()),
device,
}
}
pub fn device(&self) -> Device {
self.device
}
pub fn trainable_variables(&self) -> Vec<Tensor> {
let variables = self.variables.lock().unwrap();
variables
.values()
.filter_map(|x| {
if x.requires_grad() {
Some(x.shallow_clone())
} else {
None
}
})
.collect()
}
pub fn root(&self) -> Path {
Path {
path: vec![],
var_store: self,
}
}
pub fn save(&self, path: &std::path::Path) -> Result<(), TorchError> {
let variables = self.variables.lock().unwrap();
let named_tensors = variables
.iter()
.map(|(x, y)| (&x[..], y))
.collect::<Vec<_>>();
Tensor::save_multi(named_tensors.as_slice(), path)
}
pub fn load(&self, path: &std::path::Path) -> Result<(), TorchError> {
let named_tensors = Tensor::load_multi(path)?;
let named_tensors: HashMap<_, _> = named_tensors.into_iter().collect();
let variables = self.variables.lock().unwrap();
for (name, tensor) in variables.iter() {
match named_tensors.get(name) {
Some(src) => crate::no_grad(|| tensor.copy_(src)),
None => Err(TorchError::new(format!(
"cannot find {} in {:?}",
name, path
)))?,
}
}
Ok(())
}
}
impl<'a> Path<'a> {
pub fn sub(&'a self, s: &str) -> Path<'a> {
if s.chars().any(|x| x == '.') {
panic!("sub name cannot contain a dot {}", s);
}
let mut path = self.path.clone();
path.push(s.to_owned());
Path {
path,
var_store: self.var_store,
}
}
pub fn device(&self) -> Device {
self.var_store.device
}
fn path(&self, name: &str) -> String {
if name.chars().any(|x| x == '.') {
panic!("variable name cannot contain a dot {}", name);
}
if self.path.len() == 0 {
name.to_string()
} else {
format!("{}.{}", self.path.join("."), name)
}
}
fn add(&self, name: &str, tensor: Tensor) -> Tensor {
let path = self.path(name);
let mut variables = self.var_store.variables.lock().unwrap();
let path = if variables.contains_key(&path) {
format!("{}__{}", path, variables.len())
} else {
path
};
variables.insert(path, tensor.shallow_clone());
tensor
}
pub fn zeros_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
let z = Tensor::zeros(dims, (Kind::Float, self.device()));
self.add(name, z)
}
pub fn ones_no_train(&self, name: &str, dims: &[i64]) -> Tensor {
let z = Tensor::ones(dims, (Kind::Float, self.device()));
self.add(name, z)
}
pub fn zeros(&self, name: &str, dims: &[i64]) -> Tensor {
let z = Tensor::zeros(dims, (Kind::Float, self.device())).set_requires_grad(true);
self.add(name, z)
}
pub fn ones(&self, name: &str, dims: &[i64]) -> Tensor {
let z = Tensor::ones(dims, (Kind::Float, self.device())).set_requires_grad(true);
self.add(name, z)
}
pub fn randn_standard(&self, name: &str, dims: &[i64]) -> Tensor {
let z = Tensor::randn(dims, (Kind::Float, self.device())).set_requires_grad(true);
self.add(name, z)
}
pub fn randn(&self, name: &str, dims: &[i64], mean: f64, stdev: f64) -> Tensor {
let z = Tensor::randn(dims, (Kind::Float, self.device()));
let z = (z * stdev + mean).set_requires_grad(true);
self.add(name, z)
}
pub fn uniform(&self, name: &str, dims: &[i64], lo: f64, up: f64) -> Tensor {
let z = Tensor::zeros(dims, (Kind::Float, self.device()))
.uniform_(lo, up)
.set_requires_grad(true);
self.add(name, z)
}
pub fn kaiming_uniform(&self, name: &str, dims: &[i64]) -> Tensor {
let fan_in: i64 = dims.iter().skip(1).product();
let bound = (1.0 / fan_in as f64).sqrt();
self.uniform(name, dims, -bound, bound)
}
}
impl<'a> Div<&str> for &'a mut Path<'a> {
type Output = Path<'a>;
fn div(self, rhs: &str) -> Self::Output {
self.sub(&rhs)
}
}