use crate::error::{RusTorchError, RusTorchResult};
use std::env;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OptimizationLevel {
None,
Basic,
Standard,
Aggressive,
}
#[derive(Debug, Clone)]
pub struct PlatformFeatures {
pub os: String,
pub arch: String,
pub cpu_cores: usize,
pub total_memory: usize,
pub cache_line_size: usize,
pub page_size: usize,
pub supports_huge_pages: bool,
pub supports_prefetch: bool,
}
pub struct PlatformOptimizer {
features: PlatformFeatures,
optimization_level: OptimizationLevel,
thread_pool_size: usize,
}
impl PlatformOptimizer {
pub fn new() -> Self {
let features = Self::detect_features();
let optimization_level = OptimizationLevel::Standard;
let thread_pool_size = Self::calculate_optimal_threads(&features);
PlatformOptimizer {
features,
optimization_level,
thread_pool_size,
}
}
fn detect_features() -> PlatformFeatures {
PlatformFeatures {
os: env::consts::OS.to_string(),
arch: env::consts::ARCH.to_string(),
cpu_cores: num_cpus::get(),
total_memory: Self::get_total_memory(),
cache_line_size: Self::detect_cache_line_size(),
page_size: Self::get_page_size(),
supports_huge_pages: Self::check_huge_pages_support(),
supports_prefetch: Self::check_prefetch_support(),
}
}
fn get_total_memory() -> usize {
#[cfg(target_os = "linux")]
{
8 * 1024 * 1024 * 1024
}
#[cfg(target_os = "macos")]
{
unsafe {
let mut size: usize = std::mem::size_of::<i64>();
let mut total_mem: i64 = 0;
let mut mib = [6i32, 0i32];
libc::sysctl(
mib.as_mut_ptr(),
2,
&mut total_mem as *mut _ as *mut _,
&mut size,
std::ptr::null_mut(),
0,
);
if total_mem > 0 {
total_mem as usize
} else {
8_usize
.saturating_mul(1024)
.saturating_mul(1024)
.saturating_mul(1024)
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
{
8_usize
.saturating_mul(1024)
.saturating_mul(1024)
.saturating_mul(1024)
}
}
fn detect_cache_line_size() -> usize {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
64
}
#[cfg(target_arch = "aarch64")]
{
64
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
{
64
}
}
fn get_page_size() -> usize {
#[cfg(unix)]
{
unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize }
}
#[cfg(target_os = "windows")]
{
4096
}
#[cfg(not(any(unix, windows)))]
{
4096
}
}
fn check_huge_pages_support() -> bool {
#[cfg(target_os = "linux")]
{
std::path::Path::new("/sys/kernel/mm/transparent_hugepage/enabled").exists()
}
#[cfg(not(target_os = "linux"))]
{
false
}
}
fn check_prefetch_support() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
true }
#[cfg(target_arch = "aarch64")]
{
true }
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
{
false
}
}
fn calculate_optimal_threads(features: &PlatformFeatures) -> usize {
let compute_cores = (features.cpu_cores as f32 * 0.75) as usize;
compute_cores.max(1)
}
pub fn align_memory(&self, size: usize) -> usize {
let alignment = self.features.cache_line_size;
(size + alignment - 1) & !(alignment - 1)
}
pub fn allocate_aligned<T>(&self, count: usize) -> RusTorchResult<Vec<T>>
where
T: Default + Clone,
{
let size = count * std::mem::size_of::<T>();
let aligned_size = self.align_memory(size);
let aligned_count = aligned_size / std::mem::size_of::<T>();
Ok(vec![T::default(); aligned_count])
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub unsafe fn prefetch_read<T>(ptr: *const T) {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::_mm_prefetch;
_mm_prefetch(ptr as *const i8, 0); }
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub unsafe fn prefetch_read<T>(_ptr: *const T) {
}
#[cfg(target_os = "linux")]
pub fn set_thread_affinity(&self, _thread_id: usize) -> RusTorchResult<()> {
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn set_thread_affinity(&self, _thread_id: usize) -> RusTorchResult<()> {
Ok(())
}
pub fn features(&self) -> &PlatformFeatures {
&self.features
}
pub fn optimization_level(&self) -> OptimizationLevel {
self.optimization_level
}
pub fn set_optimization_level(&mut self, level: OptimizationLevel) {
self.optimization_level = level;
match level {
OptimizationLevel::None => self.thread_pool_size = 1,
OptimizationLevel::Basic => self.thread_pool_size = self.features.cpu_cores / 2,
OptimizationLevel::Standard => {
self.thread_pool_size = Self::calculate_optimal_threads(&self.features)
}
OptimizationLevel::Aggressive => self.thread_pool_size = self.features.cpu_cores,
}
}
pub fn thread_pool_size(&self) -> usize {
self.thread_pool_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_platform_detection() {
let optimizer = PlatformOptimizer::new();
let features = optimizer.features();
println!("Platform features:");
println!(" OS: {}", features.os);
println!(" Architecture: {}", features.arch);
println!(" CPU cores: {}", features.cpu_cores);
println!(
" Total memory: {} GB",
features.total_memory / (1024 * 1024 * 1024)
);
println!(" Cache line size: {} bytes", features.cache_line_size);
println!(" Page size: {} bytes", features.page_size);
println!(" Huge pages: {}", features.supports_huge_pages);
println!(" Prefetch: {}", features.supports_prefetch);
assert!(features.cpu_cores > 0);
assert!(features.cache_line_size > 0);
assert!(features.page_size > 0);
}
#[test]
fn test_memory_alignment() {
let optimizer = PlatformOptimizer::new();
let unaligned = 100;
let aligned = optimizer.align_memory(unaligned);
assert!(aligned >= unaligned);
assert_eq!(aligned % optimizer.features().cache_line_size, 0);
}
#[test]
fn test_optimization_levels() {
let mut optimizer = PlatformOptimizer::new();
optimizer.set_optimization_level(OptimizationLevel::None);
assert_eq!(optimizer.thread_pool_size(), 1);
optimizer.set_optimization_level(OptimizationLevel::Aggressive);
assert_eq!(optimizer.thread_pool_size(), optimizer.features().cpu_cores);
}
}