npu_rs/
device.rs

1use crate::error::{NpuError, Result};
2use crate::memory::MemoryPool;
3use crate::perf_monitor::PerformanceMonitor;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8/// NPU device state.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum DeviceState {
11    Uninitialized,
12    Initialized,
13    Computing,
14    Error,
15}
16
17/// NPU device information.
18#[derive(Debug, Clone)]
19pub struct DeviceInfo {
20    pub device_id: u32,
21    pub peak_throughput_tops: f32,
22    pub memory_mb: usize,
23    pub compute_units: usize,
24    pub frequency_mhz: u32,
25    pub power_tdp_watts: f32,
26    pub vendor: String,
27    pub device_name: String,
28}
29
30impl Default for DeviceInfo {
31    fn default() -> Self {
32        Self {
33            device_id: 0,
34            peak_throughput_tops: 20.0,
35            memory_mb: 512,
36            compute_units: 4,
37            frequency_mhz: 800,
38            power_tdp_watts: 5.0,
39            vendor: "RISC NPU Vendor".to_string(),
40            device_name: "20-TOPS NPU Accelerator".to_string(),
41        }
42    }
43}
44
45/// Main NPU device driver.
46pub struct NpuDevice {
47    info: DeviceInfo,
48    state: Arc<Mutex<DeviceState>>,
49    memory_pool: MemoryPool,
50    perf_monitor: Arc<PerformanceMonitor>,
51    initialized_at: Arc<Mutex<Option<SystemTime>>>,
52}
53
54impl NpuDevice {
55    /// Create a new NPU device with default configuration.
56    pub fn new() -> Self {
57        Self::with_config(DeviceInfo::default())
58    }
59
60    /// Create a new NPU device with custom configuration.
61    pub fn with_config(info: DeviceInfo) -> Self {
62        Self {
63            info: info.clone(),
64            state: Arc::new(Mutex::new(DeviceState::Uninitialized)),
65            memory_pool: MemoryPool::new(info.memory_mb),
66            perf_monitor: Arc::new(PerformanceMonitor::new()),
67            initialized_at: Arc::new(Mutex::new(None)),
68        }
69    }
70
71    /// Initialize the NPU device.
72    pub fn initialize(&self) -> Result<()> {
73        let mut state = self.state.lock();
74
75        match *state {
76            DeviceState::Uninitialized => {
77                *state = DeviceState::Initialized;
78                *self.initialized_at.lock() = Some(SystemTime::now());
79                Ok(())
80            }
81            DeviceState::Initialized => {
82                Err(NpuError::InitializationFailed(
83                    "Device already initialized".to_string(),
84                ))
85            }
86            DeviceState::Error => {
87                Err(NpuError::InitializationFailed(
88                    "Device in error state".to_string(),
89                ))
90            }
91            _ => Err(NpuError::InitializationFailed(
92                "Invalid state transition".to_string(),
93            )),
94        }
95    }
96
97    /// Reset the device.
98    pub fn reset(&self) -> Result<()> {
99        let mut state = self.state.lock();
100        self.perf_monitor.reset();
101        *state = DeviceState::Initialized;
102        Ok(())
103    }
104
105    /// Get device information.
106    pub fn get_info(&self) -> DeviceInfo {
107        self.info.clone()
108    }
109
110    /// Get current device state.
111    pub fn get_state(&self) -> DeviceState {
112        *self.state.lock()
113    }
114
115    /// Get memory pool.
116    pub fn get_memory_pool(&self) -> MemoryPool {
117        self.memory_pool.clone()
118    }
119
120    /// Get performance monitor.
121    pub fn get_perf_monitor(&self) -> Arc<PerformanceMonitor> {
122        Arc::clone(&self.perf_monitor)
123    }
124
125    /// Check if device is ready for computation.
126    pub fn is_ready(&self) -> bool {
127        matches!(*self.state.lock(), DeviceState::Initialized)
128    }
129
130    /// Get device status as JSON.
131    pub fn get_status_json(&self) -> serde_json::Value {
132        let state = self.get_state();
133        let memory_stats = self.memory_pool.get_manager().get_stats();
134        let perf_metrics = self.perf_monitor.get_metrics();
135
136        serde_json::json!({
137            "device_id": self.info.device_id,
138            "device_name": self.info.device_name,
139            "state": format!("{:?}", state),
140            "peak_throughput_tops": self.info.peak_throughput_tops,
141            "current_memory_mb": memory_stats.allocated_bytes / 1024 / 1024,
142            "peak_memory_mb": memory_stats.peak_bytes / 1024 / 1024,
143            "total_memory_mb": self.info.memory_mb,
144            "performance": {
145                "total_operations": perf_metrics.total_operations,
146                "total_time_ms": perf_metrics.total_time_ms,
147                "peak_power_watts": perf_metrics.peak_power_watts,
148                "throughput_gops": self.perf_monitor.get_throughput_gops(),
149            }
150        })
151    }
152
153    /// Shutdown the device.
154    pub fn shutdown(&self) -> Result<()> {
155        let mut state = self.state.lock();
156        match *state {
157            DeviceState::Initialized | DeviceState::Computing => {
158                *state = DeviceState::Uninitialized;
159                Ok(())
160            }
161            _ => Err(NpuError::DeviceError(
162                "Cannot shutdown device not in valid state".to_string(),
163            )),
164        }
165    }
166}
167
168impl Default for NpuDevice {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Global device registry for multi-device support.
175pub struct DeviceRegistry {
176    devices: Vec<Arc<NpuDevice>>,
177}
178
179impl DeviceRegistry {
180    /// Create a new device registry.
181    pub fn new() -> Self {
182        Self {
183            devices: Vec::new(),
184        }
185    }
186
187    /// Register a device.
188    pub fn register(&mut self, device: Arc<NpuDevice>) -> Result<u32> {
189        if self.devices.len() >= 16 {
190            return Err(NpuError::DeviceError(
191                "Maximum number of devices reached".to_string(),
192            ));
193        }
194        let device_id = self.devices.len() as u32;
195        self.devices.push(device);
196        Ok(device_id)
197    }
198
199    /// Get device by ID.
200    pub fn get_device(&self, device_id: u32) -> Result<Arc<NpuDevice>> {
201        self.devices
202            .get(device_id as usize)
203            .cloned()
204            .ok_or_else(|| NpuError::DeviceError(format!("Device {} not found", device_id)))
205    }
206
207    /// Get total devices.
208    pub fn num_devices(&self) -> usize {
209        self.devices.len()
210    }
211}
212
213impl Default for DeviceRegistry {
214    fn default() -> Self {
215        Self::new()
216    }
217}