use crate::error::{Result, TorshError};
use crate::shape::Shape;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(feature = "std")]
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct AcceleratorConfig {
pub broadcast_threshold: usize,
pub reshape_threshold: usize,
pub stride_dimension_threshold: usize,
pub batch_validation_threshold: usize,
pub enable_auto_tuning: bool,
pub device_id: usize,
}
impl Default for AcceleratorConfig {
fn default() -> Self {
Self {
broadcast_threshold: 10_000_000,
reshape_threshold: 5_000_000,
stride_dimension_threshold: 10,
batch_validation_threshold: 100,
enable_auto_tuning: false,
device_id: 0,
}
}
}
impl AcceleratorConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_broadcast_threshold(mut self, threshold: usize) -> Self {
self.broadcast_threshold = threshold;
self
}
pub fn with_reshape_threshold(mut self, threshold: usize) -> Self {
self.reshape_threshold = threshold;
self
}
pub fn with_stride_dimension_threshold(mut self, threshold: usize) -> Self {
self.stride_dimension_threshold = threshold;
self
}
pub fn with_batch_validation_threshold(mut self, threshold: usize) -> Self {
self.batch_validation_threshold = threshold;
self
}
pub fn with_auto_tuning(mut self, enable: bool) -> Self {
self.enable_auto_tuning = enable;
self
}
pub fn with_device_id(mut self, device_id: usize) -> Self {
self.device_id = device_id;
self
}
pub fn for_very_large_tensors() -> Self {
Self {
broadcast_threshold: 1_000_000,
reshape_threshold: 500_000,
stride_dimension_threshold: 8,
batch_validation_threshold: 50,
enable_auto_tuning: true,
device_id: 0,
}
}
pub fn for_high_dimensional() -> Self {
Self {
broadcast_threshold: 5_000_000,
reshape_threshold: 2_000_000,
stride_dimension_threshold: 6,
batch_validation_threshold: 100,
enable_auto_tuning: false,
device_id: 0,
}
}
pub fn conservative() -> Self {
Self {
broadcast_threshold: 50_000_000,
reshape_threshold: 25_000_000,
stride_dimension_threshold: 15,
batch_validation_threshold: 200,
enable_auto_tuning: false,
device_id: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AcceleratorStats {
pub total_operations: usize,
pub gpu_operations: usize,
pub cpu_operations: usize,
pub gpu_time_us: u64,
pub cpu_time_us: u64,
pub gpu_fallback_count: usize,
}
impl AcceleratorStats {
pub fn new() -> Self {
Self::default()
}
pub fn gpu_usage_percentage(&self) -> f64 {
if self.total_operations == 0 {
0.0
} else {
(self.gpu_operations as f64 / self.total_operations as f64) * 100.0
}
}
pub fn avg_gpu_time_us(&self) -> f64 {
if self.gpu_operations == 0 {
0.0
} else {
self.gpu_time_us as f64 / self.gpu_operations as f64
}
}
pub fn avg_cpu_time_us(&self) -> f64 {
if self.cpu_operations == 0 {
0.0
} else {
self.cpu_time_us as f64 / self.cpu_operations as f64
}
}
pub fn speedup_factor(&self) -> f64 {
if self.gpu_time_us == 0 {
0.0
} else {
self.cpu_time_us as f64 / self.gpu_time_us as f64
}
}
pub fn reset(&mut self) {
*self = Self::default();
}
}
#[cfg(feature = "std")]
pub struct GpuShapeAccelerator {
config: AcceleratorConfig,
stats: Arc<std::sync::Mutex<AcceleratorStats>>,
gpu_available: bool,
}
#[cfg(feature = "std")]
impl GpuShapeAccelerator {
pub fn new(config: AcceleratorConfig) -> Result<Self> {
#[cfg(feature = "gpu")]
let gpu_available = crate::gpu::is_gpu_available();
#[cfg(not(feature = "gpu"))]
let gpu_available = false;
Ok(Self {
config,
stats: Arc::new(std::sync::Mutex::new(AcceleratorStats::new())),
gpu_available,
})
}
pub fn default_config() -> Result<Self> {
Self::new(AcceleratorConfig::default())
}
pub fn is_gpu_available(&self) -> bool {
self.gpu_available
}
pub fn stats(&self) -> AcceleratorStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn reset_stats(&self) {
self.stats
.lock()
.expect("lock should not be poisoned")
.reset();
}
pub fn broadcast(&self, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
let numel1 = shape1.numel();
let numel2 = shape2.numel();
let total_elements = numel1.max(numel2);
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_operations += 1;
let use_gpu = self.gpu_available && total_elements >= self.config.broadcast_threshold;
if use_gpu {
stats.gpu_operations += 1;
drop(stats);
match self.broadcast_gpu(shape1, shape2) {
Ok(result) => Ok(result),
Err(_) => {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.gpu_fallback_count += 1;
stats.gpu_operations -= 1;
stats.cpu_operations += 1;
drop(stats);
self.broadcast_cpu(shape1, shape2)
}
}
} else {
stats.cpu_operations += 1;
drop(stats);
self.broadcast_cpu(shape1, shape2)
}
}
fn broadcast_cpu(&self, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
shape1.broadcast_with(shape2)
}
#[cfg(feature = "gpu")]
fn broadcast_gpu(&self, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
self.broadcast_cpu(shape1, shape2)
}
#[cfg(not(feature = "gpu"))]
fn broadcast_gpu(&self, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
self.broadcast_cpu(shape1, shape2)
}
pub fn reshape(&self, shape: &Shape, new_dims: &[usize]) -> Result<Shape> {
let numel = shape.numel();
let new_numel: usize = new_dims.iter().product();
if numel != new_numel {
return Err(TorshError::dimension_error(
&format!(
"Cannot reshape tensor of {} elements into shape with {} elements",
numel, new_numel
),
"reshape",
));
}
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_operations += 1;
let use_gpu = self.gpu_available && numel >= self.config.reshape_threshold;
if use_gpu {
stats.gpu_operations += 1;
drop(stats);
match self.reshape_gpu(shape, new_dims) {
Ok(result) => Ok(result),
Err(_) => {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.gpu_fallback_count += 1;
stats.gpu_operations -= 1;
stats.cpu_operations += 1;
drop(stats);
self.reshape_cpu(new_dims)
}
}
} else {
stats.cpu_operations += 1;
drop(stats);
self.reshape_cpu(new_dims)
}
}
fn reshape_cpu(&self, new_dims: &[usize]) -> Result<Shape> {
Shape::from_dims(new_dims.to_vec())
}
#[cfg(feature = "gpu")]
fn reshape_gpu(&self, _shape: &Shape, new_dims: &[usize]) -> Result<Shape> {
self.reshape_cpu(new_dims)
}
#[cfg(not(feature = "gpu"))]
fn reshape_gpu(&self, _shape: &Shape, new_dims: &[usize]) -> Result<Shape> {
self.reshape_cpu(new_dims)
}
pub fn batch_validate(&self, shapes: &[Vec<usize>]) -> Result<Vec<bool>> {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_operations += 1;
let use_gpu = self.gpu_available && shapes.len() >= self.config.batch_validation_threshold;
if use_gpu {
stats.gpu_operations += 1;
drop(stats);
match self.batch_validate_gpu(shapes) {
Ok(result) => Ok(result),
Err(_) => {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.gpu_fallback_count += 1;
stats.gpu_operations -= 1;
stats.cpu_operations += 1;
drop(stats);
self.batch_validate_cpu(shapes)
}
}
} else {
stats.cpu_operations += 1;
drop(stats);
self.batch_validate_cpu(shapes)
}
}
fn batch_validate_cpu(&self, shapes: &[Vec<usize>]) -> Result<Vec<bool>> {
Ok(shapes
.iter()
.map(|dims| {
!dims.is_empty() && dims.iter().all(|&d| d > 0)
})
.collect())
}
#[cfg(feature = "gpu")]
fn batch_validate_gpu(&self, shapes: &[Vec<usize>]) -> Result<Vec<bool>> {
self.batch_validate_cpu(shapes)
}
#[cfg(not(feature = "gpu"))]
fn batch_validate_gpu(&self, shapes: &[Vec<usize>]) -> Result<Vec<bool>> {
self.batch_validate_cpu(shapes)
}
pub fn compute_strides(&self, dims: &[usize]) -> Result<Vec<usize>> {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_operations += 1;
let use_gpu = self.gpu_available && dims.len() >= self.config.stride_dimension_threshold;
if use_gpu {
stats.gpu_operations += 1;
drop(stats);
match self.compute_strides_gpu(dims) {
Ok(result) => Ok(result),
Err(_) => {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.gpu_fallback_count += 1;
stats.gpu_operations -= 1;
stats.cpu_operations += 1;
drop(stats);
self.compute_strides_cpu(dims)
}
}
} else {
stats.cpu_operations += 1;
drop(stats);
self.compute_strides_cpu(dims)
}
}
fn compute_strides_cpu(&self, dims: &[usize]) -> Result<Vec<usize>> {
if dims.is_empty() {
return Ok(Vec::new());
}
let mut strides = vec![0; dims.len()];
let mut stride = 1;
for i in (0..dims.len()).rev() {
strides[i] = stride;
stride *= dims[i];
}
Ok(strides)
}
#[cfg(feature = "gpu")]
fn compute_strides_gpu(&self, dims: &[usize]) -> Result<Vec<usize>> {
self.compute_strides_cpu(dims)
}
#[cfg(not(feature = "gpu"))]
fn compute_strides_gpu(&self, dims: &[usize]) -> Result<Vec<usize>> {
self.compute_strides_cpu(dims)
}
pub fn config(&self) -> &AcceleratorConfig {
&self.config
}
pub fn set_config(&mut self, config: AcceleratorConfig) {
self.config = config;
}
}
#[cfg(not(feature = "std"))]
pub struct GpuShapeAccelerator {
config: AcceleratorConfig,
gpu_available: bool,
}
#[cfg(not(feature = "std"))]
impl GpuShapeAccelerator {
pub fn new(config: AcceleratorConfig) -> Result<Self> {
#[cfg(feature = "gpu")]
let gpu_available = crate::gpu::is_gpu_available();
#[cfg(not(feature = "gpu"))]
let gpu_available = false;
Ok(Self {
config,
gpu_available,
})
}
pub fn is_gpu_available(&self) -> bool {
self.gpu_available
}
pub fn broadcast(&self, shape1: &Shape, shape2: &Shape) -> Result<Shape> {
shape1.broadcast_with(shape2)
}
pub fn config(&self) -> &AcceleratorConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accelerator_config_default() {
let config = AcceleratorConfig::default();
assert_eq!(config.broadcast_threshold, 10_000_000);
assert_eq!(config.reshape_threshold, 5_000_000);
assert_eq!(config.stride_dimension_threshold, 10);
assert_eq!(config.batch_validation_threshold, 100);
assert!(!config.enable_auto_tuning);
assert_eq!(config.device_id, 0);
}
#[test]
fn test_accelerator_config_builder() {
let config = AcceleratorConfig::new()
.with_broadcast_threshold(1_000_000)
.with_reshape_threshold(500_000)
.with_stride_dimension_threshold(5)
.with_batch_validation_threshold(50)
.with_auto_tuning(true)
.with_device_id(1);
assert_eq!(config.broadcast_threshold, 1_000_000);
assert_eq!(config.reshape_threshold, 500_000);
assert_eq!(config.stride_dimension_threshold, 5);
assert_eq!(config.batch_validation_threshold, 50);
assert!(config.enable_auto_tuning);
assert_eq!(config.device_id, 1);
}
#[test]
fn test_accelerator_config_presets() {
let very_large = AcceleratorConfig::for_very_large_tensors();
assert_eq!(very_large.broadcast_threshold, 1_000_000);
assert!(very_large.enable_auto_tuning);
let high_dim = AcceleratorConfig::for_high_dimensional();
assert_eq!(high_dim.stride_dimension_threshold, 6);
let conservative = AcceleratorConfig::conservative();
assert_eq!(conservative.broadcast_threshold, 50_000_000);
}
#[test]
fn test_accelerator_stats() {
let mut stats = AcceleratorStats::new();
assert_eq!(stats.total_operations, 0);
assert_eq!(stats.gpu_operations, 0);
assert_eq!(stats.cpu_operations, 0);
stats.total_operations = 100;
stats.gpu_operations = 60;
stats.cpu_operations = 40;
stats.gpu_time_us = 1000;
stats.cpu_time_us = 3000;
assert_eq!(stats.gpu_usage_percentage(), 60.0);
assert_eq!(stats.avg_gpu_time_us(), 1000.0 / 60.0);
assert_eq!(stats.avg_cpu_time_us(), 3000.0 / 40.0);
assert_eq!(stats.speedup_factor(), 3.0);
stats.reset();
assert_eq!(stats.total_operations, 0);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_creation() {
let config = AcceleratorConfig::default();
let accelerator = GpuShapeAccelerator::new(config);
assert!(accelerator.is_ok());
let accelerator = accelerator.expect("accelerator creation should succeed");
let _gpu_available = accelerator.is_gpu_available();
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_broadcast() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shape1 = Shape::from_dims(vec![10, 20, 30]).expect("shape creation should succeed");
let shape2 = Shape::from_dims(vec![1, 20, 30]).expect("shape creation should succeed");
let result = accelerator.broadcast(&shape1, &shape2);
assert!(result.is_ok());
let result = result.expect("broadcast should succeed");
assert_eq!(result.dims(), &[10, 20, 30]);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_reshape() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shape = Shape::from_dims(vec![10, 20, 30]).expect("shape creation should succeed");
let new_dims = vec![10, 600];
let result = accelerator.reshape(&shape, &new_dims);
assert!(result.is_ok());
let result = result.expect("reshape should succeed");
assert_eq!(result.dims(), &[10, 600]);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_reshape_invalid() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shape = Shape::from_dims(vec![10, 20, 30]).expect("shape creation should succeed");
let new_dims = vec![10, 100];
let result = accelerator.reshape(&shape, &new_dims);
assert!(result.is_err());
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_batch_validate() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shapes = vec![
vec![10, 20],
vec![30, 40, 50],
vec![], vec![10, 0, 20], vec![5, 5, 5, 5],
];
let result = accelerator.batch_validate(&shapes);
assert!(result.is_ok());
let result = result.expect("batch_validate should succeed");
assert_eq!(result.len(), 5);
assert!(result[0]); assert!(result[1]); assert!(!result[2]); assert!(!result[3]); assert!(result[4]); }
#[test]
#[cfg(feature = "std")]
fn test_accelerator_compute_strides() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let dims = vec![10, 20, 30];
let result = accelerator.compute_strides(&dims);
assert!(result.is_ok());
let strides = result.expect("compute_strides should succeed");
assert_eq!(strides, vec![600, 30, 1]);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_compute_strides_empty() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let dims = vec![];
let result = accelerator.compute_strides(&dims);
assert!(result.is_ok());
let strides = result.expect("compute_strides should succeed");
assert!(strides.is_empty());
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_stats_tracking() {
let config = AcceleratorConfig::default();
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shape1 = Shape::from_dims(vec![10, 20]).expect("shape creation should succeed");
let shape2 = Shape::from_dims(vec![1, 20]).expect("shape creation should succeed");
let _ = accelerator.broadcast(&shape1, &shape2);
let stats = accelerator.stats();
assert_eq!(stats.total_operations, 1);
assert!(stats.gpu_operations + stats.cpu_operations == 1);
accelerator.reset_stats();
let stats = accelerator.stats();
assert_eq!(stats.total_operations, 0);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_config_update() {
let config = AcceleratorConfig::default();
let mut accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
assert_eq!(accelerator.config().broadcast_threshold, 10_000_000);
let new_config = AcceleratorConfig::new().with_broadcast_threshold(1_000_000);
accelerator.set_config(new_config);
assert_eq!(accelerator.config().broadcast_threshold, 1_000_000);
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_high_dimensional_strides() {
let config = AcceleratorConfig::new().with_stride_dimension_threshold(5);
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let dims = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
let result = accelerator.compute_strides(&dims);
assert!(result.is_ok());
let strides = result.expect("compute_strides should succeed");
assert_eq!(strides.len(), dims.len());
let mut expected_stride = 1;
for i in (0..dims.len()).rev() {
assert_eq!(strides[i], expected_stride);
expected_stride *= dims[i];
}
}
#[test]
#[cfg(feature = "std")]
fn test_accelerator_large_batch_validation() {
let config = AcceleratorConfig::new().with_batch_validation_threshold(10);
let accelerator =
GpuShapeAccelerator::new(config).expect("accelerator creation should succeed");
let shapes: Vec<Vec<usize>> = (0..50)
.map(|i| {
if i % 10 == 0 {
vec![] } else {
vec![10, 20, 30]
}
})
.collect();
let result = accelerator.batch_validate(&shapes);
assert!(result.is_ok());
let validations = result.expect("batch_validate should succeed");
assert_eq!(validations.len(), 50);
for i in 0..50 {
if i % 10 == 0 {
assert!(!validations[i]);
} else {
assert!(validations[i]);
}
}
}
}