use std::collections::HashSet;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use rayon::{ThreadPool, ThreadPoolBuilder};
use ronn_core::{
CompiledKernel, DataType, ExecutionProvider, MemoryType, OperatorSpec, PerformanceProfile,
ProviderCapability, ProviderConfig, ProviderId, ResourceRequirements, SubGraph,
TensorAllocator,
};
use tracing::{debug, info, warn};
use super::{
allocator::{create_cpu_allocator, create_numa_cpu_allocator},
kernels::CpuKernel,
simd::{SimdCapabilities, detect_simd_capabilities},
};
pub struct CpuExecutionProvider {
config: CpuProviderConfig,
simd_capabilities: SimdCapabilities,
thread_pool: ThreadPool,
allocator: Arc<dyn TensorAllocator>,
supported_ops: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct CpuProviderConfig {
pub thread_count: Option<usize>,
pub memory_limit: Option<usize>,
pub numa_node: i32,
pub enable_simd: bool,
pub enable_fusion: bool,
pub thread_pool_name: String,
}
impl Default for CpuProviderConfig {
fn default() -> Self {
Self {
thread_count: None, memory_limit: None, numa_node: -1, enable_simd: true, enable_fusion: true, thread_pool_name: "cpu-provider".to_string(),
}
}
}
impl CpuExecutionProvider {
pub fn new() -> Result<Self> {
Self::with_config(CpuProviderConfig::default())
}
pub fn with_config(config: CpuProviderConfig) -> Result<Self> {
let simd_capabilities = if config.enable_simd {
detect_simd_capabilities()
} else {
SimdCapabilities::default() };
info!("Detected SIMD capabilities: {:?}", simd_capabilities);
let thread_count = config.thread_count.unwrap_or_else(|| {
let cores = num_cpus::get();
(cores - 1).max(1)
});
let thread_pool_name = config.thread_pool_name.clone();
let thread_pool = ThreadPoolBuilder::new()
.num_threads(thread_count)
.thread_name(move |i| format!("{}-worker-{}", thread_pool_name, i))
.build()
.map_err(|e| anyhow!("Failed to create thread pool: {}", e))?;
info!("Created CPU thread pool with {} threads", thread_count);
let allocator: Arc<dyn TensorAllocator> = if config.numa_node >= 0 {
create_numa_cpu_allocator(config.numa_node)
} else {
create_cpu_allocator()
};
let mut supported_ops = HashSet::new();
supported_ops.insert("Add".to_string());
supported_ops.insert("Sub".to_string());
supported_ops.insert("Mul".to_string());
supported_ops.insert("Div".to_string());
supported_ops.insert("MatMul".to_string());
supported_ops.insert("Gemm".to_string());
supported_ops.insert("Reshape".to_string());
supported_ops.insert("Transpose".to_string());
supported_ops.insert("Flatten".to_string());
supported_ops.insert("Squeeze".to_string());
supported_ops.insert("Unsqueeze".to_string());
supported_ops.insert("Sum".to_string());
supported_ops.insert("Mean".to_string());
supported_ops.insert("Max".to_string());
supported_ops.insert("Min".to_string());
supported_ops.insert("ArgMax".to_string());
supported_ops.insert("ArgMin".to_string());
supported_ops.insert("ReLU".to_string());
supported_ops.insert("Sigmoid".to_string());
supported_ops.insert("Tanh".to_string());
supported_ops.insert("Softmax".to_string());
supported_ops.insert("Conv".to_string());
supported_ops.insert("MaxPool".to_string());
supported_ops.insert("AveragePool".to_string());
supported_ops.insert("BatchNormalization".to_string());
supported_ops.insert("Concat".to_string());
supported_ops.insert("Split".to_string());
supported_ops.insert("Slice".to_string());
supported_ops.insert("Gather".to_string());
info!(
"CPU provider supports {} operation types",
supported_ops.len()
);
Ok(Self {
config,
simd_capabilities,
thread_pool,
allocator,
supported_ops,
})
}
pub fn get_config(&self) -> &CpuProviderConfig {
&self.config
}
pub fn get_simd_capabilities(&self) -> &SimdCapabilities {
&self.simd_capabilities
}
pub fn get_thread_pool(&self) -> &ThreadPool {
&self.thread_pool
}
pub fn supports_operation(&self, op_type: &str) -> bool {
self.supported_ops.contains(op_type)
}
pub fn estimate_cost(&self, op_spec: &OperatorSpec) -> f64 {
match op_spec.op_type.as_str() {
"Add" | "Sub" | "Mul" | "Div" => 1.0, "ReLU" | "Sigmoid" | "Tanh" => 2.0, "MatMul" | "Gemm" => 10.0, "Conv" => 20.0, "BatchNormalization" => 5.0, "Softmax" => 8.0, _ => 1.0, }
}
}
impl Default for CpuExecutionProvider {
fn default() -> Self {
Self::new().expect("Failed to create default CPU provider")
}
}
impl ExecutionProvider for CpuExecutionProvider {
fn provider_id(&self) -> ProviderId {
ProviderId::CPU
}
fn get_capability(&self) -> ProviderCapability {
let mut cpu_features = Vec::new();
if self.simd_capabilities.sse2 {
cpu_features.push("sse2".to_string());
}
if self.simd_capabilities.sse41 {
cpu_features.push("sse4.1".to_string());
}
if self.simd_capabilities.avx {
cpu_features.push("avx".to_string());
}
if self.simd_capabilities.avx2 {
cpu_features.push("avx2".to_string());
}
if self.simd_capabilities.avx512f {
cpu_features.push("avx512f".to_string());
}
if self.simd_capabilities.fma {
cpu_features.push("fma".to_string());
}
ProviderCapability {
supported_ops: self.supported_ops.clone(),
data_types: vec![
DataType::F32,
DataType::F16,
DataType::F64,
DataType::I8,
DataType::I32,
DataType::U8,
DataType::U32,
DataType::Bool,
],
memory_types: vec![MemoryType::SystemRAM],
performance_profile: PerformanceProfile::CPU,
resource_requirements: ResourceRequirements {
min_memory_bytes: Some(64 * 1024 * 1024), cpu_features,
gpu_memory_bytes: None,
},
}
}
fn can_handle(&self, operators: &[OperatorSpec]) -> Vec<bool> {
operators
.iter()
.map(|op| self.supports_operation(&op.op_type))
.collect()
}
fn compile_subgraph(&self, subgraph: SubGraph) -> Result<Box<dyn CompiledKernel>> {
debug!("Compiling subgraph with {} nodes", subgraph.nodes.len());
for node in &subgraph.nodes {
if !self.supports_operation(&node.op_type) {
return Err(anyhow!(
"Unsupported operation '{}' in subgraph",
node.op_type
));
}
}
let kernel = CpuKernel::compile(subgraph, self.simd_capabilities.clone())?;
debug!("Successfully compiled CPU kernel");
Ok(Box::new(kernel))
}
fn get_allocator(&self) -> Arc<dyn TensorAllocator> {
self.allocator.clone()
}
fn configure(&mut self, config: ProviderConfig) -> Result<()> {
if let Some(thread_count) = config.thread_count {
if thread_count != self.thread_pool.current_num_threads() {
warn!(
"Thread count change requested ({} -> {}), but requires provider recreation",
self.thread_pool.current_num_threads(),
thread_count
);
}
}
if let Some(memory_limit) = config.memory_limit {
self.config.memory_limit = Some(memory_limit);
info!("Updated memory limit to {} bytes", memory_limit);
}
for (key, value) in &config.custom_options {
match key.as_str() {
"numa_node" => {
if let Ok(numa_node) = value.parse::<i32>() {
self.config.numa_node = numa_node;
info!("Updated NUMA node preference to {}", numa_node);
}
}
"enable_simd" => {
if let Ok(enable_simd) = value.parse::<bool>() {
self.config.enable_simd = enable_simd;
info!("Updated SIMD enablement to {}", enable_simd);
}
}
"enable_fusion" => {
if let Ok(enable_fusion) = value.parse::<bool>() {
self.config.enable_fusion = enable_fusion;
info!("Updated fusion enablement to {}", enable_fusion);
}
}
_ => {
warn!("Unknown configuration option: {}", key);
}
}
}
Ok(())
}
fn shutdown(&self) -> Result<()> {
info!("Shutting down CPU execution provider");
debug!("CPU provider shutdown complete");
Ok(())
}
}
pub fn create_cpu_provider() -> Result<Arc<dyn ExecutionProvider>> {
Ok(Arc::new(CpuExecutionProvider::new()?))
}
pub fn create_cpu_provider_with_config(
config: CpuProviderConfig,
) -> Result<Arc<dyn ExecutionProvider>> {
Ok(Arc::new(CpuExecutionProvider::with_config(config)?))
}
pub fn create_numa_cpu_provider(numa_node: i32) -> Result<Arc<dyn ExecutionProvider>> {
let config = CpuProviderConfig {
numa_node,
..Default::default()
};
create_cpu_provider_with_config(config)
}
#[cfg(test)]
mod tests {
use super::*;
use ronn_core::{AttributeValue, GraphNode};
use std::collections::HashMap;
#[test]
fn test_provider_creation() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
assert_eq!(provider.provider_id(), ProviderId::CPU);
let capability = provider.get_capability();
assert_eq!(capability.performance_profile, PerformanceProfile::CPU);
assert!(!capability.supported_ops.is_empty());
assert!(capability.data_types.contains(&DataType::F32));
Ok(())
}
#[test]
fn test_provider_with_config() -> Result<()> {
let config = CpuProviderConfig {
thread_count: Some(2),
numa_node: 0,
enable_simd: false,
..Default::default()
};
let provider = CpuExecutionProvider::with_config(config)?;
assert_eq!(provider.get_thread_pool().current_num_threads(), 2);
assert_eq!(provider.get_config().numa_node, 0);
assert!(!provider.get_config().enable_simd);
Ok(())
}
#[test]
fn test_operation_support() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
assert!(provider.supports_operation("Add"));
assert!(provider.supports_operation("MatMul"));
assert!(provider.supports_operation("ReLU"));
assert!(!provider.supports_operation("NonexistentOp"));
let ops = vec![
OperatorSpec {
op_type: "Add".to_string(),
input_types: vec![DataType::F32],
output_types: vec![DataType::F32],
attributes: HashMap::new(),
},
OperatorSpec {
op_type: "InvalidOp".to_string(),
input_types: vec![DataType::F32],
output_types: vec![DataType::F32],
attributes: HashMap::new(),
},
];
let support_results = provider.can_handle(&ops);
assert_eq!(support_results, vec![true, false]);
Ok(())
}
#[test]
fn test_subgraph_compilation() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
let node = GraphNode {
id: 0,
op_type: "Add".to_string(),
attributes: HashMap::new(),
inputs: vec!["input1".to_string(), "input2".to_string()],
outputs: vec!["output1".to_string()],
name: Some("test_add".to_string()),
};
let subgraph = SubGraph {
nodes: vec![node],
edges: vec![],
inputs: vec!["input1".to_string(), "input2".to_string()],
outputs: vec!["output1".to_string()],
};
let kernel = provider.compile_subgraph(subgraph)?;
let stats = kernel.get_performance_stats();
assert_eq!(stats.execution_count, 0);
Ok(())
}
#[test]
fn test_configuration_update() -> Result<()> {
let mut provider = CpuExecutionProvider::new()?;
let config = ProviderConfig {
thread_count: Some(4),
memory_limit: Some(128 * 1024 * 1024), optimization_level: ronn_core::OptimizationLevel::Aggressive,
custom_options: {
let mut opts = HashMap::new();
opts.insert("enable_simd".to_string(), "false".to_string());
opts.insert("numa_node".to_string(), "1".to_string());
opts
},
};
provider.configure(config)?;
assert_eq!(provider.get_config().memory_limit, Some(128 * 1024 * 1024));
assert!(!provider.get_config().enable_simd);
assert_eq!(provider.get_config().numa_node, 1);
Ok(())
}
#[test]
fn test_cost_estimation() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
let add_op = OperatorSpec {
op_type: "Add".to_string(),
input_types: vec![DataType::F32],
output_types: vec![DataType::F32],
attributes: HashMap::new(),
};
let conv_op = OperatorSpec {
op_type: "Conv".to_string(),
input_types: vec![DataType::F32],
output_types: vec![DataType::F32],
attributes: HashMap::new(),
};
let add_cost = provider.estimate_cost(&add_op);
let conv_cost = provider.estimate_cost(&conv_op);
assert!(conv_cost > add_cost);
Ok(())
}
#[test]
fn test_provider_shutdown() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
provider.shutdown()?;
Ok(())
}
#[test]
fn test_allocator() -> Result<()> {
let provider = CpuExecutionProvider::new()?;
let allocator = provider.get_allocator();
let buffer = allocator.allocate(&[100], DataType::F32)?;
assert_eq!(buffer.size, 400); assert_eq!(buffer.memory_type, MemoryType::SystemRAM);
allocator.deallocate(buffer)?;
Ok(())
}
#[test]
fn test_factory_functions() -> Result<()> {
let provider1 = create_cpu_provider()?;
assert_eq!(provider1.provider_id(), ProviderId::CPU);
let config = CpuProviderConfig {
thread_count: Some(1),
..Default::default()
};
let provider2 = create_cpu_provider_with_config(config)?;
assert_eq!(provider2.provider_id(), ProviderId::CPU);
let provider3 = create_numa_cpu_provider(0)?;
assert_eq!(provider3.provider_id(), ProviderId::CPU);
Ok(())
}
}