use std::cell::RefCell;
use std::collections::{BTreeSet, BTreeMap};
use std::fmt;
use std::rc::Rc;
use super::collection::generational_index::*;
use super::collection::graph::Graph;
use super::tensor::Tensor;
use super::op::*;
pub struct Module {
net: Rc<RefCell<Net>>,
}
impl Module {
pub fn new() -> Module {
Module {
net: Rc::new(RefCell::new(Net::new())),
}
}
pub fn var(&mut self) -> Var {
let mut new_var = Var::new();
{
self.net.borrow_mut().init_var(&mut new_var);
new_var.net = Rc::clone(&self.net);
}
new_var
}
pub fn eval(&self) {
self.net.borrow_mut().eval().expect("");
}
pub fn forward(&self) {
self.net.borrow_mut().eval().expect("");
}
pub fn backward_vector(&self, og: &BTreeMap<NetIndex, Tensor>) {
self.net.borrow_mut().bptt(og);
}
pub fn backward(&self, og: f32) {
self.net.borrow_mut().bptt_scale(og);
}
pub fn _visit_data<F>(&self, closure: F)
where F: Fn(&Op) {
}
pub fn _visit_op<F>(&self, closure: F)
where F: Fn(&Op) {
self.net.borrow_mut().visit_op(closure);
}
}
pub struct Var {
id: NetIndex,
net: Rc<RefCell<Net>>,
}
macro_rules! var_op_method {
($a:ident, $b:ident) => {
pub fn $a(&self, o: &Var) -> Var {
let result = self.new_attached();
self.net
.borrow_mut()
.connect(&vec![self.id, o.id], Op::new(Box::new($b::new())), &vec![result.id]);
result
}
}
}
impl Var {
pub fn new() -> Var {
Var {
id: NetIndex::new(0, 0),
net: Rc::new(RefCell::new(Net::new())),
}
}
pub fn new_attached(&self) -> Var {
let mut new_var = Var::new();
{
self.net.borrow_mut().init_var(&mut new_var);
new_var.net = Rc::clone(&self.net);
}
new_var
}
pub fn _id(&self) -> &NetIndex {
&self.id
}
pub fn set(&self, v: Tensor) {
self.net
.borrow_mut()
.data
.replace(&self.id, v).expect("");
self.net.borrow_mut().set_mark(&self.id);
}
pub fn get(&self) -> Tensor {
self.net.borrow().data.get(&self.id).expect("").clone()
}
pub fn get_grad(&self) -> Tensor {
self.net.borrow().data_grad.get(&self.id).expect("").clone()
}
pub fn to(&self, op: &Op) -> Var {
let result = self.new_attached();
self.net.borrow_mut().connect(&vec![self.id], op.clone(), &vec![result.id]);
result
}
pub fn size(&self) -> Vec<usize> {
self.net.borrow().data.get(&self.id).expect("").size()
}
pub fn numel(&self) -> usize {
self.net.borrow().data.get(&self.id).expect("").numel()
}
var_op_method!(add, Add);
var_op_method!(sub, Sub);
var_op_method!(mul, Mul);
var_op_method!(div, Div);
}
impl fmt::Display for Var {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"({}, {})",
self.id,
self.net.borrow().data.get(&self.id).expect("")
)
}
}
pub fn mseloss(a: &Var, b: &Var) -> Var {
let result = a.new_attached();
a.net.borrow_mut().connect(&vec![a.id, b.id], Op::new(Box::new(MSELoss::new())), &vec![result.id]);
result
}
pub fn bcewithlogitsloss(predict: &Var, label: &Var) -> Var {
let result = predict.new_attached();
predict.net.borrow_mut().connect(&vec![predict.id, label.id], Op::new(Box::new(BCEWithLogitsLoss::new())), &vec![result.id]);
result
}
struct Net {
data: GenIndex<Tensor>,
ops: GenIndex<Op>,
set_mark: BTreeSet<NetIndex>,
graph: Graph,
data_grad: BTreeMap<NetIndex, Tensor>,
}
impl Net {
fn new() -> Net {
Net {
data: GenIndex::new(),
ops: GenIndex::new(),
set_mark: BTreeSet::new(),
graph: Graph::new(),
data_grad: BTreeMap::new(),
}
}
fn init_var(&mut self, var: &mut Var) {
let id = self.data.insert(Tensor::new());
self.graph.add_data(&id).expect("");
var.id = id;
}
fn del_var(&mut self, var: &Var) {
self.data.remove(&var.id).expect("");
self.graph.del_data(&var.id).expect("");
}
fn init_op(&mut self, op: Op) -> NetIndex {
let id = self.ops.insert(op.clone());
self.graph.add_op(&id).expect("");
id
}
fn connect(&mut self, input: &[NetIndex], op: Op, output: &[NetIndex]) {
let opid = self.init_op(op);
self.graph.connect(input, output, &opid).expect("");
}
fn set_mark(&mut self, did: &NetIndex) {
self.set_mark.insert(*did);
}
fn unset_mark(&mut self, did: &NetIndex) {
self.set_mark.remove(did);
}
fn merge(&self, o: &Net) -> Net {
Net::new()
}
fn eval(&mut self) -> Result<(), BTreeSet<NetIndex>> {
let mut all_input = Vec::new();
for i in &self.set_mark {
all_input.push(i.clone());
}
self.graph
.walk(
&all_input[..],
true,
|input, output, op| {
let mut inputs: Vec<&Tensor> = Vec::new();
for input_id in input {
let a = self.data.get(input_id).expect("");
inputs.push(a);
}
let mut outputs: Vec<&Tensor> = Vec::new();
for output_id in output {
let a = self.data.get(output_id).expect("");
outputs.push(a);
}
self.ops
.get(op)
.expect("")
.apply(&inputs, &outputs);
}
)?;
Ok(())
}
fn bptt_scale(&mut self, r: f32) {
let output = self.graph.get_output_cache();
let mut output_grad = BTreeMap::new();
for i in &output {
output_grad.insert(i.clone(),
Tensor::fill(&self.data.get(i).expect("").size(),
r));
}
self.bptt(&output_grad);
}
fn bptt(&mut self, output_grad: &BTreeMap<NetIndex, Tensor>) {
let mut output = Vec::new();
self.data_grad.clear();
for (k, v) in output_grad {
output.push(k.clone());
self.data_grad.insert(k.clone(), v.clone());
}
for i in self.graph.list_data() {
if ! self.data_grad.contains_key(&i) {
self.data_grad.insert(i, Tensor::new());
}
}
self.graph
.walk(
&output[..],
false,
|output_grads, input_grads, op| {
let mut inputs: Vec<&Tensor> = Vec::new();
for input_id in input_grads {
let a = self.data.get(input_id).expect("");
inputs.push(a);
}
let mut output_grad: Vec<&Tensor> = Vec::new();
for output_id in output_grads {
let a = self.data_grad.get(output_id).expect("");
output_grad.push(a);
}
let mut input_grad: Vec<&Tensor> = Vec::new();
for input_id in input_grads {
let a = self.data_grad.get(input_id).expect("");
input_grad.push(a);
}
self.ops
.get(op)
.expect("")
.grad(&inputs, &output_grad, &input_grad);
}
).expect("");
}
pub fn visit_op<F>(&mut self, closure: F)
where F: Fn(&Op) {
for i in self.graph.list_op() {
closure(self.ops.get(&i).expect(""));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn genindex_new_add_del() {
}
}