Skip to main content

ronn_providers/cpu/
provider.rs

1//! CPU execution provider implementation.
2//!
3//! This module provides a complete CPU execution provider with SIMD optimizations,
4//! multi-threading support, and NUMA awareness.
5
6use std::collections::HashSet;
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use rayon::{ThreadPool, ThreadPoolBuilder};
11use ronn_core::{
12    CompiledKernel, DataType, ExecutionProvider, MemoryType, OperatorSpec, PerformanceProfile,
13    ProviderCapability, ProviderConfig, ProviderId, ResourceRequirements, SubGraph,
14    TensorAllocator,
15};
16use tracing::{debug, info, warn};
17
18use super::{
19    allocator::{create_cpu_allocator, create_numa_cpu_allocator},
20    kernels::CpuKernel,
21    simd::{SimdCapabilities, detect_simd_capabilities},
22};
23
24/// CPU execution provider with SIMD optimizations and multi-threading.
25pub struct CpuExecutionProvider {
26    /// Provider configuration.
27    config: CpuProviderConfig,
28    /// SIMD capabilities detected at initialization.
29    simd_capabilities: SimdCapabilities,
30    /// Thread pool for parallel execution.
31    thread_pool: ThreadPool,
32    /// Memory allocator for this provider.
33    allocator: Arc<dyn TensorAllocator>,
34    /// Set of supported operations.
35    supported_ops: HashSet<String>,
36}
37
38/// Configuration for CPU execution provider.
39#[derive(Debug, Clone)]
40pub struct CpuProviderConfig {
41    /// Number of worker threads (None = auto-detect).
42    pub thread_count: Option<usize>,
43    /// Memory limit in bytes (None = no limit).
44    pub memory_limit: Option<usize>,
45    /// NUMA node preference (-1 = no preference).
46    pub numa_node: i32,
47    /// Enable SIMD optimizations.
48    pub enable_simd: bool,
49    /// Enable operator fusion.
50    pub enable_fusion: bool,
51    /// Thread pool name for debugging.
52    pub thread_pool_name: String,
53}
54
55impl Default for CpuProviderConfig {
56    fn default() -> Self {
57        Self {
58            thread_count: None,  // Auto-detect based on CPU cores
59            memory_limit: None,  // No memory limit
60            numa_node: -1,       // No NUMA preference
61            enable_simd: true,   // Enable SIMD by default
62            enable_fusion: true, // Enable operator fusion
63            thread_pool_name: "cpu-provider".to_string(),
64        }
65    }
66}
67
68impl CpuExecutionProvider {
69    /// Create a new CPU execution provider with default configuration.
70    pub fn new() -> Result<Self> {
71        Self::with_config(CpuProviderConfig::default())
72    }
73
74    /// Create a CPU execution provider with custom configuration.
75    pub fn with_config(config: CpuProviderConfig) -> Result<Self> {
76        let simd_capabilities = if config.enable_simd {
77            detect_simd_capabilities()
78        } else {
79            SimdCapabilities::default() // Disabled SIMD
80        };
81
82        info!("Detected SIMD capabilities: {:?}", simd_capabilities);
83
84        // Determine thread count
85        let thread_count = config.thread_count.unwrap_or_else(|| {
86            let cores = num_cpus::get();
87            // Leave one core for system tasks
88            (cores - 1).max(1)
89        });
90
91        // Create thread pool
92        let thread_pool_name = config.thread_pool_name.clone();
93        let thread_pool = ThreadPoolBuilder::new()
94            .num_threads(thread_count)
95            .thread_name(move |i| format!("{}-worker-{}", thread_pool_name, i))
96            .build()
97            .map_err(|e| anyhow!("Failed to create thread pool: {}", e))?;
98
99        info!("Created CPU thread pool with {} threads", thread_count);
100
101        // Create allocator (NUMA-aware if specified)
102        let allocator: Arc<dyn TensorAllocator> = if config.numa_node >= 0 {
103            create_numa_cpu_allocator(config.numa_node)
104        } else {
105            create_cpu_allocator()
106        };
107
108        // Define supported operations
109        let mut supported_ops = HashSet::new();
110
111        // Basic arithmetic operations
112        supported_ops.insert("Add".to_string());
113        supported_ops.insert("Sub".to_string());
114        supported_ops.insert("Mul".to_string());
115        supported_ops.insert("Div".to_string());
116
117        // Matrix operations
118        supported_ops.insert("MatMul".to_string());
119        supported_ops.insert("Gemm".to_string());
120
121        // Shape operations
122        supported_ops.insert("Reshape".to_string());
123        supported_ops.insert("Transpose".to_string());
124        supported_ops.insert("Flatten".to_string());
125        supported_ops.insert("Squeeze".to_string());
126        supported_ops.insert("Unsqueeze".to_string());
127
128        // Reduction operations
129        supported_ops.insert("Sum".to_string());
130        supported_ops.insert("Mean".to_string());
131        supported_ops.insert("Max".to_string());
132        supported_ops.insert("Min".to_string());
133        supported_ops.insert("ArgMax".to_string());
134        supported_ops.insert("ArgMin".to_string());
135
136        // Activation functions
137        supported_ops.insert("ReLU".to_string());
138        supported_ops.insert("Sigmoid".to_string());
139        supported_ops.insert("Tanh".to_string());
140        supported_ops.insert("Softmax".to_string());
141
142        // Convolution operations (basic support)
143        supported_ops.insert("Conv".to_string());
144        supported_ops.insert("MaxPool".to_string());
145        supported_ops.insert("AveragePool".to_string());
146
147        // Normalization
148        supported_ops.insert("BatchNormalization".to_string());
149
150        // Utility operations
151        supported_ops.insert("Concat".to_string());
152        supported_ops.insert("Split".to_string());
153        supported_ops.insert("Slice".to_string());
154        supported_ops.insert("Gather".to_string());
155
156        info!(
157            "CPU provider supports {} operation types",
158            supported_ops.len()
159        );
160
161        Ok(Self {
162            config,
163            simd_capabilities,
164            thread_pool,
165            allocator,
166            supported_ops,
167        })
168    }
169
170    /// Get the current configuration.
171    pub fn get_config(&self) -> &CpuProviderConfig {
172        &self.config
173    }
174
175    /// Get SIMD capabilities.
176    pub fn get_simd_capabilities(&self) -> &SimdCapabilities {
177        &self.simd_capabilities
178    }
179
180    /// Get the thread pool.
181    pub fn get_thread_pool(&self) -> &ThreadPool {
182        &self.thread_pool
183    }
184
185    /// Check if an operation type is supported.
186    pub fn supports_operation(&self, op_type: &str) -> bool {
187        self.supported_ops.contains(op_type)
188    }
189
190    /// Estimate execution cost for an operation (for provider selection).
191    pub fn estimate_cost(&self, op_spec: &OperatorSpec) -> f64 {
192        // Simple cost estimation based on operation type
193        // In practice, this would consider input sizes, CPU load, etc.
194        match op_spec.op_type.as_str() {
195            "Add" | "Sub" | "Mul" | "Div" => 1.0, // Very fast
196            "ReLU" | "Sigmoid" | "Tanh" => 2.0,   // Fast
197            "MatMul" | "Gemm" => 10.0,            // Medium cost
198            "Conv" => 20.0,                       // Higher cost
199            "BatchNormalization" => 5.0,          // Medium-low cost
200            "Softmax" => 8.0,                     // Medium cost
201            _ => 1.0,                             // Default cost
202        }
203    }
204}
205
206impl Default for CpuExecutionProvider {
207    fn default() -> Self {
208        Self::new().expect("Failed to create default CPU provider")
209    }
210}
211
212impl ExecutionProvider for CpuExecutionProvider {
213    fn provider_id(&self) -> ProviderId {
214        ProviderId::CPU
215    }
216
217    fn get_capability(&self) -> ProviderCapability {
218        // Build CPU features list
219        let mut cpu_features = Vec::new();
220
221        if self.simd_capabilities.sse2 {
222            cpu_features.push("sse2".to_string());
223        }
224        if self.simd_capabilities.sse41 {
225            cpu_features.push("sse4.1".to_string());
226        }
227        if self.simd_capabilities.avx {
228            cpu_features.push("avx".to_string());
229        }
230        if self.simd_capabilities.avx2 {
231            cpu_features.push("avx2".to_string());
232        }
233        if self.simd_capabilities.avx512f {
234            cpu_features.push("avx512f".to_string());
235        }
236        if self.simd_capabilities.fma {
237            cpu_features.push("fma".to_string());
238        }
239
240        ProviderCapability {
241            supported_ops: self.supported_ops.clone(),
242            data_types: vec![
243                DataType::F32,
244                DataType::F16,
245                DataType::F64,
246                DataType::I8,
247                DataType::I32,
248                DataType::U8,
249                DataType::U32,
250                DataType::Bool,
251            ],
252            memory_types: vec![MemoryType::SystemRAM],
253            performance_profile: PerformanceProfile::CPU,
254            resource_requirements: ResourceRequirements {
255                min_memory_bytes: Some(64 * 1024 * 1024), // 64MB minimum
256                cpu_features,
257                gpu_memory_bytes: None,
258            },
259        }
260    }
261
262    fn can_handle(&self, operators: &[OperatorSpec]) -> Vec<bool> {
263        operators
264            .iter()
265            .map(|op| self.supports_operation(&op.op_type))
266            .collect()
267    }
268
269    fn compile_subgraph(&self, subgraph: SubGraph) -> Result<Box<dyn CompiledKernel>> {
270        debug!("Compiling subgraph with {} nodes", subgraph.nodes.len());
271
272        // Validate that all operations are supported
273        for node in &subgraph.nodes {
274            if !self.supports_operation(&node.op_type) {
275                return Err(anyhow!(
276                    "Unsupported operation '{}' in subgraph",
277                    node.op_type
278                ));
279            }
280        }
281
282        // Compile the kernel
283        let kernel = CpuKernel::compile(subgraph, self.simd_capabilities.clone())?;
284
285        debug!("Successfully compiled CPU kernel");
286
287        Ok(Box::new(kernel))
288    }
289
290    fn get_allocator(&self) -> Arc<dyn TensorAllocator> {
291        self.allocator.clone()
292    }
293
294    fn configure(&mut self, config: ProviderConfig) -> Result<()> {
295        // Update thread count if specified
296        if let Some(thread_count) = config.thread_count {
297            if thread_count != self.thread_pool.current_num_threads() {
298                warn!(
299                    "Thread count change requested ({} -> {}), but requires provider recreation",
300                    self.thread_pool.current_num_threads(),
301                    thread_count
302                );
303                // Would need to recreate the thread pool in a real implementation
304            }
305        }
306
307        // Update memory limit
308        if let Some(memory_limit) = config.memory_limit {
309            self.config.memory_limit = Some(memory_limit);
310            info!("Updated memory limit to {} bytes", memory_limit);
311        }
312
313        // Handle custom options
314        for (key, value) in &config.custom_options {
315            match key.as_str() {
316                "numa_node" => {
317                    if let Ok(numa_node) = value.parse::<i32>() {
318                        self.config.numa_node = numa_node;
319                        info!("Updated NUMA node preference to {}", numa_node);
320                        // Would need to recreate allocator in a real implementation
321                    }
322                }
323                "enable_simd" => {
324                    if let Ok(enable_simd) = value.parse::<bool>() {
325                        self.config.enable_simd = enable_simd;
326                        info!("Updated SIMD enablement to {}", enable_simd);
327                    }
328                }
329                "enable_fusion" => {
330                    if let Ok(enable_fusion) = value.parse::<bool>() {
331                        self.config.enable_fusion = enable_fusion;
332                        info!("Updated fusion enablement to {}", enable_fusion);
333                    }
334                }
335                _ => {
336                    warn!("Unknown configuration option: {}", key);
337                }
338            }
339        }
340
341        Ok(())
342    }
343
344    fn shutdown(&self) -> Result<()> {
345        info!("Shutting down CPU execution provider");
346
347        // The thread pool will be dropped automatically
348        // Memory allocator cleanup is handled by Drop traits
349
350        debug!("CPU provider shutdown complete");
351
352        Ok(())
353    }
354}
355
356/// Create a default CPU execution provider.
357pub fn create_cpu_provider() -> Result<Arc<dyn ExecutionProvider>> {
358    Ok(Arc::new(CpuExecutionProvider::new()?))
359}
360
361/// Create a CPU execution provider with custom configuration.
362pub fn create_cpu_provider_with_config(
363    config: CpuProviderConfig,
364) -> Result<Arc<dyn ExecutionProvider>> {
365    Ok(Arc::new(CpuExecutionProvider::with_config(config)?))
366}
367
368/// Create a NUMA-aware CPU execution provider.
369pub fn create_numa_cpu_provider(numa_node: i32) -> Result<Arc<dyn ExecutionProvider>> {
370    let config = CpuProviderConfig {
371        numa_node,
372        ..Default::default()
373    };
374    create_cpu_provider_with_config(config)
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use ronn_core::{AttributeValue, GraphNode};
381    use std::collections::HashMap;
382
383    #[test]
384    fn test_provider_creation() -> Result<()> {
385        let provider = CpuExecutionProvider::new()?;
386
387        assert_eq!(provider.provider_id(), ProviderId::CPU);
388
389        let capability = provider.get_capability();
390        assert_eq!(capability.performance_profile, PerformanceProfile::CPU);
391        assert!(!capability.supported_ops.is_empty());
392        assert!(capability.data_types.contains(&DataType::F32));
393
394        Ok(())
395    }
396
397    #[test]
398    fn test_provider_with_config() -> Result<()> {
399        let config = CpuProviderConfig {
400            thread_count: Some(2),
401            numa_node: 0,
402            enable_simd: false,
403            ..Default::default()
404        };
405
406        let provider = CpuExecutionProvider::with_config(config)?;
407
408        assert_eq!(provider.get_thread_pool().current_num_threads(), 2);
409        assert_eq!(provider.get_config().numa_node, 0);
410        assert!(!provider.get_config().enable_simd);
411
412        Ok(())
413    }
414
415    #[test]
416    fn test_operation_support() -> Result<()> {
417        let provider = CpuExecutionProvider::new()?;
418
419        // Test basic operations
420        assert!(provider.supports_operation("Add"));
421        assert!(provider.supports_operation("MatMul"));
422        assert!(provider.supports_operation("ReLU"));
423        assert!(!provider.supports_operation("NonexistentOp"));
424
425        // Test can_handle
426        let ops = vec![
427            OperatorSpec {
428                op_type: "Add".to_string(),
429                input_types: vec![DataType::F32],
430                output_types: vec![DataType::F32],
431                attributes: HashMap::new(),
432            },
433            OperatorSpec {
434                op_type: "InvalidOp".to_string(),
435                input_types: vec![DataType::F32],
436                output_types: vec![DataType::F32],
437                attributes: HashMap::new(),
438            },
439        ];
440
441        let support_results = provider.can_handle(&ops);
442        assert_eq!(support_results, vec![true, false]);
443
444        Ok(())
445    }
446
447    #[test]
448    fn test_subgraph_compilation() -> Result<()> {
449        let provider = CpuExecutionProvider::new()?;
450
451        let node = GraphNode {
452            id: 0,
453            op_type: "Add".to_string(),
454            attributes: HashMap::new(),
455            inputs: vec!["input1".to_string(), "input2".to_string()],
456            outputs: vec!["output1".to_string()],
457            name: Some("test_add".to_string()),
458        };
459
460        let subgraph = SubGraph {
461            nodes: vec![node],
462            edges: vec![],
463            inputs: vec!["input1".to_string(), "input2".to_string()],
464            outputs: vec!["output1".to_string()],
465        };
466
467        let kernel = provider.compile_subgraph(subgraph)?;
468
469        // Should have compiled successfully
470        let stats = kernel.get_performance_stats();
471        assert_eq!(stats.execution_count, 0); // Not executed yet
472
473        Ok(())
474    }
475
476    #[test]
477    fn test_configuration_update() -> Result<()> {
478        let mut provider = CpuExecutionProvider::new()?;
479
480        let config = ProviderConfig {
481            thread_count: Some(4),
482            memory_limit: Some(128 * 1024 * 1024), // 128MB
483            optimization_level: ronn_core::OptimizationLevel::Aggressive,
484            custom_options: {
485                let mut opts = HashMap::new();
486                opts.insert("enable_simd".to_string(), "false".to_string());
487                opts.insert("numa_node".to_string(), "1".to_string());
488                opts
489            },
490        };
491
492        provider.configure(config)?;
493
494        // Configuration should have been updated
495        assert_eq!(provider.get_config().memory_limit, Some(128 * 1024 * 1024));
496        assert!(!provider.get_config().enable_simd);
497        assert_eq!(provider.get_config().numa_node, 1);
498
499        Ok(())
500    }
501
502    #[test]
503    fn test_cost_estimation() -> Result<()> {
504        let provider = CpuExecutionProvider::new()?;
505
506        let add_op = OperatorSpec {
507            op_type: "Add".to_string(),
508            input_types: vec![DataType::F32],
509            output_types: vec![DataType::F32],
510            attributes: HashMap::new(),
511        };
512
513        let conv_op = OperatorSpec {
514            op_type: "Conv".to_string(),
515            input_types: vec![DataType::F32],
516            output_types: vec![DataType::F32],
517            attributes: HashMap::new(),
518        };
519
520        let add_cost = provider.estimate_cost(&add_op);
521        let conv_cost = provider.estimate_cost(&conv_op);
522
523        // Convolution should be more expensive than addition
524        assert!(conv_cost > add_cost);
525
526        Ok(())
527    }
528
529    #[test]
530    fn test_provider_shutdown() -> Result<()> {
531        let provider = CpuExecutionProvider::new()?;
532
533        // Should shutdown without errors
534        provider.shutdown()?;
535
536        Ok(())
537    }
538
539    #[test]
540    fn test_allocator() -> Result<()> {
541        let provider = CpuExecutionProvider::new()?;
542        let allocator = provider.get_allocator();
543
544        // Test basic allocation
545        let buffer = allocator.allocate(&[100], DataType::F32)?;
546        assert_eq!(buffer.size, 400); // 100 * 4 bytes
547        assert_eq!(buffer.memory_type, MemoryType::SystemRAM);
548
549        allocator.deallocate(buffer)?;
550
551        Ok(())
552    }
553
554    #[test]
555    fn test_factory_functions() -> Result<()> {
556        // Test default provider creation
557        let provider1 = create_cpu_provider()?;
558        assert_eq!(provider1.provider_id(), ProviderId::CPU);
559
560        // Test provider with custom config
561        let config = CpuProviderConfig {
562            thread_count: Some(1),
563            ..Default::default()
564        };
565        let provider2 = create_cpu_provider_with_config(config)?;
566        assert_eq!(provider2.provider_id(), ProviderId::CPU);
567
568        // Test NUMA-aware provider
569        let provider3 = create_numa_cpu_provider(0)?;
570        assert_eq!(provider3.provider_id(), ProviderId::CPU);
571
572        Ok(())
573    }
574}