npu_rs/
execution.rs

1use ndarray::ArrayD;
2use crate::compute::{MatMulUnit, ConvUnit};
3use crate::device::NpuDevice;
4use crate::error::Result;
5use crate::perf_monitor::PerformanceMonitor;
6use std::sync::Arc;
7
8/// Execution context for NPU operations.
9pub struct ExecutionContext {
10    device: Arc<NpuDevice>,
11    matmul_unit: MatMulUnit,
12    conv_unit: ConvUnit,
13    perf_monitor: Arc<PerformanceMonitor>,
14}
15
16impl ExecutionContext {
17    /// Create a new execution context.
18    pub fn new(device: Arc<NpuDevice>) -> Self {
19        let info = device.get_info();
20        Self {
21            device: device.clone(),
22            matmul_unit: MatMulUnit::new(info.peak_throughput_tops),
23            conv_unit: ConvUnit::new(info.peak_throughput_tops),
24            perf_monitor: device.get_perf_monitor(),
25        }
26    }
27
28    /// Execute matrix multiplication operation.
29    pub fn execute_matmul(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
30        if !self.device.is_ready() {
31            return Err(crate::error::NpuError::DeviceError(
32                "Device not ready".to_string(),
33            ));
34        }
35
36        let result = self.matmul_unit.gemm(a, b)?;
37
38        let m = a.shape()[0];
39        let k = a.shape()[1];
40        let n = b.shape()[1];
41        let ops = (2 * m * k * n) as u64;
42
43        self.perf_monitor.record_operation(ops);
44
45        Ok(result)
46    }
47
48    /// Execute batched matrix multiplication.
49    pub fn execute_batched_matmul(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
50        if !self.device.is_ready() {
51            return Err(crate::error::NpuError::DeviceError(
52                "Device not ready".to_string(),
53            ));
54        }
55
56        let result = self.matmul_unit.batched_gemm(a, b)?;
57
58        let batch = a.shape()[0];
59        let m = a.shape()[1];
60        let k = a.shape()[2];
61        let n = b.shape()[2];
62        let ops = (2 * batch * m * k * n) as u64;
63
64        self.perf_monitor.record_operation(ops);
65
66        Ok(result)
67    }
68
69    /// Execute 1x1 convolution.
70    pub fn execute_conv1x1(
71        &self,
72        input: &ArrayD<f32>,
73        kernel: &ArrayD<f32>,
74    ) -> Result<ArrayD<f32>> {
75        if !self.device.is_ready() {
76            return Err(crate::error::NpuError::DeviceError(
77                "Device not ready".to_string(),
78            ));
79        }
80
81        let result = self.conv_unit.conv1x1(input, kernel)?;
82
83        let batch = input.shape()[0];
84        let height = input.shape()[1];
85        let width = input.shape()[2];
86        let c_in = input.shape()[3];
87        let c_out = kernel.shape()[3];
88        let ops = (2 * batch * height * width * c_in * c_out) as u64;
89
90        self.perf_monitor.record_operation(ops);
91
92        Ok(result)
93    }
94
95    /// Get current throughput in GOPS.
96    pub fn get_current_throughput_gops(&self) -> f64 {
97        self.perf_monitor.get_throughput_gops()
98    }
99
100    /// Get performance metrics.
101    pub fn get_metrics(&self) -> crate::perf_monitor::PerformanceMetrics {
102        self.perf_monitor.get_metrics()
103    }
104
105    /// Get underlying device.
106    pub fn get_device(&self) -> Arc<NpuDevice> {
107        self.device.clone()
108    }
109}
110
111/// Batch execution scheduler for efficient workload distribution.
112pub struct BatchScheduler {
113    context: ExecutionContext,
114    batch_size: usize,
115}
116
117impl BatchScheduler {
118    /// Create a new batch scheduler.
119    pub fn new(device: Arc<NpuDevice>, batch_size: usize) -> Self {
120        Self {
121            context: ExecutionContext::new(device),
122            batch_size,
123        }
124    }
125
126    /// Submit a batch of operations.
127    pub fn submit_batch(&self, operations: Vec<(&ArrayD<f32>, &ArrayD<f32>)>) -> Result<Vec<ArrayD<f32>>> {
128        let mut results = Vec::new();
129
130        for (a, b) in operations {
131            let result = self.context.execute_matmul(a, b)?;
132            results.push(result);
133        }
134
135        Ok(results)
136    }
137
138    /// Get the execution context.
139    pub fn get_context(&self) -> &ExecutionContext {
140        &self.context
141    }
142
143    /// Get batch size.
144    pub fn get_batch_size(&self) -> usize {
145        self.batch_size
146    }
147}