zyx 0.15.5

Zyx machine learning library
Documentation
#![allow(unused)]

use crate::{
    DType,
    dtype::Constant,
    graph::{
        Node,
        compiled::{CachedGraph, CompiledGraph},
    },
    kernel::{BOp, UOp},
    shape::{Dim, UAxis},
    slab::{Slab, SlabId},
};

#[derive(Debug, Copy, Clone, Hash, Ord, PartialEq, PartialOrd, Eq)]
struct NodeId(u32);

impl SlabId for NodeId {
    const ZERO: Self = Self(0);
    const NULL: Self = Self(u32::MAX);

    fn inc(&mut self) {
        self.0 += 1;
    }
}

impl From<usize> for NodeId {
    fn from(value: usize) -> Self {
        Self(value as u32)
    }
}

impl From<NodeId> for usize {
    fn from(val: NodeId) -> Self {
        val.0 as usize
    }
}

#[derive(Debug)]
enum ENode {
    Const {
        value: Constant,
    },
    Leaf {
        shape: Box<[Dim]>,
        dtype: DType,
    },
    Cast {
        x: NodeId,
        dtype: DType,
    },
    Unary {
        x: NodeId,
        uop: UOp,
    },
    Binary {
        x: NodeId,
        y: NodeId,
        bop: BOp,
    },
    Reshape {
        x: NodeId,
        shape: Box<[Dim]>,
    },
    Expand {
        x: NodeId,
        shape: Box<[Dim]>,
    },
    Permute {
        x: NodeId,
        axes: Box<[UAxis]>,
        shape: Box<[Dim]>,
    },
    Pad {
        x: NodeId,
        padding: Box<[(i64, i64)]>,
        shape: Box<[Dim]>,
    },
    Reduce {
        x: NodeId,
        axes: Box<[UAxis]>,
        rop: BOp,
        shape: Box<[Dim]>,
    },
    Fused(Box<dyn FusedOp>),
}

trait FusedOp: std::fmt::Debug {
    fn try_fuse(g: &mut EGraph, nid: NodeId) -> Option<Self>
    where
        Self: Sized;
}

pub struct EGraph {
    nodes: Slab<NodeId, ENode>,
}

impl EGraph {
    pub fn new(graph: &CachedGraph) -> EGraph {
        let mut nodes: Slab<NodeId, ENode> = Slab::new();

        for (tid, node) in graph.nodes.iter().enumerate() {
            let enode = match *node {
                Node::Const { value } => ENode::Const { value },
                Node::Leaf { dtype } => ENode::Leaf { shape: graph.shape(tid).into(), dtype },
                Node::Expand { x } => ENode::Expand { x: NodeId(x.0), shape: graph.shape(tid).into() },
                Node::Permute { x } => {
                    ENode::Permute { x: NodeId(x.0), axes: graph.axes[&tid].clone(), shape: graph.shape(tid).into() }
                }
                Node::Reshape { x } => ENode::Reshape { x: NodeId(x.0), shape: graph.shape(tid).into() },
                Node::Pad { x } => {
                    ENode::Pad { x: NodeId(x.0), padding: graph.paddings[&tid].clone(), shape: graph.shape(tid).into() }
                }
                Node::Reduce { x, rop } => {
                    ENode::Reduce { x: NodeId(x.0), axes: graph.axes[&tid].clone(), rop, shape: graph.shape(tid).into() }
                }
                Node::Cast { x, dtype } => ENode::Cast { x: NodeId(x.0), dtype },
                Node::Unary { x, uop } => ENode::Unary { x: NodeId(x.0), uop },
                Node::Binary { x, y, bop } => ENode::Binary { x: NodeId(x.0), y: NodeId(x.0), bop },
                Node::Custom(_) => todo!(),
            };

            nodes.push(enode);
        }

        for (nid, node) in nodes.iter() {
            println!("{nid:?} -> {node:?}");
        }

        EGraph { nodes }
    }

    pub fn saturate(&mut self) {
        let ids: Vec<NodeId> = self.nodes.ids().collect();
        for nid in ids {
            if let Some(fused_op) = Matmul::try_fuse(self, nid) {
                self.nodes.push(ENode::Fused(Box::new(fused_op)));
            }
        }
    }

    pub fn extract(self) -> CompiledGraph {
        todo!()
    }
}

#[derive(Debug)]
struct Matmul {}

impl FusedOp for Matmul {
    fn try_fuse(g: &mut EGraph, nid: NodeId) -> Option<Self>
    where
        Self: Sized,
    {
        todo!()
    }
}