1use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum DeviceError {
12 #[error("Failed to detect device capabilities: {0}")]
13 DetectionFailed(String),
14
15 #[error("Unsupported device type: {0}")]
16 UnsupportedDevice(String),
17
18 #[error("Insufficient resources: {0}")]
19 InsufficientResources(String),
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum DeviceType {
25 Edge,
27 Consumer,
29 Server,
31 GpuAccelerated,
33 Cloud,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum DeviceArch {
40 X86_64,
41 Aarch64,
42 Arm,
43 Riscv,
44 Other,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MemoryInfo {
50 pub total_bytes: u64,
52 pub available_bytes: u64,
54 pub pressure: f32,
56}
57
58impl MemoryInfo {
59 pub fn has_capacity(&self, required_bytes: u64) -> bool {
61 self.available_bytes >= required_bytes
62 }
63
64 pub fn utilization(&self) -> f32 {
66 if self.total_bytes == 0 {
67 return 0.0;
68 }
69 ((self.total_bytes - self.available_bytes) as f32 / self.total_bytes as f32) * 100.0
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CpuInfo {
76 pub logical_cores: usize,
78 pub physical_cores: usize,
80 pub arch: DeviceArch,
82 pub frequency_mhz: Option<u32>,
84}
85
86impl CpuInfo {
87 pub fn recommended_threads(&self) -> usize {
89 (self.logical_cores as f32 * 0.8).ceil() as usize
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct DeviceCapabilities {
97 pub device_type: DeviceType,
99 pub cpu: CpuInfo,
101 pub memory: MemoryInfo,
103 pub has_gpu: bool,
105 pub has_fast_storage: bool,
107 pub network_bandwidth_mbps: Option<u32>,
109}
110
111impl DeviceCapabilities {
112 pub fn detect() -> Result<Self, DeviceError> {
114 let cpu = Self::detect_cpu()?;
115 let memory = Self::detect_memory()?;
116 let device_type = Self::classify_device(&cpu, &memory);
117
118 Ok(DeviceCapabilities {
119 device_type,
120 cpu,
121 memory,
122 has_gpu: Self::detect_gpu(),
123 has_fast_storage: Self::detect_fast_storage(),
124 network_bandwidth_mbps: None, })
126 }
127
128 #[cfg(target_arch = "x86_64")]
129 fn detect_cpu() -> Result<CpuInfo, DeviceError> {
130 let logical_cores = num_cpus::get();
131 let physical_cores = num_cpus::get_physical();
132
133 Ok(CpuInfo {
134 logical_cores,
135 physical_cores,
136 arch: DeviceArch::X86_64,
137 frequency_mhz: None,
138 })
139 }
140
141 #[cfg(target_arch = "aarch64")]
142 fn detect_cpu() -> Result<CpuInfo, DeviceError> {
143 let logical_cores = num_cpus::get();
144 let physical_cores = num_cpus::get_physical();
145
146 Ok(CpuInfo {
147 logical_cores,
148 physical_cores,
149 arch: DeviceArch::Aarch64,
150 frequency_mhz: None,
151 })
152 }
153
154 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
155 fn detect_cpu() -> Result<CpuInfo, DeviceError> {
156 let logical_cores = num_cpus::get();
157 let physical_cores = num_cpus::get_physical();
158
159 Ok(CpuInfo {
160 logical_cores,
161 physical_cores,
162 arch: DeviceArch::Other,
163 frequency_mhz: None,
164 })
165 }
166
167 #[cfg(target_os = "linux")]
168 fn detect_memory() -> Result<MemoryInfo, DeviceError> {
169 use std::fs;
170
171 let meminfo = fs::read_to_string("/proc/meminfo")
172 .map_err(|e| DeviceError::DetectionFailed(format!("Failed to read meminfo: {}", e)))?;
173
174 let mut total_kb = 0u64;
175 let mut available_kb = 0u64;
176
177 for line in meminfo.lines() {
178 if line.starts_with("MemTotal:") {
179 total_kb = Self::parse_meminfo_line(line)?;
180 } else if line.starts_with("MemAvailable:") {
181 available_kb = Self::parse_meminfo_line(line)?;
182 }
183 }
184
185 let total_bytes = total_kb * 1024;
186 let available_bytes = available_kb * 1024;
187 let pressure = if total_bytes > 0 {
188 1.0 - (available_bytes as f32 / total_bytes as f32)
189 } else {
190 0.0
191 };
192
193 Ok(MemoryInfo {
194 total_bytes,
195 available_bytes,
196 pressure,
197 })
198 }
199
200 #[cfg(not(target_os = "linux"))]
201 fn detect_memory() -> Result<MemoryInfo, DeviceError> {
202 Ok(MemoryInfo {
205 total_bytes: 8 * 1024 * 1024 * 1024, available_bytes: 4 * 1024 * 1024 * 1024, pressure: 0.5,
208 })
209 }
210
211 #[cfg(target_os = "linux")]
212 fn parse_meminfo_line(line: &str) -> Result<u64, DeviceError> {
213 let parts: Vec<&str> = line.split_whitespace().collect();
214 if parts.len() >= 2 {
215 parts[1].parse().map_err(|e| {
216 DeviceError::DetectionFailed(format!("Failed to parse meminfo: {}", e))
217 })
218 } else {
219 Err(DeviceError::DetectionFailed(
220 "Invalid meminfo format".to_string(),
221 ))
222 }
223 }
224
225 fn detect_gpu() -> bool {
226 #[cfg(target_os = "linux")]
228 {
229 std::path::Path::new("/dev/dri").exists()
230 || std::path::Path::new("/dev/nvidia0").exists()
231 }
232
233 #[cfg(not(target_os = "linux"))]
234 false
235 }
236
237 fn detect_fast_storage() -> bool {
238 #[cfg(target_os = "linux")]
240 {
241 if let Ok(contents) = std::fs::read_to_string("/sys/block/sda/queue/rotational") {
242 contents.trim() == "0"
243 } else {
244 false
245 }
246 }
247
248 #[cfg(not(target_os = "linux"))]
249 false
250 }
251
252 fn classify_device(cpu: &CpuInfo, memory: &MemoryInfo) -> DeviceType {
253 let total_gb = memory.total_bytes / (1024 * 1024 * 1024);
254
255 match (cpu.logical_cores, total_gb) {
256 (cores, gb) if cores >= 16 && gb >= 32 => DeviceType::Server,
257 (cores, gb) if cores >= 8 && gb >= 16 => DeviceType::Consumer,
258 (cores, gb) if cores <= 4 || gb <= 4 => DeviceType::Edge,
259 _ => DeviceType::Consumer,
260 }
261 }
262
263 pub fn optimal_batch_size(&self, model_size_bytes: u64, item_size_bytes: u64) -> usize {
265 let usable_memory = (self.memory.available_bytes as f32 * 0.8) as u64;
267
268 let memory_for_batch = usable_memory.saturating_sub(model_size_bytes);
270
271 if memory_for_batch == 0 || item_size_bytes == 0 {
272 return 1;
273 }
274
275 let batch_size = (memory_for_batch / item_size_bytes) as usize;
277
278 batch_size.clamp(1, 1024)
280 }
281
282 pub fn recommended_workers(&self) -> usize {
284 match self.device_type {
285 DeviceType::Edge => 1.max(self.cpu.logical_cores / 2),
286 DeviceType::Consumer => self.cpu.logical_cores,
287 DeviceType::Server | DeviceType::Cloud => self.cpu.logical_cores * 2,
288 DeviceType::GpuAccelerated => self.cpu.logical_cores,
289 }
290 }
291}
292
293pub struct AdaptiveBatchSizer {
295 capabilities: Arc<DeviceCapabilities>,
296 min_batch_size: usize,
297 max_batch_size: usize,
298 target_memory_utilization: f32,
299}
300
301impl AdaptiveBatchSizer {
302 pub fn new(capabilities: Arc<DeviceCapabilities>) -> Self {
304 Self {
305 capabilities,
306 min_batch_size: 1,
307 max_batch_size: 1024,
308 target_memory_utilization: 0.7, }
310 }
311
312 pub fn with_min_batch_size(mut self, size: usize) -> Self {
314 self.min_batch_size = size;
315 self
316 }
317
318 pub fn with_max_batch_size(mut self, size: usize) -> Self {
320 self.max_batch_size = size;
321 self
322 }
323
324 pub fn with_target_utilization(mut self, utilization: f32) -> Self {
326 self.target_memory_utilization = utilization.clamp(0.1, 0.9);
327 self
328 }
329
330 pub fn calculate(&self, item_size_bytes: u64, model_size_bytes: u64) -> usize {
332 let available = (self.capabilities.memory.available_bytes as f32
333 * self.target_memory_utilization) as u64;
334 let memory_for_batch = available.saturating_sub(model_size_bytes);
335
336 if memory_for_batch == 0 || item_size_bytes == 0 {
337 return self.min_batch_size;
338 }
339
340 let batch_size = (memory_for_batch / item_size_bytes) as usize;
341 batch_size.clamp(self.min_batch_size, self.max_batch_size)
342 }
343
344 pub fn adjust_for_pressure(&self, current_batch_size: usize) -> usize {
346 let pressure = self.capabilities.memory.pressure;
347
348 if pressure > 0.9 {
349 (current_batch_size / 2).max(self.min_batch_size)
351 } else if pressure > 0.7 {
352 ((current_batch_size as f32 * 0.75) as usize).max(self.min_batch_size)
354 } else if pressure < 0.3 && current_batch_size < self.max_batch_size {
355 ((current_batch_size as f32 * 1.25) as usize).min(self.max_batch_size)
357 } else {
358 current_batch_size
359 }
360 }
361}
362
363pub struct DeviceProfiler {
365 capabilities: Arc<DeviceCapabilities>,
366}
367
368impl DeviceProfiler {
369 pub fn new(capabilities: Arc<DeviceCapabilities>) -> Self {
371 Self { capabilities }
372 }
373
374 pub fn profile_memory_bandwidth(&self) -> f64 {
376 use std::time::Instant;
377
378 let size = 10 * 1024 * 1024;
380 let mut buffer = vec![0u8; size];
381
382 let start = Instant::now();
384 for (i, item) in buffer.iter_mut().enumerate().take(size) {
385 *item = (i & 0xFF) as u8;
386 }
387 let write_duration = start.elapsed();
388
389 let start = Instant::now();
391 let mut _sum: u64 = 0;
392 for &byte in &buffer {
393 _sum += byte as u64;
394 }
395 let read_duration = start.elapsed();
396
397 let write_bandwidth = (size as f64) / write_duration.as_secs_f64() / 1e9;
399 let read_bandwidth = (size as f64) / read_duration.as_secs_f64() / 1e9;
400
401 (write_bandwidth + read_bandwidth) / 2.0
403 }
404
405 pub fn profile_compute_throughput(&self) -> f64 {
407 use std::time::Instant;
408
409 let iterations = 10_000_000;
411 let mut result = 1.0f32;
412
413 let start = Instant::now();
414 for i in 0..iterations {
415 result = result * 1.0001 + (i as f32) * 0.0001;
416 }
417 let duration = start.elapsed();
418
419 let flops = (iterations * 2) as f64 / duration.as_secs_f64();
421
422 if result < 0.0 {
424 println!("Unexpected result: {}", result);
425 }
426
427 flops
428 }
429
430 pub fn performance_tier(&self) -> DevicePerformanceTier {
432 let memory_gb = self.capabilities.memory.total_bytes / (1024 * 1024 * 1024);
433 let cores = self.capabilities.cpu.logical_cores;
434
435 match (cores, memory_gb) {
436 (c, m) if c >= 32 && m >= 64 => DevicePerformanceTier::High,
437 (c, m) if c >= 8 && m >= 16 => DevicePerformanceTier::Medium,
438 _ => DevicePerformanceTier::Low,
439 }
440 }
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
445pub enum DevicePerformanceTier {
446 Low,
447 Medium,
448 High,
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_device_detection() {
457 let caps = DeviceCapabilities::detect();
458 assert!(caps.is_ok());
459
460 let caps = caps.unwrap();
461 assert!(caps.cpu.logical_cores > 0);
462 assert!(caps.memory.total_bytes > 0);
463 }
464
465 #[test]
466 fn test_memory_info() {
467 let mem = MemoryInfo {
468 total_bytes: 8 * 1024 * 1024 * 1024,
469 available_bytes: 4 * 1024 * 1024 * 1024,
470 pressure: 0.5,
471 };
472
473 assert!(mem.has_capacity(1024 * 1024 * 1024));
474 assert!(!mem.has_capacity(5 * 1024 * 1024 * 1024));
475 assert_eq!(mem.utilization(), 50.0);
476 }
477
478 #[test]
479 fn test_cpu_info() {
480 let cpu = CpuInfo {
481 logical_cores: 8,
482 physical_cores: 4,
483 arch: DeviceArch::X86_64,
484 frequency_mhz: Some(3000),
485 };
486
487 assert_eq!(cpu.recommended_threads(), 7); }
489
490 #[test]
491 fn test_optimal_batch_size() {
492 let caps = DeviceCapabilities {
493 device_type: DeviceType::Consumer,
494 cpu: CpuInfo {
495 logical_cores: 8,
496 physical_cores: 4,
497 arch: DeviceArch::X86_64,
498 frequency_mhz: Some(3000),
499 },
500 memory: MemoryInfo {
501 total_bytes: 16 * 1024 * 1024 * 1024,
502 available_bytes: 8 * 1024 * 1024 * 1024,
503 pressure: 0.5,
504 },
505 has_gpu: false,
506 has_fast_storage: true,
507 network_bandwidth_mbps: Some(1000),
508 };
509
510 let model_size = 1024 * 1024 * 1024; let item_size = 1024 * 1024; let batch_size = caps.optimal_batch_size(model_size, item_size);
514 assert!(batch_size > 0);
515 assert!(batch_size <= 1024);
516 }
517
518 #[test]
519 fn test_adaptive_batch_sizer() {
520 let caps = Arc::new(DeviceCapabilities {
521 device_type: DeviceType::Consumer,
522 cpu: CpuInfo {
523 logical_cores: 8,
524 physical_cores: 4,
525 arch: DeviceArch::X86_64,
526 frequency_mhz: Some(3000),
527 },
528 memory: MemoryInfo {
529 total_bytes: 16 * 1024 * 1024 * 1024,
530 available_bytes: 8 * 1024 * 1024 * 1024,
531 pressure: 0.5,
532 },
533 has_gpu: false,
534 has_fast_storage: true,
535 network_bandwidth_mbps: Some(1000),
536 });
537
538 let sizer = AdaptiveBatchSizer::new(caps)
539 .with_min_batch_size(4)
540 .with_max_batch_size(256);
541
542 let batch_size = sizer.calculate(1024 * 1024, 512 * 1024 * 1024);
543 assert!(batch_size >= 4);
544 assert!(batch_size <= 256);
545 }
546
547 #[test]
548 fn test_pressure_adjustment() {
549 let caps_low_pressure = Arc::new(DeviceCapabilities {
550 device_type: DeviceType::Consumer,
551 cpu: CpuInfo {
552 logical_cores: 8,
553 physical_cores: 4,
554 arch: DeviceArch::X86_64,
555 frequency_mhz: Some(3000),
556 },
557 memory: MemoryInfo {
558 total_bytes: 16 * 1024 * 1024 * 1024,
559 available_bytes: 12 * 1024 * 1024 * 1024,
560 pressure: 0.25,
561 },
562 has_gpu: false,
563 has_fast_storage: true,
564 network_bandwidth_mbps: Some(1000),
565 });
566
567 let sizer = AdaptiveBatchSizer::new(caps_low_pressure)
568 .with_min_batch_size(4)
569 .with_max_batch_size(256);
570
571 let adjusted = sizer.adjust_for_pressure(32);
572 assert!(adjusted >= 32); let caps_high_pressure = Arc::new(DeviceCapabilities {
575 device_type: DeviceType::Consumer,
576 cpu: CpuInfo {
577 logical_cores: 8,
578 physical_cores: 4,
579 arch: DeviceArch::X86_64,
580 frequency_mhz: Some(3000),
581 },
582 memory: MemoryInfo {
583 total_bytes: 16 * 1024 * 1024 * 1024,
584 available_bytes: 1024 * 1024 * 1024,
585 pressure: 0.95,
586 },
587 has_gpu: false,
588 has_fast_storage: true,
589 network_bandwidth_mbps: Some(1000),
590 });
591
592 let sizer = AdaptiveBatchSizer::new(caps_high_pressure)
593 .with_min_batch_size(4)
594 .with_max_batch_size(256);
595
596 let adjusted = sizer.adjust_for_pressure(32);
597 assert!(adjusted < 32); }
599
600 #[test]
601 fn test_device_profiler() {
602 let caps = Arc::new(DeviceCapabilities::detect().unwrap());
603 let profiler = DeviceProfiler::new(caps);
604
605 let bandwidth = profiler.profile_memory_bandwidth();
606 assert!(bandwidth > 0.0);
607
608 let throughput = profiler.profile_compute_throughput();
609 assert!(throughput > 0.0);
610
611 let tier = profiler.performance_tier();
612 assert!(matches!(
613 tier,
614 DevicePerformanceTier::Low
615 | DevicePerformanceTier::Medium
616 | DevicePerformanceTier::High
617 ));
618 }
619}