ipfrs_tensorlogic/
device.rs

1//! Heterogeneous Device Support
2//!
3//! This module provides device capability detection and adaptive resource management
4//! for running tensor operations across diverse hardware (edge to cloud).
5
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub enum DeviceError {
12    #[error("Failed to detect device capabilities: {0}")]
13    DetectionFailed(String),
14
15    #[error("Unsupported device type: {0}")]
16    UnsupportedDevice(String),
17
18    #[error("Insufficient resources: {0}")]
19    InsufficientResources(String),
20}
21
22/// Device type classification
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum DeviceType {
25    /// Edge device (IoT, mobile)
26    Edge,
27    /// Consumer device (laptop, desktop)
28    Consumer,
29    /// Server-class device
30    Server,
31    /// GPU-accelerated device
32    GpuAccelerated,
33    /// Cloud instance
34    Cloud,
35}
36
37/// Device architecture
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum DeviceArch {
40    X86_64,
41    Aarch64,
42    Arm,
43    Riscv,
44    Other,
45}
46
47/// Memory tier information
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MemoryInfo {
50    /// Total system memory in bytes
51    pub total_bytes: u64,
52    /// Available memory in bytes
53    pub available_bytes: u64,
54    /// Memory pressure (0.0 = no pressure, 1.0 = critical)
55    pub pressure: f32,
56}
57
58impl MemoryInfo {
59    /// Check if device has sufficient memory for operation
60    pub fn has_capacity(&self, required_bytes: u64) -> bool {
61        self.available_bytes >= required_bytes
62    }
63
64    /// Get memory utilization percentage
65    pub fn utilization(&self) -> f32 {
66        if self.total_bytes == 0 {
67            return 0.0;
68        }
69        ((self.total_bytes - self.available_bytes) as f32 / self.total_bytes as f32) * 100.0
70    }
71}
72
73/// CPU information
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CpuInfo {
76    /// Number of logical cores
77    pub logical_cores: usize,
78    /// Number of physical cores
79    pub physical_cores: usize,
80    /// CPU architecture
81    pub arch: DeviceArch,
82    /// CPU frequency in MHz (if available)
83    pub frequency_mhz: Option<u32>,
84}
85
86impl CpuInfo {
87    /// Get recommended thread count for parallel operations
88    pub fn recommended_threads(&self) -> usize {
89        // Use 80% of logical cores to leave room for system
90        (self.logical_cores as f32 * 0.8).ceil() as usize
91    }
92}
93
94/// Device capabilities
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct DeviceCapabilities {
97    /// Device type
98    pub device_type: DeviceType,
99    /// CPU information
100    pub cpu: CpuInfo,
101    /// Memory information
102    pub memory: MemoryInfo,
103    /// Has GPU acceleration
104    pub has_gpu: bool,
105    /// Has fast storage (SSD)
106    pub has_fast_storage: bool,
107    /// Network bandwidth estimate (Mbps)
108    pub network_bandwidth_mbps: Option<u32>,
109}
110
111impl DeviceCapabilities {
112    /// Detect device capabilities
113    pub fn detect() -> Result<Self, DeviceError> {
114        let cpu = Self::detect_cpu()?;
115        let memory = Self::detect_memory()?;
116        let device_type = Self::classify_device(&cpu, &memory);
117
118        Ok(DeviceCapabilities {
119            device_type,
120            cpu,
121            memory,
122            has_gpu: Self::detect_gpu(),
123            has_fast_storage: Self::detect_fast_storage(),
124            network_bandwidth_mbps: None, // Would need network probing
125        })
126    }
127
128    #[cfg(target_arch = "x86_64")]
129    fn detect_cpu() -> Result<CpuInfo, DeviceError> {
130        let logical_cores = num_cpus::get();
131        let physical_cores = num_cpus::get_physical();
132
133        Ok(CpuInfo {
134            logical_cores,
135            physical_cores,
136            arch: DeviceArch::X86_64,
137            frequency_mhz: None,
138        })
139    }
140
141    #[cfg(target_arch = "aarch64")]
142    fn detect_cpu() -> Result<CpuInfo, DeviceError> {
143        let logical_cores = num_cpus::get();
144        let physical_cores = num_cpus::get_physical();
145
146        Ok(CpuInfo {
147            logical_cores,
148            physical_cores,
149            arch: DeviceArch::Aarch64,
150            frequency_mhz: None,
151        })
152    }
153
154    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
155    fn detect_cpu() -> Result<CpuInfo, DeviceError> {
156        let logical_cores = num_cpus::get();
157        let physical_cores = num_cpus::get_physical();
158
159        Ok(CpuInfo {
160            logical_cores,
161            physical_cores,
162            arch: DeviceArch::Other,
163            frequency_mhz: None,
164        })
165    }
166
167    #[cfg(target_os = "linux")]
168    fn detect_memory() -> Result<MemoryInfo, DeviceError> {
169        use std::fs;
170
171        let meminfo = fs::read_to_string("/proc/meminfo")
172            .map_err(|e| DeviceError::DetectionFailed(format!("Failed to read meminfo: {}", e)))?;
173
174        let mut total_kb = 0u64;
175        let mut available_kb = 0u64;
176
177        for line in meminfo.lines() {
178            if line.starts_with("MemTotal:") {
179                total_kb = Self::parse_meminfo_line(line)?;
180            } else if line.starts_with("MemAvailable:") {
181                available_kb = Self::parse_meminfo_line(line)?;
182            }
183        }
184
185        let total_bytes = total_kb * 1024;
186        let available_bytes = available_kb * 1024;
187        let pressure = if total_bytes > 0 {
188            1.0 - (available_bytes as f32 / total_bytes as f32)
189        } else {
190            0.0
191        };
192
193        Ok(MemoryInfo {
194            total_bytes,
195            available_bytes,
196            pressure,
197        })
198    }
199
200    #[cfg(not(target_os = "linux"))]
201    fn detect_memory() -> Result<MemoryInfo, DeviceError> {
202        // Fallback for non-Linux systems
203        // Use sysinfo crate or platform-specific APIs
204        Ok(MemoryInfo {
205            total_bytes: 8 * 1024 * 1024 * 1024,     // Default 8GB
206            available_bytes: 4 * 1024 * 1024 * 1024, // Default 4GB available
207            pressure: 0.5,
208        })
209    }
210
211    #[cfg(target_os = "linux")]
212    fn parse_meminfo_line(line: &str) -> Result<u64, DeviceError> {
213        let parts: Vec<&str> = line.split_whitespace().collect();
214        if parts.len() >= 2 {
215            parts[1].parse().map_err(|e| {
216                DeviceError::DetectionFailed(format!("Failed to parse meminfo: {}", e))
217            })
218        } else {
219            Err(DeviceError::DetectionFailed(
220                "Invalid meminfo format".to_string(),
221            ))
222        }
223    }
224
225    fn detect_gpu() -> bool {
226        // Simple heuristic: check for common GPU driver files
227        #[cfg(target_os = "linux")]
228        {
229            std::path::Path::new("/dev/dri").exists()
230                || std::path::Path::new("/dev/nvidia0").exists()
231        }
232
233        #[cfg(not(target_os = "linux"))]
234        false
235    }
236
237    fn detect_fast_storage() -> bool {
238        // Heuristic: assume SSD if rotational is 0 on Linux
239        #[cfg(target_os = "linux")]
240        {
241            if let Ok(contents) = std::fs::read_to_string("/sys/block/sda/queue/rotational") {
242                contents.trim() == "0"
243            } else {
244                false
245            }
246        }
247
248        #[cfg(not(target_os = "linux"))]
249        false
250    }
251
252    fn classify_device(cpu: &CpuInfo, memory: &MemoryInfo) -> DeviceType {
253        let total_gb = memory.total_bytes / (1024 * 1024 * 1024);
254
255        match (cpu.logical_cores, total_gb) {
256            (cores, gb) if cores >= 16 && gb >= 32 => DeviceType::Server,
257            (cores, gb) if cores >= 8 && gb >= 16 => DeviceType::Consumer,
258            (cores, gb) if cores <= 4 || gb <= 4 => DeviceType::Edge,
259            _ => DeviceType::Consumer,
260        }
261    }
262
263    /// Calculate optimal batch size based on available memory and model size
264    pub fn optimal_batch_size(&self, model_size_bytes: u64, item_size_bytes: u64) -> usize {
265        // Reserve 20% of available memory for overhead
266        let usable_memory = (self.memory.available_bytes as f32 * 0.8) as u64;
267
268        // Account for model size
269        let memory_for_batch = usable_memory.saturating_sub(model_size_bytes);
270
271        if memory_for_batch == 0 || item_size_bytes == 0 {
272            return 1;
273        }
274
275        // Calculate batch size
276        let batch_size = (memory_for_batch / item_size_bytes) as usize;
277
278        // Clamp to reasonable range
279        batch_size.clamp(1, 1024)
280    }
281
282    /// Get recommended worker count for parallel processing
283    pub fn recommended_workers(&self) -> usize {
284        match self.device_type {
285            DeviceType::Edge => 1.max(self.cpu.logical_cores / 2),
286            DeviceType::Consumer => self.cpu.logical_cores,
287            DeviceType::Server | DeviceType::Cloud => self.cpu.logical_cores * 2,
288            DeviceType::GpuAccelerated => self.cpu.logical_cores,
289        }
290    }
291}
292
293/// Adaptive batch size calculator
294pub struct AdaptiveBatchSizer {
295    capabilities: Arc<DeviceCapabilities>,
296    min_batch_size: usize,
297    max_batch_size: usize,
298    target_memory_utilization: f32,
299}
300
301impl AdaptiveBatchSizer {
302    /// Create a new adaptive batch sizer
303    pub fn new(capabilities: Arc<DeviceCapabilities>) -> Self {
304        Self {
305            capabilities,
306            min_batch_size: 1,
307            max_batch_size: 1024,
308            target_memory_utilization: 0.7, // Target 70% memory utilization
309        }
310    }
311
312    /// Set minimum batch size
313    pub fn with_min_batch_size(mut self, size: usize) -> Self {
314        self.min_batch_size = size;
315        self
316    }
317
318    /// Set maximum batch size
319    pub fn with_max_batch_size(mut self, size: usize) -> Self {
320        self.max_batch_size = size;
321        self
322    }
323
324    /// Set target memory utilization (0.0-1.0)
325    pub fn with_target_utilization(mut self, utilization: f32) -> Self {
326        self.target_memory_utilization = utilization.clamp(0.1, 0.9);
327        self
328    }
329
330    /// Calculate adaptive batch size
331    pub fn calculate(&self, item_size_bytes: u64, model_size_bytes: u64) -> usize {
332        let available = (self.capabilities.memory.available_bytes as f32
333            * self.target_memory_utilization) as u64;
334        let memory_for_batch = available.saturating_sub(model_size_bytes);
335
336        if memory_for_batch == 0 || item_size_bytes == 0 {
337            return self.min_batch_size;
338        }
339
340        let batch_size = (memory_for_batch / item_size_bytes) as usize;
341        batch_size.clamp(self.min_batch_size, self.max_batch_size)
342    }
343
344    /// Adjust batch size based on current memory pressure
345    pub fn adjust_for_pressure(&self, current_batch_size: usize) -> usize {
346        let pressure = self.capabilities.memory.pressure;
347
348        if pressure > 0.9 {
349            // Critical pressure: halve batch size
350            (current_batch_size / 2).max(self.min_batch_size)
351        } else if pressure > 0.7 {
352            // High pressure: reduce by 25%
353            ((current_batch_size as f32 * 0.75) as usize).max(self.min_batch_size)
354        } else if pressure < 0.3 && current_batch_size < self.max_batch_size {
355            // Low pressure: increase by 25%
356            ((current_batch_size as f32 * 1.25) as usize).min(self.max_batch_size)
357        } else {
358            current_batch_size
359        }
360    }
361}
362
363/// Device profiler for performance optimization
364pub struct DeviceProfiler {
365    capabilities: Arc<DeviceCapabilities>,
366}
367
368impl DeviceProfiler {
369    /// Create a new device profiler
370    pub fn new(capabilities: Arc<DeviceCapabilities>) -> Self {
371        Self { capabilities }
372    }
373
374    /// Profile memory bandwidth (GB/s)
375    pub fn profile_memory_bandwidth(&self) -> f64 {
376        use std::time::Instant;
377
378        // Allocate test buffer (10 MB)
379        let size = 10 * 1024 * 1024;
380        let mut buffer = vec![0u8; size];
381
382        // Sequential write test
383        let start = Instant::now();
384        for (i, item) in buffer.iter_mut().enumerate().take(size) {
385            *item = (i & 0xFF) as u8;
386        }
387        let write_duration = start.elapsed();
388
389        // Sequential read test
390        let start = Instant::now();
391        let mut _sum: u64 = 0;
392        for &byte in &buffer {
393            _sum += byte as u64;
394        }
395        let read_duration = start.elapsed();
396
397        // Calculate bandwidth (GB/s)
398        let write_bandwidth = (size as f64) / write_duration.as_secs_f64() / 1e9;
399        let read_bandwidth = (size as f64) / read_duration.as_secs_f64() / 1e9;
400
401        // Return average
402        (write_bandwidth + read_bandwidth) / 2.0
403    }
404
405    /// Profile compute throughput (FLOPS)
406    pub fn profile_compute_throughput(&self) -> f64 {
407        use std::time::Instant;
408
409        // Simple FP32 FLOPS test
410        let iterations = 10_000_000;
411        let mut result = 1.0f32;
412
413        let start = Instant::now();
414        for i in 0..iterations {
415            result = result * 1.0001 + (i as f32) * 0.0001;
416        }
417        let duration = start.elapsed();
418
419        // Calculate FLOPS (2 operations per iteration: multiply and add)
420        let flops = (iterations * 2) as f64 / duration.as_secs_f64();
421
422        // Prevent optimization from removing computation
423        if result < 0.0 {
424            println!("Unexpected result: {}", result);
425        }
426
427        flops
428    }
429
430    /// Get device performance tier
431    pub fn performance_tier(&self) -> DevicePerformanceTier {
432        let memory_gb = self.capabilities.memory.total_bytes / (1024 * 1024 * 1024);
433        let cores = self.capabilities.cpu.logical_cores;
434
435        match (cores, memory_gb) {
436            (c, m) if c >= 32 && m >= 64 => DevicePerformanceTier::High,
437            (c, m) if c >= 8 && m >= 16 => DevicePerformanceTier::Medium,
438            _ => DevicePerformanceTier::Low,
439        }
440    }
441}
442
443/// Device performance tier
444#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
445pub enum DevicePerformanceTier {
446    Low,
447    Medium,
448    High,
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_device_detection() {
457        let caps = DeviceCapabilities::detect();
458        assert!(caps.is_ok());
459
460        let caps = caps.unwrap();
461        assert!(caps.cpu.logical_cores > 0);
462        assert!(caps.memory.total_bytes > 0);
463    }
464
465    #[test]
466    fn test_memory_info() {
467        let mem = MemoryInfo {
468            total_bytes: 8 * 1024 * 1024 * 1024,
469            available_bytes: 4 * 1024 * 1024 * 1024,
470            pressure: 0.5,
471        };
472
473        assert!(mem.has_capacity(1024 * 1024 * 1024));
474        assert!(!mem.has_capacity(5 * 1024 * 1024 * 1024));
475        assert_eq!(mem.utilization(), 50.0);
476    }
477
478    #[test]
479    fn test_cpu_info() {
480        let cpu = CpuInfo {
481            logical_cores: 8,
482            physical_cores: 4,
483            arch: DeviceArch::X86_64,
484            frequency_mhz: Some(3000),
485        };
486
487        assert_eq!(cpu.recommended_threads(), 7); // 80% of 8 = 6.4, ceil to 7
488    }
489
490    #[test]
491    fn test_optimal_batch_size() {
492        let caps = DeviceCapabilities {
493            device_type: DeviceType::Consumer,
494            cpu: CpuInfo {
495                logical_cores: 8,
496                physical_cores: 4,
497                arch: DeviceArch::X86_64,
498                frequency_mhz: Some(3000),
499            },
500            memory: MemoryInfo {
501                total_bytes: 16 * 1024 * 1024 * 1024,
502                available_bytes: 8 * 1024 * 1024 * 1024,
503                pressure: 0.5,
504            },
505            has_gpu: false,
506            has_fast_storage: true,
507            network_bandwidth_mbps: Some(1000),
508        };
509
510        let model_size = 1024 * 1024 * 1024; // 1GB model
511        let item_size = 1024 * 1024; // 1MB per item
512
513        let batch_size = caps.optimal_batch_size(model_size, item_size);
514        assert!(batch_size > 0);
515        assert!(batch_size <= 1024);
516    }
517
518    #[test]
519    fn test_adaptive_batch_sizer() {
520        let caps = Arc::new(DeviceCapabilities {
521            device_type: DeviceType::Consumer,
522            cpu: CpuInfo {
523                logical_cores: 8,
524                physical_cores: 4,
525                arch: DeviceArch::X86_64,
526                frequency_mhz: Some(3000),
527            },
528            memory: MemoryInfo {
529                total_bytes: 16 * 1024 * 1024 * 1024,
530                available_bytes: 8 * 1024 * 1024 * 1024,
531                pressure: 0.5,
532            },
533            has_gpu: false,
534            has_fast_storage: true,
535            network_bandwidth_mbps: Some(1000),
536        });
537
538        let sizer = AdaptiveBatchSizer::new(caps)
539            .with_min_batch_size(4)
540            .with_max_batch_size(256);
541
542        let batch_size = sizer.calculate(1024 * 1024, 512 * 1024 * 1024);
543        assert!(batch_size >= 4);
544        assert!(batch_size <= 256);
545    }
546
547    #[test]
548    fn test_pressure_adjustment() {
549        let caps_low_pressure = Arc::new(DeviceCapabilities {
550            device_type: DeviceType::Consumer,
551            cpu: CpuInfo {
552                logical_cores: 8,
553                physical_cores: 4,
554                arch: DeviceArch::X86_64,
555                frequency_mhz: Some(3000),
556            },
557            memory: MemoryInfo {
558                total_bytes: 16 * 1024 * 1024 * 1024,
559                available_bytes: 12 * 1024 * 1024 * 1024,
560                pressure: 0.25,
561            },
562            has_gpu: false,
563            has_fast_storage: true,
564            network_bandwidth_mbps: Some(1000),
565        });
566
567        let sizer = AdaptiveBatchSizer::new(caps_low_pressure)
568            .with_min_batch_size(4)
569            .with_max_batch_size(256);
570
571        let adjusted = sizer.adjust_for_pressure(32);
572        assert!(adjusted >= 32); // Should increase under low pressure
573
574        let caps_high_pressure = Arc::new(DeviceCapabilities {
575            device_type: DeviceType::Consumer,
576            cpu: CpuInfo {
577                logical_cores: 8,
578                physical_cores: 4,
579                arch: DeviceArch::X86_64,
580                frequency_mhz: Some(3000),
581            },
582            memory: MemoryInfo {
583                total_bytes: 16 * 1024 * 1024 * 1024,
584                available_bytes: 1024 * 1024 * 1024,
585                pressure: 0.95,
586            },
587            has_gpu: false,
588            has_fast_storage: true,
589            network_bandwidth_mbps: Some(1000),
590        });
591
592        let sizer = AdaptiveBatchSizer::new(caps_high_pressure)
593            .with_min_batch_size(4)
594            .with_max_batch_size(256);
595
596        let adjusted = sizer.adjust_for_pressure(32);
597        assert!(adjusted < 32); // Should decrease under high pressure
598    }
599
600    #[test]
601    fn test_device_profiler() {
602        let caps = Arc::new(DeviceCapabilities::detect().unwrap());
603        let profiler = DeviceProfiler::new(caps);
604
605        let bandwidth = profiler.profile_memory_bandwidth();
606        assert!(bandwidth > 0.0);
607
608        let throughput = profiler.profile_compute_throughput();
609        assert!(throughput > 0.0);
610
611        let tier = profiler.performance_tier();
612        assert!(matches!(
613            tier,
614            DevicePerformanceTier::Low
615                | DevicePerformanceTier::Medium
616                | DevicePerformanceTier::High
617        ));
618    }
619}