use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum DType {
F32,
#[default]
F64,
I8,
I32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorSpec {
pub shape: Vec<usize>,
pub dtype: DType,
}
impl TensorSpec {
pub fn new(shape: Vec<usize>, dtype: DType) -> Self {
Self { shape, dtype }
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum OpAttr {
Int(i64),
Float(f64),
IntList(Vec<i64>),
FloatList(Vec<f64>),
String(String),
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum OpType {
Linear,
Conv1d,
ReLU,
Sigmoid,
Tanh,
Add,
Mul,
Reshape,
Transpose,
Softmax,
LayerNorm,
BatchNorm,
FusedLinearReLU,
Constant,
}
#[derive(Debug, Clone)]
pub struct OpNode {
pub id: usize,
pub op_type: OpType,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
pub attrs: HashMap<String, OpAttr>,
pub output_spec: TensorSpec,
pub name: Option<String>,
}
impl OpNode {
pub fn new(
id: usize,
op_type: OpType,
inputs: Vec<usize>,
attrs: HashMap<String, OpAttr>,
output_spec: TensorSpec,
) -> Self {
Self {
id,
op_type,
inputs,
outputs: Vec::new(),
attrs,
output_spec,
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone)]
pub struct StaticGraph {
pub nodes: Vec<OpNode>,
pub inputs: Vec<TensorSpec>,
pub outputs: Vec<TensorSpec>,
pub(crate) id_to_idx: HashMap<usize, usize>,
pub input_node_ids: Vec<usize>,
pub output_node_ids: Vec<usize>,
}
impl StaticGraph {
pub fn new(inputs: Vec<TensorSpec>, outputs: Vec<TensorSpec>) -> Self {
Self {
nodes: Vec::new(),
inputs,
outputs,
id_to_idx: HashMap::new(),
input_node_ids: Vec::new(),
output_node_ids: Vec::new(),
}
}
pub fn get_node(&self, id: usize) -> Option<&OpNode> {
self.id_to_idx.get(&id).map(|&idx| &self.nodes[idx])
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
}
#[derive(Debug, Clone)]
pub struct TraceConfig {
pub optimize: bool,
pub fold_constants: bool,
}
impl Default for TraceConfig {
fn default() -> Self {
Self {
optimize: true,
fold_constants: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_spec_num_elements() {
let spec = TensorSpec::new(vec![2, 3, 4], DType::F64);
assert_eq!(spec.num_elements(), 24);
}
#[test]
fn test_op_node_creation() {
let spec = TensorSpec::new(vec![4], DType::F64);
let node = OpNode::new(0, OpType::ReLU, vec![], HashMap::new(), spec.clone());
assert_eq!(node.id, 0);
assert_eq!(node.op_type, OpType::ReLU);
assert!(node.outputs.is_empty());
}
#[test]
fn test_static_graph_lookup() {
let input_spec = TensorSpec::new(vec![4], DType::F64);
let out_spec = TensorSpec::new(vec![2], DType::F64);
let mut graph = StaticGraph::new(vec![input_spec.clone()], vec![out_spec.clone()]);
let node = OpNode::new(0, OpType::Linear, vec![], HashMap::new(), out_spec);
graph.nodes.push(node);
graph.id_to_idx.insert(0, 0);
assert!(graph.get_node(0).is_some());
assert!(graph.get_node(99).is_none());
assert_eq!(graph.num_nodes(), 1);
}
#[test]
fn test_trace_config_default() {
let cfg = TraceConfig::default();
assert!(cfg.optimize);
assert!(cfg.fold_constants);
}
}