1use ndarray::ArrayD;
2use crate::compute::{MatMulUnit, ConvUnit};
3use crate::device::NpuDevice;
4use crate::error::Result;
5use crate::perf_monitor::PerformanceMonitor;
6use std::sync::Arc;
7
8pub struct ExecutionContext {
10 device: Arc<NpuDevice>,
11 matmul_unit: MatMulUnit,
12 conv_unit: ConvUnit,
13 perf_monitor: Arc<PerformanceMonitor>,
14}
15
16impl ExecutionContext {
17 pub fn new(device: Arc<NpuDevice>) -> Self {
19 let info = device.get_info();
20 Self {
21 device: device.clone(),
22 matmul_unit: MatMulUnit::new(info.peak_throughput_tops),
23 conv_unit: ConvUnit::new(info.peak_throughput_tops),
24 perf_monitor: device.get_perf_monitor(),
25 }
26 }
27
28 pub fn execute_matmul(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
30 if !self.device.is_ready() {
31 return Err(crate::error::NpuError::DeviceError(
32 "Device not ready".to_string(),
33 ));
34 }
35
36 let result = self.matmul_unit.gemm(a, b)?;
37
38 let m = a.shape()[0];
39 let k = a.shape()[1];
40 let n = b.shape()[1];
41 let ops = (2 * m * k * n) as u64;
42
43 self.perf_monitor.record_operation(ops);
44
45 Ok(result)
46 }
47
48 pub fn execute_batched_matmul(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
50 if !self.device.is_ready() {
51 return Err(crate::error::NpuError::DeviceError(
52 "Device not ready".to_string(),
53 ));
54 }
55
56 let result = self.matmul_unit.batched_gemm(a, b)?;
57
58 let batch = a.shape()[0];
59 let m = a.shape()[1];
60 let k = a.shape()[2];
61 let n = b.shape()[2];
62 let ops = (2 * batch * m * k * n) as u64;
63
64 self.perf_monitor.record_operation(ops);
65
66 Ok(result)
67 }
68
69 pub fn execute_conv1x1(
71 &self,
72 input: &ArrayD<f32>,
73 kernel: &ArrayD<f32>,
74 ) -> Result<ArrayD<f32>> {
75 if !self.device.is_ready() {
76 return Err(crate::error::NpuError::DeviceError(
77 "Device not ready".to_string(),
78 ));
79 }
80
81 let result = self.conv_unit.conv1x1(input, kernel)?;
82
83 let batch = input.shape()[0];
84 let height = input.shape()[1];
85 let width = input.shape()[2];
86 let c_in = input.shape()[3];
87 let c_out = kernel.shape()[3];
88 let ops = (2 * batch * height * width * c_in * c_out) as u64;
89
90 self.perf_monitor.record_operation(ops);
91
92 Ok(result)
93 }
94
95 pub fn get_current_throughput_gops(&self) -> f64 {
97 self.perf_monitor.get_throughput_gops()
98 }
99
100 pub fn get_metrics(&self) -> crate::perf_monitor::PerformanceMetrics {
102 self.perf_monitor.get_metrics()
103 }
104
105 pub fn get_device(&self) -> Arc<NpuDevice> {
107 self.device.clone()
108 }
109}
110
111pub struct BatchScheduler {
113 context: ExecutionContext,
114 batch_size: usize,
115}
116
117impl BatchScheduler {
118 pub fn new(device: Arc<NpuDevice>, batch_size: usize) -> Self {
120 Self {
121 context: ExecutionContext::new(device),
122 batch_size,
123 }
124 }
125
126 pub fn submit_batch(&self, operations: Vec<(&ArrayD<f32>, &ArrayD<f32>)>) -> Result<Vec<ArrayD<f32>>> {
128 let mut results = Vec::new();
129
130 for (a, b) in operations {
131 let result = self.context.execute_matmul(a, b)?;
132 results.push(result);
133 }
134
135 Ok(results)
136 }
137
138 pub fn get_context(&self) -> &ExecutionContext {
140 &self.context
141 }
142
143 pub fn get_batch_size(&self) -> usize {
145 self.batch_size
146 }
147}