use crate::device::core::DeviceContext;
use crate::device::{Device, DeviceCapabilities, DeviceType};
use crate::error::Result;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[cfg(feature = "parallel")]
use crate::parallel::{ThreadPool, ThreadPoolBuilder};
#[derive(Debug)]
pub struct CpuDevice {
context: DeviceContext,
#[cfg(feature = "parallel")]
thread_pool: Option<Arc<ThreadPool>>,
#[cfg(not(feature = "parallel"))]
thread_pool: Option<()>, simd_level: SimdLevel,
}
impl Clone for CpuDevice {
fn clone(&self) -> Self {
Self {
context: DeviceContext::new(DeviceType::Cpu),
thread_pool: self.thread_pool.clone(),
simd_level: self.simd_level,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdLevel {
None,
Sse,
Avx,
Avx2,
Avx512,
}
impl CpuDevice {
pub fn new() -> Self {
let context = DeviceContext::new(DeviceType::Cpu);
context
.lifecycle()
.set_state(crate::device::core::DeviceState::Initializing)
.ok();
let device = Self {
context,
thread_pool: Self::create_thread_pool(),
simd_level: Self::detect_simd_level(),
};
device
.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)
.ok();
device
}
pub fn with_threads(num_threads: usize) -> Result<Self> {
let context = DeviceContext::new(DeviceType::Cpu);
context
.lifecycle()
.set_state(crate::device::core::DeviceState::Initializing)?;
#[cfg(feature = "parallel")]
let thread_pool = {
let pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.map_err(|e| {
crate::error::TorshError::DeviceError(format!(
"Failed to create thread pool: {}",
e
))
})?;
Some(Arc::new(pool))
};
#[cfg(not(feature = "parallel"))]
let thread_pool = None;
let device = Self {
context,
thread_pool,
simd_level: Self::detect_simd_level(),
};
device
.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(device)
}
#[cfg(feature = "parallel")]
pub fn thread_pool(&self) -> Option<&Arc<ThreadPool>> {
self.thread_pool.as_ref()
}
#[cfg(not(feature = "parallel"))]
pub fn thread_pool(&self) -> Option<()> {
None
}
pub fn simd_level(&self) -> SimdLevel {
self.simd_level
}
#[cfg(feature = "parallel")]
pub fn execute_parallel<F, T>(&self, work: F) -> T
where
F: FnOnce() -> T + Send,
T: Send,
{
match &self.thread_pool {
Some(pool) => pool.install(work),
None => work(),
}
}
#[cfg(not(feature = "parallel"))]
pub fn execute_parallel<F, T>(&self, work: F) -> T
where
F: FnOnce() -> T + Send,
T: Send,
{
work()
}
#[cfg(feature = "parallel")]
fn create_thread_pool() -> Option<Arc<ThreadPool>> {
ThreadPoolBuilder::new().build().map(Arc::new).ok()
}
#[cfg(not(feature = "parallel"))]
fn create_thread_pool() -> Option<()> {
None
}
fn detect_simd_level() -> SimdLevel {
#[cfg(target_arch = "x86_64")]
{
if cfg!(target_feature = "avx512f") {
SimdLevel::Avx512
} else if cfg!(target_feature = "avx2") {
SimdLevel::Avx2
} else if cfg!(target_feature = "avx") {
SimdLevel::Avx
} else if cfg!(target_feature = "sse") {
SimdLevel::Sse
} else {
SimdLevel::None
}
}
#[cfg(not(target_arch = "x86_64"))]
{
SimdLevel::None
}
}
}
impl Default for CpuDevice {
fn default() -> Self {
Self::new()
}
}
impl Device for CpuDevice {
fn device_type(&self) -> DeviceType {
DeviceType::Cpu
}
fn name(&self) -> &str {
"CPU"
}
fn is_available(&self) -> Result<bool> {
Ok(self.context.lifecycle().is_ready())
}
fn capabilities(&self) -> Result<DeviceCapabilities> {
DeviceCapabilities::detect(DeviceType::Cpu)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
fn reset(&self) -> Result<()> {
self.context.lifecycle().reset()?;
self.context.clear_resources();
self.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>> {
Ok(Box::new(CpuDevice::new()))
}
}
#[derive(Debug)]
pub struct CudaDevice {
context: DeviceContext,
device_index: usize,
cuda_context: Option<CudaContext>,
#[allow(dead_code)] stream_manager: Arc<Mutex<CudaStreamManager>>,
}
#[derive(Debug)]
struct CudaContext {
#[allow(dead_code)] device_handle: u32,
#[allow(dead_code)] context_handle: u64,
compute_capability: (u32, u32),
}
#[derive(Debug)]
struct CudaStreamManager {
#[allow(dead_code)] streams: HashMap<u32, CudaStream>,
#[allow(dead_code)] next_stream_id: u32,
}
#[derive(Debug)]
struct CudaStream {
#[allow(dead_code)] stream_id: u32,
#[allow(dead_code)] stream_handle: u64,
#[allow(dead_code)] is_default: bool,
}
impl CudaDevice {
pub fn new(device_index: usize) -> Result<Self> {
let context = DeviceContext::new(DeviceType::Cuda(device_index));
context
.lifecycle()
.set_state(crate::device::core::DeviceState::Initializing)?;
#[cfg(feature = "cuda")]
{
let cuda_context = Self::initialize_cuda_context(device_index)?;
let device = Self {
context,
device_index,
cuda_context: Some(cuda_context),
stream_manager: Arc::new(Mutex::new(CudaStreamManager::new())),
};
device
.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(device)
}
#[cfg(not(feature = "cuda"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError("CUDA support not compiled".to_string()),
))
}
}
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn compute_capability(&self) -> Option<(u32, u32)> {
self.cuda_context.as_ref().map(|ctx| ctx.compute_capability)
}
pub fn create_stream(&self) -> Result<u32> {
#[cfg(feature = "cuda")]
{
let mut manager = self
.stream_manager
.lock()
.expect("lock should not be poisoned");
let stream_id = manager.next_stream_id;
manager.next_stream_id += 1;
let stream = CudaStream {
stream_id,
stream_handle: self.create_cuda_stream_handle()?,
is_default: false,
};
manager.streams.insert(stream_id, stream);
Ok(stream_id)
}
#[cfg(not(feature = "cuda"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::UnsupportedOperation {
op: "CUDA streams".to_string(),
dtype: "N/A".to_string(),
},
))
}
}
pub fn synchronize_device(&self) -> Result<()> {
#[cfg(feature = "cuda")]
{
Ok(())
}
#[cfg(not(feature = "cuda"))]
{
Ok(())
}
}
#[cfg(feature = "cuda")]
fn initialize_cuda_context(device_index: usize) -> Result<CudaContext> {
Ok(CudaContext {
device_handle: device_index as u32,
context_handle: 0x12345678, compute_capability: (8, 6), })
}
#[cfg(feature = "cuda")]
fn create_cuda_stream_handle(&self) -> Result<u64> {
Ok(0x87654321) }
}
impl CudaStreamManager {
#[allow(dead_code)] fn new() -> Self {
Self {
streams: HashMap::new(),
next_stream_id: 1, }
}
}
impl Device for CudaDevice {
fn device_type(&self) -> DeviceType {
DeviceType::Cuda(self.device_index)
}
fn name(&self) -> &str {
"CUDA Device"
}
fn is_available(&self) -> Result<bool> {
#[cfg(feature = "cuda")]
{
Ok(self.context.lifecycle().is_ready() && self.cuda_context.is_some())
}
#[cfg(not(feature = "cuda"))]
{
Ok(false)
}
}
fn capabilities(&self) -> Result<DeviceCapabilities> {
DeviceCapabilities::detect(DeviceType::Cuda(self.device_index))
}
fn synchronize(&self) -> Result<()> {
self.synchronize_device()
}
fn reset(&self) -> Result<()> {
self.context.lifecycle().reset()?;
self.context.clear_resources();
#[cfg(feature = "cuda")]
{
let mut manager = self
.stream_manager
.lock()
.expect("lock should not be poisoned");
manager.streams.clear();
manager.next_stream_id = 1;
}
self.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>> {
CudaDevice::new(self.device_index).map(|d| Box::new(d) as Box<dyn Device>)
}
}
#[derive(Debug)]
pub struct MetalDevice {
context: DeviceContext,
device_index: usize,
metal_device: Option<MetalDeviceHandle>,
#[allow(dead_code)] command_queue: Option<MetalCommandQueue>,
}
#[derive(Debug)]
struct MetalDeviceHandle {
#[allow(dead_code)] device_id: u64,
name: String,
#[allow(dead_code)] registry_id: u64,
}
#[derive(Debug)]
struct MetalCommandQueue {
#[allow(dead_code)] queue_id: u64,
#[allow(dead_code)] max_command_buffers: usize,
}
impl MetalDevice {
pub fn new(device_index: usize) -> Result<Self> {
let context = DeviceContext::new(DeviceType::Metal(device_index));
context
.lifecycle()
.set_state(crate::device::core::DeviceState::Initializing)?;
#[cfg(target_os = "macos")]
{
let metal_device = Self::create_metal_device(device_index)?;
let command_queue = Self::create_command_queue(&metal_device)?;
let device = Self {
context,
device_index,
metal_device: Some(metal_device),
command_queue: Some(command_queue),
};
device
.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(device)
}
#[cfg(not(target_os = "macos"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError(
"Metal device only available on macOS".to_string(),
),
))
}
}
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn metal_device_name(&self) -> Option<&str> {
self.metal_device.as_ref().map(|d| d.name.as_str())
}
pub fn execute_compute_shader(&self, _shader_source: &str) -> Result<()> {
#[cfg(target_os = "macos")]
{
Ok(())
}
#[cfg(not(target_os = "macos"))]
{
Err(crate::error::TorshError::NotImplemented(
"Metal compute shaders not available".to_string(),
))
}
}
#[cfg(target_os = "macos")]
fn create_metal_device(device_index: usize) -> Result<MetalDeviceHandle> {
Ok(MetalDeviceHandle {
device_id: device_index as u64,
name: format!("Apple GPU {}", device_index),
registry_id: 0x1000 + device_index as u64,
})
}
#[cfg(target_os = "macos")]
fn create_command_queue(_device: &MetalDeviceHandle) -> Result<MetalCommandQueue> {
Ok(MetalCommandQueue {
queue_id: 0x2000,
max_command_buffers: 64,
})
}
}
impl Device for MetalDevice {
fn device_type(&self) -> DeviceType {
DeviceType::Metal(self.device_index)
}
fn name(&self) -> &str {
#[cfg(target_os = "macos")]
{
self.metal_device_name().unwrap_or("Metal Device")
}
#[cfg(not(target_os = "macos"))]
{
"Metal Device (Unavailable)"
}
}
fn is_available(&self) -> Result<bool> {
#[cfg(target_os = "macos")]
{
Ok(self.context.lifecycle().is_ready() && self.metal_device.is_some())
}
#[cfg(not(target_os = "macos"))]
{
Ok(false)
}
}
fn capabilities(&self) -> Result<DeviceCapabilities> {
DeviceCapabilities::detect(DeviceType::Metal(self.device_index))
}
fn synchronize(&self) -> Result<()> {
#[cfg(target_os = "macos")]
{
Ok(())
}
#[cfg(not(target_os = "macos"))]
{
Ok(())
}
}
fn reset(&self) -> Result<()> {
self.context.lifecycle().reset()?;
self.context.clear_resources();
self.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>> {
MetalDevice::new(self.device_index).map(|d| Box::new(d) as Box<dyn Device>)
}
}
#[derive(Debug)]
pub struct WgpuDevice {
context: DeviceContext,
device_index: usize,
#[allow(dead_code)] wgpu_device: Option<WgpuDeviceHandle>,
adapter_info: Option<WgpuAdapterInfo>,
}
#[derive(Debug)]
struct WgpuDeviceHandle {
#[allow(dead_code)] device_id: u64,
#[allow(dead_code)] limits: WgpuLimits,
#[allow(dead_code)] features: Vec<String>,
}
#[derive(Debug)]
pub struct WgpuAdapterInfo {
#[allow(dead_code)] name: String,
#[allow(dead_code)] vendor: String,
#[allow(dead_code)] device_type: WgpuDeviceType,
#[allow(dead_code)] backend: WgpuBackend,
}
#[derive(Debug)]
struct WgpuLimits {
#[allow(dead_code)] max_bind_groups: u32,
#[allow(dead_code)] max_uniform_buffer_binding_size: u64,
#[allow(dead_code)] max_storage_buffer_binding_size: u64,
}
#[derive(Debug)]
enum WgpuDeviceType {
#[allow(dead_code)] DiscreteGpu,
#[allow(dead_code)] IntegratedGpu,
#[allow(dead_code)] VirtualGpu,
#[allow(dead_code)] Cpu,
}
#[derive(Debug)]
enum WgpuBackend {
#[allow(dead_code)] Vulkan,
#[allow(dead_code)] Metal,
#[allow(dead_code)] Dx12,
#[allow(dead_code)] Dx11,
#[allow(dead_code)] Gl,
#[allow(dead_code)] BrowserWebGpu,
}
impl WgpuDevice {
pub fn new(device_index: usize) -> Result<Self> {
let context = DeviceContext::new(DeviceType::Wgpu(device_index));
context
.lifecycle()
.set_state(crate::device::core::DeviceState::Initializing)?;
#[cfg(feature = "wgpu")]
{
let (wgpu_device, adapter_info) = Self::initialize_wgpu(device_index)?;
let device = Self {
context,
device_index,
wgpu_device: Some(wgpu_device),
adapter_info: Some(adapter_info),
};
device
.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(device)
}
#[cfg(not(feature = "wgpu"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError("WebGPU support not compiled".to_string()),
))
}
}
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn adapter_info(&self) -> Option<&WgpuAdapterInfo> {
self.adapter_info.as_ref()
}
pub fn execute_compute(&self, _shader_source: &str) -> Result<()> {
#[cfg(feature = "wgpu")]
{
Ok(())
}
#[cfg(not(feature = "wgpu"))]
{
Err(crate::error::TorshError::General(
crate::error::GeneralError::UnsupportedOperation {
op: "WebGPU compute".to_string(),
dtype: "N/A".to_string(),
},
))
}
}
#[cfg(feature = "wgpu")]
fn initialize_wgpu(device_index: usize) -> Result<(WgpuDeviceHandle, WgpuAdapterInfo)> {
let device = WgpuDeviceHandle {
device_id: device_index as u64,
limits: WgpuLimits {
max_bind_groups: 4,
max_uniform_buffer_binding_size: 16384,
max_storage_buffer_binding_size: 134217728,
},
features: vec!["compute-shaders".to_string()],
};
let adapter_info = WgpuAdapterInfo {
name: format!("WebGPU Adapter {}", device_index),
vendor: "Unknown".to_string(),
device_type: WgpuDeviceType::DiscreteGpu,
backend: WgpuBackend::Vulkan,
};
Ok((device, adapter_info))
}
}
impl Device for WgpuDevice {
fn device_type(&self) -> DeviceType {
DeviceType::Wgpu(self.device_index)
}
fn name(&self) -> &str {
#[cfg(feature = "wgpu")]
{
self.adapter_info
.as_ref()
.map(|info| info.name.as_str())
.unwrap_or("WebGPU Device")
}
#[cfg(not(feature = "wgpu"))]
{
"WebGPU Device (Unavailable)"
}
}
fn is_available(&self) -> Result<bool> {
#[cfg(feature = "wgpu")]
{
Ok(self.context.lifecycle().is_ready() && self.wgpu_device.is_some())
}
#[cfg(not(feature = "wgpu"))]
{
Ok(false)
}
}
fn capabilities(&self) -> Result<DeviceCapabilities> {
DeviceCapabilities::detect(DeviceType::Wgpu(self.device_index))
}
fn synchronize(&self) -> Result<()> {
#[cfg(feature = "wgpu")]
{
Ok(())
}
#[cfg(not(feature = "wgpu"))]
{
Ok(())
}
}
fn reset(&self) -> Result<()> {
self.context.lifecycle().reset()?;
self.context.clear_resources();
self.context
.lifecycle()
.set_state(crate::device::core::DeviceState::Ready)?;
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>> {
WgpuDevice::new(self.device_index).map(|d| Box::new(d) as Box<dyn Device>)
}
}
#[derive(Debug)]
pub struct DeviceFactory;
impl DeviceFactory {
pub fn create_device(device_type: DeviceType) -> Result<Box<dyn Device>> {
match device_type {
DeviceType::Cpu => Ok(Box::new(CpuDevice::new())),
DeviceType::Cuda(index) => {
CudaDevice::new(index).map(|d| Box::new(d) as Box<dyn Device>)
}
DeviceType::Metal(index) => {
MetalDevice::new(index).map(|d| Box::new(d) as Box<dyn Device>)
}
DeviceType::Wgpu(index) => {
WgpuDevice::new(index).map(|d| Box::new(d) as Box<dyn Device>)
}
}
}
pub fn create_cpu_with_threads(num_threads: usize) -> Result<Box<dyn Device>> {
CpuDevice::with_threads(num_threads).map(|d| Box::new(d) as Box<dyn Device>)
}
pub fn is_device_type_available(device_type: DeviceType) -> bool {
match device_type {
DeviceType::Cpu => true,
DeviceType::Cuda(_) => cfg!(feature = "cuda"),
DeviceType::Metal(_) => cfg!(target_os = "macos"),
DeviceType::Wgpu(_) => cfg!(feature = "wgpu"),
}
}
pub fn available_device_types() -> Vec<DeviceType> {
let mut types = vec![DeviceType::Cpu];
if cfg!(feature = "cuda") {
types.push(DeviceType::Cuda(0));
}
if cfg!(target_os = "macos") {
types.push(DeviceType::Metal(0));
}
if cfg!(feature = "wgpu") {
types.push(DeviceType::Wgpu(0));
}
types
}
}
pub mod utils {
use super::*;
pub fn cast_device<T: Device + 'static>(device: &dyn Device) -> Option<&T> {
device.as_any().downcast_ref::<T>()
}
pub fn cast_device_mut<T: Device + 'static>(device: &mut dyn Device) -> Option<&mut T> {
device.as_any_mut().downcast_mut::<T>()
}
pub fn is_cpu_device(device: &dyn Device) -> bool {
cast_device::<CpuDevice>(device).is_some()
}
pub fn is_cuda_device(device: &dyn Device) -> bool {
cast_device::<CudaDevice>(device).is_some()
}
pub fn is_metal_device(device: &dyn Device) -> bool {
cast_device::<MetalDevice>(device).is_some()
}
pub fn is_wgpu_device(device: &dyn Device) -> bool {
cast_device::<WgpuDevice>(device).is_some()
}
pub fn device_implementation_name(device: &dyn Device) -> &'static str {
if is_cpu_device(device) {
"CPU"
} else if is_cuda_device(device) {
"CUDA"
} else if is_metal_device(device) {
"Metal"
} else if is_wgpu_device(device) {
"WebGPU"
} else {
"Unknown"
}
}
pub fn create_all_available_devices() -> Vec<Box<dyn Device>> {
let mut devices = Vec::new();
for device_type in DeviceFactory::available_device_types() {
if let Ok(device) = DeviceFactory::create_device(device_type) {
devices.push(device);
}
}
devices
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_device() {
let device = CpuDevice::new();
assert_eq!(device.device_type(), DeviceType::Cpu);
assert_eq!(device.name(), "CPU");
assert!(device.is_available().expect("is_available should succeed"));
let cloned = device.clone_device().expect("clone_device should succeed");
assert_eq!(cloned.device_type(), DeviceType::Cpu);
}
#[test]
fn test_cpu_device_with_threads() {
let device = CpuDevice::with_threads(4).expect("with_threads should succeed");
assert!(device.thread_pool().is_some());
let result = device.execute_parallel(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_simd_level_detection() {
let level = CpuDevice::detect_simd_level();
match level {
SimdLevel::None
| SimdLevel::Sse
| SimdLevel::Avx
| SimdLevel::Avx2
| SimdLevel::Avx512 => {}
}
}
#[test]
fn test_device_factory() {
let cpu_device =
DeviceFactory::create_device(DeviceType::Cpu).expect("create_device should succeed");
assert_eq!(cpu_device.device_type(), DeviceType::Cpu);
assert!(DeviceFactory::is_device_type_available(DeviceType::Cpu));
let available_types = DeviceFactory::available_device_types();
assert!(available_types.contains(&DeviceType::Cpu));
}
#[test]
fn test_device_casting() {
let device = CpuDevice::new();
let device_ref: &dyn Device = &device;
assert!(utils::is_cpu_device(device_ref));
assert!(!utils::is_cuda_device(device_ref));
assert!(!utils::is_metal_device(device_ref));
assert!(!utils::is_wgpu_device(device_ref));
let cpu_device = utils::cast_device::<CpuDevice>(device_ref);
assert!(cpu_device.is_some());
assert_eq!(utils::device_implementation_name(device_ref), "CPU");
}
#[cfg(feature = "cuda")]
#[test]
fn test_cuda_device() {
if let Ok(device) = CudaDevice::new(0) {
assert_eq!(device.device_type(), DeviceType::Cuda(0));
assert_eq!(device.device_index(), 0);
assert!(device.is_available().expect("is_available should succeed"));
if let Ok(stream_id) = device.create_stream() {
assert!(stream_id > 0);
}
}
}
#[cfg(target_os = "macos")]
#[test]
fn test_metal_device() {
if let Ok(device) = MetalDevice::new(0) {
assert_eq!(device.device_type(), DeviceType::Metal(0));
assert_eq!(device.device_index(), 0);
assert!(device.is_available().expect("is_available should succeed"));
}
}
#[cfg(feature = "wgpu")]
#[test]
fn test_wgpu_device() {
if let Ok(device) = WgpuDevice::new(0) {
assert_eq!(device.device_type(), DeviceType::Wgpu(0));
assert_eq!(device.device_index(), 0);
assert!(device.is_available().expect("is_available should succeed"));
}
}
#[test]
fn test_create_all_available_devices() {
let devices = utils::create_all_available_devices();
assert!(!devices.is_empty());
assert!(devices.iter().any(|d| d.device_type() == DeviceType::Cpu));
}
}