use std::sync::Arc;
use std::sync::RwLock;
use serde::{Serialize, Deserialize};
use super::gtensor::GTensor;
use super::serde::GTensorSerde;
use crate::tensor::shape::Shape;
use crate::initializers::Initializer;
pub struct Arena {
pub(in super) outputs: Vec<Arc<RwLock<Vec<f32>>>>,
pub(in super) gradients: Vec<Arc<RwLock<Vec<f32>>>>,
pub(in super) parameters: Vec<Arc<RwLock<Vec<f32>>>>
}
impl Arena {
pub fn new() -> Self {
Self {
outputs: Vec::new(),
gradients: Vec::new(),
parameters: Vec::new(),
}
}
pub fn clear_gradients(&mut self) {
for i in 0..self.gradients.len() {
self.gradients[i].write().unwrap().fill(0.);
}
}
pub fn alloc(&mut self, shape: Shape, batched: bool) -> (GTensor, GTensor) {
let output = Arc::new(RwLock::new(vec![0.0; shape.len()]));
let gradient = Arc::new(RwLock::new(vec![0.0; shape.len()]));
let t1 = GTensor::new(output.clone(), shape, ArenaIndex::new('O', self.outputs.len()), batched);
let t2 = GTensor::new(gradient.clone(), shape, ArenaIndex::new('G', self.gradients.len()), batched);
self.outputs.push(output);
self.gradients.push(gradient);
(t1, t2)
}
pub fn alloc_parameter(&mut self, shape: Shape, init: Box<dyn Initializer>, batched: bool) -> (GTensor, GTensor) {
let parameter = Arc::new(RwLock::new(init.init_vec(shape)));
let gradient = Arc::new(RwLock::new(vec![0.0; shape.len()]));
let t1 = GTensor::new(parameter.clone(), shape, ArenaIndex::new('P', self.parameters.len()), batched);
let t2 = GTensor::new(gradient.clone(), shape, ArenaIndex::new('G', self.gradients.len()), batched);
self.parameters.push(parameter);
self.gradients.push(gradient);
(t1, t2)
}
pub fn load(&self, tensor: GTensorSerde) -> GTensor {
let data =
match tensor.index.vec {
'O' => self.outputs[tensor.index.index].clone(),
'G' => self.gradients[tensor.index.index].clone(),
'P' => self.parameters[tensor.index.index].clone(),
_ => panic!("Invalid ArenaIndex Vec ID! id: {}", tensor.index.vec)
};
{
let len = data.read().unwrap().len();
if len != tensor.shape.len() {
panic!("Cannot load tensor data of length {} into a tensor with shape {}! (length mismatch)",
len, tensor.shape)
}
}
GTensor {
data, shape: tensor.shape, index: tensor.index, is_batched: tensor.is_batched
}
}
}
#[derive(Copy, Clone, Serialize, Deserialize)]
pub struct ArenaIndex {
vec: char,
index: usize,
}
impl ArenaIndex {
pub fn new(vec: char, index: usize) -> Self {
Self {
vec, index,
}
}
}