1use std::collections::HashSet;
7use std::sync::Arc;
8
9use anyhow::{Result, anyhow};
10use rayon::{ThreadPool, ThreadPoolBuilder};
11use ronn_core::{
12 CompiledKernel, DataType, ExecutionProvider, MemoryType, OperatorSpec, PerformanceProfile,
13 ProviderCapability, ProviderConfig, ProviderId, ResourceRequirements, SubGraph,
14 TensorAllocator,
15};
16use tracing::{debug, info, warn};
17
18use super::{
19 allocator::{create_cpu_allocator, create_numa_cpu_allocator},
20 kernels::CpuKernel,
21 simd::{SimdCapabilities, detect_simd_capabilities},
22};
23
24pub struct CpuExecutionProvider {
26 config: CpuProviderConfig,
28 simd_capabilities: SimdCapabilities,
30 thread_pool: ThreadPool,
32 allocator: Arc<dyn TensorAllocator>,
34 supported_ops: HashSet<String>,
36}
37
38#[derive(Debug, Clone)]
40pub struct CpuProviderConfig {
41 pub thread_count: Option<usize>,
43 pub memory_limit: Option<usize>,
45 pub numa_node: i32,
47 pub enable_simd: bool,
49 pub enable_fusion: bool,
51 pub thread_pool_name: String,
53}
54
55impl Default for CpuProviderConfig {
56 fn default() -> Self {
57 Self {
58 thread_count: None, memory_limit: None, numa_node: -1, enable_simd: true, enable_fusion: true, thread_pool_name: "cpu-provider".to_string(),
64 }
65 }
66}
67
68impl CpuExecutionProvider {
69 pub fn new() -> Result<Self> {
71 Self::with_config(CpuProviderConfig::default())
72 }
73
74 pub fn with_config(config: CpuProviderConfig) -> Result<Self> {
76 let simd_capabilities = if config.enable_simd {
77 detect_simd_capabilities()
78 } else {
79 SimdCapabilities::default() };
81
82 info!("Detected SIMD capabilities: {:?}", simd_capabilities);
83
84 let thread_count = config.thread_count.unwrap_or_else(|| {
86 let cores = num_cpus::get();
87 (cores - 1).max(1)
89 });
90
91 let thread_pool_name = config.thread_pool_name.clone();
93 let thread_pool = ThreadPoolBuilder::new()
94 .num_threads(thread_count)
95 .thread_name(move |i| format!("{}-worker-{}", thread_pool_name, i))
96 .build()
97 .map_err(|e| anyhow!("Failed to create thread pool: {}", e))?;
98
99 info!("Created CPU thread pool with {} threads", thread_count);
100
101 let allocator: Arc<dyn TensorAllocator> = if config.numa_node >= 0 {
103 create_numa_cpu_allocator(config.numa_node)
104 } else {
105 create_cpu_allocator()
106 };
107
108 let mut supported_ops = HashSet::new();
110
111 supported_ops.insert("Add".to_string());
113 supported_ops.insert("Sub".to_string());
114 supported_ops.insert("Mul".to_string());
115 supported_ops.insert("Div".to_string());
116
117 supported_ops.insert("MatMul".to_string());
119 supported_ops.insert("Gemm".to_string());
120
121 supported_ops.insert("Reshape".to_string());
123 supported_ops.insert("Transpose".to_string());
124 supported_ops.insert("Flatten".to_string());
125 supported_ops.insert("Squeeze".to_string());
126 supported_ops.insert("Unsqueeze".to_string());
127
128 supported_ops.insert("Sum".to_string());
130 supported_ops.insert("Mean".to_string());
131 supported_ops.insert("Max".to_string());
132 supported_ops.insert("Min".to_string());
133 supported_ops.insert("ArgMax".to_string());
134 supported_ops.insert("ArgMin".to_string());
135
136 supported_ops.insert("ReLU".to_string());
138 supported_ops.insert("Sigmoid".to_string());
139 supported_ops.insert("Tanh".to_string());
140 supported_ops.insert("Softmax".to_string());
141
142 supported_ops.insert("Conv".to_string());
144 supported_ops.insert("MaxPool".to_string());
145 supported_ops.insert("AveragePool".to_string());
146
147 supported_ops.insert("BatchNormalization".to_string());
149
150 supported_ops.insert("Concat".to_string());
152 supported_ops.insert("Split".to_string());
153 supported_ops.insert("Slice".to_string());
154 supported_ops.insert("Gather".to_string());
155
156 info!(
157 "CPU provider supports {} operation types",
158 supported_ops.len()
159 );
160
161 Ok(Self {
162 config,
163 simd_capabilities,
164 thread_pool,
165 allocator,
166 supported_ops,
167 })
168 }
169
170 pub fn get_config(&self) -> &CpuProviderConfig {
172 &self.config
173 }
174
175 pub fn get_simd_capabilities(&self) -> &SimdCapabilities {
177 &self.simd_capabilities
178 }
179
180 pub fn get_thread_pool(&self) -> &ThreadPool {
182 &self.thread_pool
183 }
184
185 pub fn supports_operation(&self, op_type: &str) -> bool {
187 self.supported_ops.contains(op_type)
188 }
189
190 pub fn estimate_cost(&self, op_spec: &OperatorSpec) -> f64 {
192 match op_spec.op_type.as_str() {
195 "Add" | "Sub" | "Mul" | "Div" => 1.0, "ReLU" | "Sigmoid" | "Tanh" => 2.0, "MatMul" | "Gemm" => 10.0, "Conv" => 20.0, "BatchNormalization" => 5.0, "Softmax" => 8.0, _ => 1.0, }
203 }
204}
205
206impl Default for CpuExecutionProvider {
207 fn default() -> Self {
208 Self::new().expect("Failed to create default CPU provider")
209 }
210}
211
212impl ExecutionProvider for CpuExecutionProvider {
213 fn provider_id(&self) -> ProviderId {
214 ProviderId::CPU
215 }
216
217 fn get_capability(&self) -> ProviderCapability {
218 let mut cpu_features = Vec::new();
220
221 if self.simd_capabilities.sse2 {
222 cpu_features.push("sse2".to_string());
223 }
224 if self.simd_capabilities.sse41 {
225 cpu_features.push("sse4.1".to_string());
226 }
227 if self.simd_capabilities.avx {
228 cpu_features.push("avx".to_string());
229 }
230 if self.simd_capabilities.avx2 {
231 cpu_features.push("avx2".to_string());
232 }
233 if self.simd_capabilities.avx512f {
234 cpu_features.push("avx512f".to_string());
235 }
236 if self.simd_capabilities.fma {
237 cpu_features.push("fma".to_string());
238 }
239
240 ProviderCapability {
241 supported_ops: self.supported_ops.clone(),
242 data_types: vec![
243 DataType::F32,
244 DataType::F16,
245 DataType::F64,
246 DataType::I8,
247 DataType::I32,
248 DataType::U8,
249 DataType::U32,
250 DataType::Bool,
251 ],
252 memory_types: vec![MemoryType::SystemRAM],
253 performance_profile: PerformanceProfile::CPU,
254 resource_requirements: ResourceRequirements {
255 min_memory_bytes: Some(64 * 1024 * 1024), cpu_features,
257 gpu_memory_bytes: None,
258 },
259 }
260 }
261
262 fn can_handle(&self, operators: &[OperatorSpec]) -> Vec<bool> {
263 operators
264 .iter()
265 .map(|op| self.supports_operation(&op.op_type))
266 .collect()
267 }
268
269 fn compile_subgraph(&self, subgraph: SubGraph) -> Result<Box<dyn CompiledKernel>> {
270 debug!("Compiling subgraph with {} nodes", subgraph.nodes.len());
271
272 for node in &subgraph.nodes {
274 if !self.supports_operation(&node.op_type) {
275 return Err(anyhow!(
276 "Unsupported operation '{}' in subgraph",
277 node.op_type
278 ));
279 }
280 }
281
282 let kernel = CpuKernel::compile(subgraph, self.simd_capabilities.clone())?;
284
285 debug!("Successfully compiled CPU kernel");
286
287 Ok(Box::new(kernel))
288 }
289
290 fn get_allocator(&self) -> Arc<dyn TensorAllocator> {
291 self.allocator.clone()
292 }
293
294 fn configure(&mut self, config: ProviderConfig) -> Result<()> {
295 if let Some(thread_count) = config.thread_count {
297 if thread_count != self.thread_pool.current_num_threads() {
298 warn!(
299 "Thread count change requested ({} -> {}), but requires provider recreation",
300 self.thread_pool.current_num_threads(),
301 thread_count
302 );
303 }
305 }
306
307 if let Some(memory_limit) = config.memory_limit {
309 self.config.memory_limit = Some(memory_limit);
310 info!("Updated memory limit to {} bytes", memory_limit);
311 }
312
313 for (key, value) in &config.custom_options {
315 match key.as_str() {
316 "numa_node" => {
317 if let Ok(numa_node) = value.parse::<i32>() {
318 self.config.numa_node = numa_node;
319 info!("Updated NUMA node preference to {}", numa_node);
320 }
322 }
323 "enable_simd" => {
324 if let Ok(enable_simd) = value.parse::<bool>() {
325 self.config.enable_simd = enable_simd;
326 info!("Updated SIMD enablement to {}", enable_simd);
327 }
328 }
329 "enable_fusion" => {
330 if let Ok(enable_fusion) = value.parse::<bool>() {
331 self.config.enable_fusion = enable_fusion;
332 info!("Updated fusion enablement to {}", enable_fusion);
333 }
334 }
335 _ => {
336 warn!("Unknown configuration option: {}", key);
337 }
338 }
339 }
340
341 Ok(())
342 }
343
344 fn shutdown(&self) -> Result<()> {
345 info!("Shutting down CPU execution provider");
346
347 debug!("CPU provider shutdown complete");
351
352 Ok(())
353 }
354}
355
356pub fn create_cpu_provider() -> Result<Arc<dyn ExecutionProvider>> {
358 Ok(Arc::new(CpuExecutionProvider::new()?))
359}
360
361pub fn create_cpu_provider_with_config(
363 config: CpuProviderConfig,
364) -> Result<Arc<dyn ExecutionProvider>> {
365 Ok(Arc::new(CpuExecutionProvider::with_config(config)?))
366}
367
368pub fn create_numa_cpu_provider(numa_node: i32) -> Result<Arc<dyn ExecutionProvider>> {
370 let config = CpuProviderConfig {
371 numa_node,
372 ..Default::default()
373 };
374 create_cpu_provider_with_config(config)
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use ronn_core::{AttributeValue, GraphNode};
381 use std::collections::HashMap;
382
383 #[test]
384 fn test_provider_creation() -> Result<()> {
385 let provider = CpuExecutionProvider::new()?;
386
387 assert_eq!(provider.provider_id(), ProviderId::CPU);
388
389 let capability = provider.get_capability();
390 assert_eq!(capability.performance_profile, PerformanceProfile::CPU);
391 assert!(!capability.supported_ops.is_empty());
392 assert!(capability.data_types.contains(&DataType::F32));
393
394 Ok(())
395 }
396
397 #[test]
398 fn test_provider_with_config() -> Result<()> {
399 let config = CpuProviderConfig {
400 thread_count: Some(2),
401 numa_node: 0,
402 enable_simd: false,
403 ..Default::default()
404 };
405
406 let provider = CpuExecutionProvider::with_config(config)?;
407
408 assert_eq!(provider.get_thread_pool().current_num_threads(), 2);
409 assert_eq!(provider.get_config().numa_node, 0);
410 assert!(!provider.get_config().enable_simd);
411
412 Ok(())
413 }
414
415 #[test]
416 fn test_operation_support() -> Result<()> {
417 let provider = CpuExecutionProvider::new()?;
418
419 assert!(provider.supports_operation("Add"));
421 assert!(provider.supports_operation("MatMul"));
422 assert!(provider.supports_operation("ReLU"));
423 assert!(!provider.supports_operation("NonexistentOp"));
424
425 let ops = vec![
427 OperatorSpec {
428 op_type: "Add".to_string(),
429 input_types: vec![DataType::F32],
430 output_types: vec![DataType::F32],
431 attributes: HashMap::new(),
432 },
433 OperatorSpec {
434 op_type: "InvalidOp".to_string(),
435 input_types: vec![DataType::F32],
436 output_types: vec![DataType::F32],
437 attributes: HashMap::new(),
438 },
439 ];
440
441 let support_results = provider.can_handle(&ops);
442 assert_eq!(support_results, vec![true, false]);
443
444 Ok(())
445 }
446
447 #[test]
448 fn test_subgraph_compilation() -> Result<()> {
449 let provider = CpuExecutionProvider::new()?;
450
451 let node = GraphNode {
452 id: 0,
453 op_type: "Add".to_string(),
454 attributes: HashMap::new(),
455 inputs: vec!["input1".to_string(), "input2".to_string()],
456 outputs: vec!["output1".to_string()],
457 name: Some("test_add".to_string()),
458 };
459
460 let subgraph = SubGraph {
461 nodes: vec![node],
462 edges: vec![],
463 inputs: vec!["input1".to_string(), "input2".to_string()],
464 outputs: vec!["output1".to_string()],
465 };
466
467 let kernel = provider.compile_subgraph(subgraph)?;
468
469 let stats = kernel.get_performance_stats();
471 assert_eq!(stats.execution_count, 0); Ok(())
474 }
475
476 #[test]
477 fn test_configuration_update() -> Result<()> {
478 let mut provider = CpuExecutionProvider::new()?;
479
480 let config = ProviderConfig {
481 thread_count: Some(4),
482 memory_limit: Some(128 * 1024 * 1024), optimization_level: ronn_core::OptimizationLevel::Aggressive,
484 custom_options: {
485 let mut opts = HashMap::new();
486 opts.insert("enable_simd".to_string(), "false".to_string());
487 opts.insert("numa_node".to_string(), "1".to_string());
488 opts
489 },
490 };
491
492 provider.configure(config)?;
493
494 assert_eq!(provider.get_config().memory_limit, Some(128 * 1024 * 1024));
496 assert!(!provider.get_config().enable_simd);
497 assert_eq!(provider.get_config().numa_node, 1);
498
499 Ok(())
500 }
501
502 #[test]
503 fn test_cost_estimation() -> Result<()> {
504 let provider = CpuExecutionProvider::new()?;
505
506 let add_op = OperatorSpec {
507 op_type: "Add".to_string(),
508 input_types: vec![DataType::F32],
509 output_types: vec![DataType::F32],
510 attributes: HashMap::new(),
511 };
512
513 let conv_op = OperatorSpec {
514 op_type: "Conv".to_string(),
515 input_types: vec![DataType::F32],
516 output_types: vec![DataType::F32],
517 attributes: HashMap::new(),
518 };
519
520 let add_cost = provider.estimate_cost(&add_op);
521 let conv_cost = provider.estimate_cost(&conv_op);
522
523 assert!(conv_cost > add_cost);
525
526 Ok(())
527 }
528
529 #[test]
530 fn test_provider_shutdown() -> Result<()> {
531 let provider = CpuExecutionProvider::new()?;
532
533 provider.shutdown()?;
535
536 Ok(())
537 }
538
539 #[test]
540 fn test_allocator() -> Result<()> {
541 let provider = CpuExecutionProvider::new()?;
542 let allocator = provider.get_allocator();
543
544 let buffer = allocator.allocate(&[100], DataType::F32)?;
546 assert_eq!(buffer.size, 400); assert_eq!(buffer.memory_type, MemoryType::SystemRAM);
548
549 allocator.deallocate(buffer)?;
550
551 Ok(())
552 }
553
554 #[test]
555 fn test_factory_functions() -> Result<()> {
556 let provider1 = create_cpu_provider()?;
558 assert_eq!(provider1.provider_id(), ProviderId::CPU);
559
560 let config = CpuProviderConfig {
562 thread_count: Some(1),
563 ..Default::default()
564 };
565 let provider2 = create_cpu_provider_with_config(config)?;
566 assert_eq!(provider2.provider_id(), ProviderId::CPU);
567
568 let provider3 = create_numa_cpu_provider(0)?;
570 assert_eq!(provider3.provider_id(), ProviderId::CPU);
571
572 Ok(())
573 }
574}