use crate::asg::{Asg, DType, NodeId, NodeType, Shape, Value};
use crate::nn::init::Initializer;
use ndarray::ArrayD;
use std::cell::RefCell;
use std::collections::HashMap;
use std::ops::{Add, Div, Mul, Sub};
use std::rc::Rc;
#[derive(Debug, Clone, PartialEq)]
pub struct ParameterMeta {
pub shape: Shape,
pub dtype: DType,
pub initializer: Initializer,
}
#[derive(Debug, Clone)]
pub struct GraphContext {
main_graph: Asg,
parameter_meta: HashMap<String, ParameterMeta>,
}
impl GraphContext {
pub fn new() -> Self {
Self {
main_graph: Asg::new(0, Some("main".to_string())),
parameter_meta: HashMap::new(),
}
}
pub fn main_graph_mut(&mut self) -> &mut Asg {
&mut self.main_graph
}
pub fn main_graph(&self) -> &Asg {
&self.main_graph
}
pub fn register_parameter_meta(&mut self, name: &str, meta: ParameterMeta) {
self.parameter_meta.insert(name.to_string(), meta);
}
pub fn parameter_meta(&self, name: &str) -> Option<&ParameterMeta> {
self.parameter_meta.get(name)
}
pub fn parameter_registry(&self) -> &HashMap<String, ParameterMeta> {
&self.parameter_meta
}
pub fn build_shape_map(
&self,
input_shapes: &HashMap<String, (Shape, DType)>,
) -> HashMap<String, (Shape, DType)> {
let mut map = input_shapes.clone();
for (name, meta) in &self.parameter_meta {
map.insert(name.clone(), (meta.shape.clone(), meta.dtype));
}
map
}
pub fn init_parameters(&self, runtime_data: &mut HashMap<String, Value>) {
for (name, meta) in &self.parameter_meta {
if runtime_data.contains_key(name) {
continue;
}
let arr = meta.initializer.sample(&meta.shape);
runtime_data.insert(name.clone(), Value::Tensor(arr));
}
}
}
impl Default for GraphContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Tensor {
pub node_id: NodeId,
pub context: Rc<RefCell<GraphContext>>,
}
impl Tensor {
pub fn new_input(context: &Rc<RefCell<GraphContext>>, name: &str) -> Self {
let mut ctx = context.borrow_mut();
let graph = ctx.main_graph_mut();
let node_id = graph.add_node(
Some(name.to_string()),
NodeType::Input {
name: name.to_string(),
},
);
graph.inputs.push(node_id);
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn new_parameter(context: &Rc<RefCell<GraphContext>>, name: &str) -> Self {
let node_id = context.borrow_mut().main_graph_mut().add_node(
Some(name.to_string()),
NodeType::Parameter {
name: name.to_string(),
},
);
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn new_parameter_with_shape(
context: &Rc<RefCell<GraphContext>>,
name: &str,
shape: Shape,
initializer: Initializer,
) -> Self {
let mut ctx = context.borrow_mut();
let node_id = ctx.main_graph_mut().add_node(
Some(name.to_string()),
NodeType::Parameter {
name: name.to_string(),
},
);
ctx.register_parameter_meta(
name,
ParameterMeta {
shape,
dtype: DType::F32,
initializer,
},
);
drop(ctx);
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn new_literal(context: &Rc<RefCell<GraphContext>>, data: ArrayD<f32>, name: &str) -> Self {
let node_id = context.borrow_mut().main_graph_mut().add_node(
Some(name.to_string()),
NodeType::Literal(Value::Tensor(data)),
);
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn pow(&self, power: &Tensor) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Power(self.node_id, power.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn dot(&self, other: &Tensor) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::MatrixMultiply(self.node_id, other.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn matmul(&self, other: &Tensor) -> Self {
self.dot(other)
}
pub fn pow_scalar(&self, power: f32) -> Self {
let exp = Tensor::new_literal(&self.context, ndarray::arr0(power).into_dyn(), "pow_exp");
self.pow(&exp)
}
pub fn sqrt(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Sqrt(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn relu(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::ReLU(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn sigmoid(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Sigmoid(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn softmax(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Softmax(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn tanh(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Tanh(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn leaky_relu(&self, negative_slope: f32) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::LeakyReLU(self.node_id, negative_slope));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn gelu(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::GELU(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn silu(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::SiLU(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn elu(&self, alpha: f32) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::ELU(self.node_id, alpha));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn softplus(&self, beta: f32) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Softplus(self.node_id, beta));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn exp(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Exp(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn abs(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Abs(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn neg(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Neg(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn clamp(&self, min_val: f32, max_val: f32) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Clamp(self.node_id, min_val, max_val));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn log(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Log(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn square(&self) -> Self {
let two = Tensor::scalar(&self.context, 2.0);
self.pow(&two)
}
pub fn scalar(context: &Rc<RefCell<GraphContext>>, value: f32) -> Self {
let data = ArrayD::from_elem(ndarray::IxDyn(&[]), value);
let node_id = context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Literal(Value::Tensor(data)));
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn from_vec(context: &Rc<RefCell<GraphContext>>, values: Vec<f32>) -> Self {
let len = values.len();
let data = ArrayD::from_shape_vec(ndarray::IxDyn(&[len]), values).unwrap();
let node_id = context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Literal(Value::Tensor(data)));
Self {
node_id,
context: Rc::clone(context),
}
}
pub fn sum(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Sum(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn mean(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Mean(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn variance(&self) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Variance(self.node_id));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn reshape(&self, shape: Vec<i64>) -> Self {
let shape_data_f32: Vec<f32> = shape.iter().map(|&x| x as f32).collect();
let shape_array =
ArrayD::from_shape_vec(ndarray::IxDyn(&[shape.len()]), shape_data_f32).unwrap();
let shape_node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Literal(Value::Tensor(shape_array)));
let reshape_node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Reshape(self.node_id, shape_node_id));
Self {
node_id: reshape_node_id,
context: Rc::clone(&self.context),
}
}
pub fn transpose(&self, axis1: usize, axis2: usize) -> Self {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Transpose(self.node_id, axis1, axis2));
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn slice(&self, axis: usize, start: usize, end: usize) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::Slice {
input: self.node_id,
axis,
start,
end,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn concat(&self, others: &[&Tensor], axis: usize) -> Self {
let mut inputs = Vec::with_capacity(1 + others.len());
inputs.push(self.node_id);
for t in others {
inputs.push(t.node_id);
}
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Concat { inputs, axis });
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::MaxPool2d {
input: self.node_id,
kernel_size,
stride,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn conv2d(
&self,
weight: &Tensor,
bias: Option<&Tensor>,
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::Conv2d {
input: self.node_id,
weight: weight.node_id,
bias: bias.map(|b| b.node_id),
stride,
padding,
dilation,
groups,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn conv_transpose2d(
&self,
weight: &Tensor,
bias: Option<&Tensor>,
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::ConvTranspose2d {
input: self.node_id,
weight: weight.node_id,
bias: bias.map(|b| b.node_id),
stride,
padding,
output_padding,
dilation,
groups,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn avg_pool2d(
&self,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::AvgPool2d {
input: self.node_id,
kernel_size,
stride,
padding,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn adaptive_avg_pool2d(&self, output_size: (usize, usize)) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::AdaptiveAvgPool2d {
input: self.node_id,
output_size,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
pub fn embedding(&self, weight: &Tensor) -> Self {
let node_id = self.context.borrow_mut().main_graph_mut().add_node(
None,
NodeType::Embedding {
indices: self.node_id,
weight: weight.node_id,
},
);
Self {
node_id,
context: Rc::clone(&self.context),
}
}
}
impl Add<&Tensor> for &Tensor {
type Output = Tensor;
fn add(self, rhs: &Tensor) -> Self::Output {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Add(self.node_id, rhs.node_id));
Tensor {
node_id,
context: Rc::clone(&self.context),
}
}
}
impl Sub<&Tensor> for &Tensor {
type Output = Tensor;
fn sub(self, rhs: &Tensor) -> Self::Output {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Subtract(self.node_id, rhs.node_id));
Tensor {
node_id,
context: Rc::clone(&self.context),
}
}
}
impl Mul<&Tensor> for &Tensor {
type Output = Tensor;
fn mul(self, rhs: &Tensor) -> Self::Output {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Multiply(self.node_id, rhs.node_id));
Tensor {
node_id,
context: Rc::clone(&self.context),
}
}
}
impl Div<&Tensor> for &Tensor {
type Output = Tensor;
fn div(self, rhs: &Tensor) -> Self::Output {
let node_id = self
.context
.borrow_mut()
.main_graph_mut()
.add_node(None, NodeType::Divide(self.node_id, rhs.node_id));
Tensor {
node_id,
context: Rc::clone(&self.context),
}
}
}
impl Add<Tensor> for &Tensor {
type Output = Tensor;
fn add(self, rhs: Tensor) -> Self::Output {
self.add(&rhs)
}
}
impl Add<&Tensor> for Tensor {
type Output = Tensor;
fn add(self, rhs: &Tensor) -> Self::Output {
(&self).add(rhs)
}
}
impl Add<Tensor> for Tensor {
type Output = Tensor;
fn add(self, rhs: Tensor) -> Self::Output {
(&self).add(&rhs)
}
}
impl Sub<Tensor> for &Tensor {
type Output = Tensor;
fn sub(self, rhs: Tensor) -> Self::Output {
self.sub(&rhs)
}
}
impl Sub<&Tensor> for Tensor {
type Output = Tensor;
fn sub(self, rhs: &Tensor) -> Self::Output {
(&self).sub(rhs)
}
}
impl Sub<Tensor> for Tensor {
type Output = Tensor;
fn sub(self, rhs: Tensor) -> Self::Output {
(&self).sub(&rhs)
}
}
impl Mul<Tensor> for &Tensor {
type Output = Tensor;
fn mul(self, rhs: Tensor) -> Self::Output {
self.mul(&rhs)
}
}
impl Mul<&Tensor> for Tensor {
type Output = Tensor;
fn mul(self, rhs: &Tensor) -> Self::Output {
(&self).mul(rhs)
}
}
impl Mul<Tensor> for Tensor {
type Output = Tensor;
fn mul(self, rhs: Tensor) -> Self::Output {
(&self).mul(&rhs)
}
}
impl Div<Tensor> for &Tensor {
type Output = Tensor;
fn div(self, rhs: Tensor) -> Self::Output {
self.div(&rhs)
}
}
impl Div<&Tensor> for Tensor {
type Output = Tensor;
fn div(self, rhs: &Tensor) -> Self::Output {
(&self).div(rhs)
}
}
impl Div<Tensor> for Tensor {
type Output = Tensor;
fn div(self, rhs: Tensor) -> Self::Output {
(&self).div(&rhs)
}
}