1use crate::error::{NpuError, Result};
2use crate::memory::MemoryPool;
3use crate::perf_monitor::PerformanceMonitor;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DeviceState {
11 Uninitialized,
12 Initialized,
13 Computing,
14 Error,
15}
16
17#[derive(Debug, Clone)]
19pub struct DeviceInfo {
20 pub device_id: u32,
21 pub peak_throughput_tops: f32,
22 pub memory_mb: usize,
23 pub compute_units: usize,
24 pub frequency_mhz: u32,
25 pub power_tdp_watts: f32,
26 pub vendor: String,
27 pub device_name: String,
28}
29
30impl Default for DeviceInfo {
31 fn default() -> Self {
32 Self {
33 device_id: 0,
34 peak_throughput_tops: 20.0,
35 memory_mb: 512,
36 compute_units: 4,
37 frequency_mhz: 800,
38 power_tdp_watts: 5.0,
39 vendor: "RISC NPU Vendor".to_string(),
40 device_name: "20-TOPS NPU Accelerator".to_string(),
41 }
42 }
43}
44
45pub struct NpuDevice {
47 info: DeviceInfo,
48 state: Arc<Mutex<DeviceState>>,
49 memory_pool: MemoryPool,
50 perf_monitor: Arc<PerformanceMonitor>,
51 initialized_at: Arc<Mutex<Option<SystemTime>>>,
52}
53
54impl NpuDevice {
55 pub fn new() -> Self {
57 Self::with_config(DeviceInfo::default())
58 }
59
60 pub fn with_config(info: DeviceInfo) -> Self {
62 Self {
63 info: info.clone(),
64 state: Arc::new(Mutex::new(DeviceState::Uninitialized)),
65 memory_pool: MemoryPool::new(info.memory_mb),
66 perf_monitor: Arc::new(PerformanceMonitor::new()),
67 initialized_at: Arc::new(Mutex::new(None)),
68 }
69 }
70
71 pub fn initialize(&self) -> Result<()> {
73 let mut state = self.state.lock();
74
75 match *state {
76 DeviceState::Uninitialized => {
77 *state = DeviceState::Initialized;
78 *self.initialized_at.lock() = Some(SystemTime::now());
79 Ok(())
80 }
81 DeviceState::Initialized => {
82 Err(NpuError::InitializationFailed(
83 "Device already initialized".to_string(),
84 ))
85 }
86 DeviceState::Error => {
87 Err(NpuError::InitializationFailed(
88 "Device in error state".to_string(),
89 ))
90 }
91 _ => Err(NpuError::InitializationFailed(
92 "Invalid state transition".to_string(),
93 )),
94 }
95 }
96
97 pub fn reset(&self) -> Result<()> {
99 let mut state = self.state.lock();
100 self.perf_monitor.reset();
101 *state = DeviceState::Initialized;
102 Ok(())
103 }
104
105 pub fn get_info(&self) -> DeviceInfo {
107 self.info.clone()
108 }
109
110 pub fn get_state(&self) -> DeviceState {
112 *self.state.lock()
113 }
114
115 pub fn get_memory_pool(&self) -> MemoryPool {
117 self.memory_pool.clone()
118 }
119
120 pub fn get_perf_monitor(&self) -> Arc<PerformanceMonitor> {
122 Arc::clone(&self.perf_monitor)
123 }
124
125 pub fn is_ready(&self) -> bool {
127 matches!(*self.state.lock(), DeviceState::Initialized)
128 }
129
130 pub fn get_status_json(&self) -> serde_json::Value {
132 let state = self.get_state();
133 let memory_stats = self.memory_pool.get_manager().get_stats();
134 let perf_metrics = self.perf_monitor.get_metrics();
135
136 serde_json::json!({
137 "device_id": self.info.device_id,
138 "device_name": self.info.device_name,
139 "state": format!("{:?}", state),
140 "peak_throughput_tops": self.info.peak_throughput_tops,
141 "current_memory_mb": memory_stats.allocated_bytes / 1024 / 1024,
142 "peak_memory_mb": memory_stats.peak_bytes / 1024 / 1024,
143 "total_memory_mb": self.info.memory_mb,
144 "performance": {
145 "total_operations": perf_metrics.total_operations,
146 "total_time_ms": perf_metrics.total_time_ms,
147 "peak_power_watts": perf_metrics.peak_power_watts,
148 "throughput_gops": self.perf_monitor.get_throughput_gops(),
149 }
150 })
151 }
152
153 pub fn shutdown(&self) -> Result<()> {
155 let mut state = self.state.lock();
156 match *state {
157 DeviceState::Initialized | DeviceState::Computing => {
158 *state = DeviceState::Uninitialized;
159 Ok(())
160 }
161 _ => Err(NpuError::DeviceError(
162 "Cannot shutdown device not in valid state".to_string(),
163 )),
164 }
165 }
166}
167
168impl Default for NpuDevice {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174pub struct DeviceRegistry {
176 devices: Vec<Arc<NpuDevice>>,
177}
178
179impl DeviceRegistry {
180 pub fn new() -> Self {
182 Self {
183 devices: Vec::new(),
184 }
185 }
186
187 pub fn register(&mut self, device: Arc<NpuDevice>) -> Result<u32> {
189 if self.devices.len() >= 16 {
190 return Err(NpuError::DeviceError(
191 "Maximum number of devices reached".to_string(),
192 ));
193 }
194 let device_id = self.devices.len() as u32;
195 self.devices.push(device);
196 Ok(device_id)
197 }
198
199 pub fn get_device(&self, device_id: u32) -> Result<Arc<NpuDevice>> {
201 self.devices
202 .get(device_id as usize)
203 .cloned()
204 .ok_or_else(|| NpuError::DeviceError(format!("Device {} not found", device_id)))
205 }
206
207 pub fn num_devices(&self) -> usize {
209 self.devices.len()
210 }
211}
212
213impl Default for DeviceRegistry {
214 fn default() -> Self {
215 Self::new()
216 }
217}