use std::cell::RefCell;
use super::node::NodeBuilder;
use super::var::Var;
use super::node::Node;
use super::tape::Tape;
use crate::tensor::shape::ToShape;
use crate::operators::input::Input;
use crate::optimizers::Optimizer;
use crate::initializers::Initializer;
pub struct TapeBuilder {
pub tape: RefCell<Tape>,
pub opt: Box<dyn Optimizer>,
pub init: Box<dyn Initializer>,
}
impl TapeBuilder {
pub fn input<'t>(&'t self, shape: impl ToShape) -> Var<'t> {
let shape = shape.to_shape().add_batch(1);
if self.tape.borrow().nodes.len() != 0 {
panic!("Tapes cannot have multiple inputs! (Input Node Index was non-Zero)")
}
self.extend(NodeBuilder {
op: Box::new(Input),
deps: Vec::new(),
shape,
skip: false,
init: None,
is_batched: true,
})
}
pub fn parameter<'t>(&'t self, shape: impl ToShape) -> Var<'t> {
let shape = shape.to_shape();
self.extend(NodeBuilder {
op: self.opt.to_operator(shape),
deps: vec![],
shape,
skip: false,
init: Some(self.init._clone()),
is_batched: false,
})
}
pub fn extend<'t>(&'t self, builder: NodeBuilder) -> Var<'t> {
let mut tape = self.tape.borrow_mut();
let (output, gradient) =
if let Some(init) = builder.init {
tape.arena.alloc_parameter(builder.shape, init, builder.is_batched)
} else {
if builder.skip {
if builder.deps.len() != 1 {
panic!("A Skip node must have exactly 1 dependency!")
}
(
tape.nodes[builder.deps[0]].y.clone_reshape(builder.shape, builder.is_batched),
tape.nodes[builder.deps[0]].gy.clone_reshape(builder.shape, builder.is_batched),
)
} else {
tape.arena.alloc(builder.shape, builder.is_batched)
}
};
let node = Node {
y: output,
gy: gradient,
x: builder.deps.iter().map(|i| tape.nodes[*i].y.clone()).collect(),
gx: builder.deps.iter().map(|i| tape.nodes[*i].gy.clone()).collect(),
is_batched: builder.is_batched,
};
tape.nodes.push(node);
tape.ops.push(builder.op);
Var {
tape: self,
shape: builder.shape,
index: tape.nodes.len() - 1,
is_batched: builder.is_batched,
}
}
pub fn finish(self) -> Tape {
self.tape.into_inner()
}
}