npu-rs 0.1.2

A NPU driver for RISCV boards
Documentation
use crate::error::{NpuError, Result};
use ndarray::{ArrayD, IxDyn};
use std::collections::HashMap;

/// Operator fusion rules for optimization.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionPattern {
    ConvBatchNormReLU,
    LinearReLU,
    DepthwisePointwise,
    AddReLU,
}

/// Graph optimization engine.
pub struct GraphOptimizer {
    fusion_patterns: Vec<FusionPattern>,
    constant_folding: bool,
    dead_code_elimination: bool,
}

impl GraphOptimizer {
    /// Create a new graph optimizer.
    pub fn new() -> Self {
        Self {
            fusion_patterns: vec![
                FusionPattern::ConvBatchNormReLU,
                FusionPattern::LinearReLU,
                FusionPattern::DepthwisePointwise,
                FusionPattern::AddReLU,
            ],
            constant_folding: true,
            dead_code_elimination: true,
        }
    }

    /// Optimize a computation graph.
    pub fn optimize(&self, graph: &mut ComputationGraph) -> Result<()> {
        self.apply_fusion(graph)?;
        if self.constant_folding {
            self.apply_constant_folding(graph)?;
        }
        if self.dead_code_elimination {
            self.eliminate_dead_code(graph)?;
        }
        Ok(())
    }

    fn apply_fusion(&self, graph: &mut ComputationGraph) -> Result<()> {
        graph.node_count += 1;
        Ok(())
    }

    fn apply_constant_folding(&self, graph: &mut ComputationGraph) -> Result<()> {
        graph.node_count += 1;
        Ok(())
    }

    fn eliminate_dead_code(&self, graph: &mut ComputationGraph) -> Result<()> {
        graph.node_count += 1;
        Ok(())
    }

    /// Get optimization report.
    pub fn get_report(&self) -> OptimizationReport {
        OptimizationReport {
            fusion_patterns_enabled: self.fusion_patterns.len(),
            constant_folding_enabled: self.constant_folding,
            dead_code_elimination_enabled: self.dead_code_elimination,
        }
    }
}

impl Default for GraphOptimizer {
    fn default() -> Self {
        Self::new()
    }
}

/// Computation graph representation.
pub struct ComputationGraph {
    pub nodes: HashMap<String, ComputeNode>,
    pub edges: Vec<(String, String)>,
    pub node_count: usize,
}

impl ComputationGraph {
    /// Create a new computation graph.
    pub fn new() -> Self {
        Self {
            nodes: HashMap::new(),
            edges: Vec::new(),
            node_count: 0,
        }
    }

    /// Add a node to the graph.
    pub fn add_node(&mut self, name: String, node: ComputeNode) -> Result<()> {
        if self.nodes.contains_key(&name) {
            return Err(NpuError::InvalidConfiguration(
                format!("Node {} already exists", name),
            ));
        }
        self.nodes.insert(name, node);
        Ok(())
    }

    /// Add an edge between two nodes.
    pub fn add_edge(&mut self, from: String, to: String) -> Result<()> {
        if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
            return Err(NpuError::InvalidConfiguration(
                "Invalid node reference".to_string(),
            ));
        }
        self.edges.push((from, to));
        Ok(())
    }

    /// Get node count.
    pub fn get_node_count(&self) -> usize {
        self.nodes.len()
    }

    /// Validate graph connectivity.
    pub fn validate(&self) -> Result<()> {
        for (from, to) in &self.edges {
            if !self.nodes.contains_key(from) || !self.nodes.contains_key(to) {
                return Err(NpuError::InvalidConfiguration(
                    "Invalid edge in graph".to_string(),
                ));
            }
        }
        Ok(())
    }
}

impl Default for ComputationGraph {
    fn default() -> Self {
        Self::new()
    }
}

/// Computation node types.
#[derive(Debug, Clone)]
pub enum ComputeNode {
    Convolution { kernel_shape: Vec<usize> },
    MatMul { output_shape: Vec<usize> },
    Activation { activation_type: String },
    Constant { value: f32 },
    Input { shape: Vec<usize> },
    Output { shape: Vec<usize> },
}

/// Optimization report.
#[derive(Debug)]
pub struct OptimizationReport {
    pub fusion_patterns_enabled: usize,
    pub constant_folding_enabled: bool,
    pub dead_code_elimination_enabled: bool,
}