use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::RwLock;
use ::ndarray::{Array, Dim, Dimension, SliceArg, SliceInfo, SliceInfoElem};
use num_traits;
use crate::array_protocol::{
ArrayFunction, ArrayProtocol, DistributedBackend, DistributedConfig, DistributedNdarray,
DistributionStrategy, GPUBackend, GPUConfig, GPUNdarray, NdarrayWrapper, NotImplemented,
};
use crate::error::CoreResult;
#[derive(Debug, Clone)]
pub struct AutoDeviceConfig {
pub gpu_threshold: usize,
pub distributed_threshold: usize,
pub enable_mixed_precision: bool,
pub prefer_memory_efficiency: bool,
pub auto_transfer: bool,
pub prefer_data_locality: bool,
pub preferred_gpu_backend: GPUBackend,
pub fallback_to_cpu: bool,
}
impl Default for AutoDeviceConfig {
fn default() -> Self {
Self {
gpu_threshold: 1_000_000, distributed_threshold: 100_000_000, enable_mixed_precision: false,
prefer_memory_efficiency: false,
auto_transfer: true,
prefer_data_locality: true,
preferred_gpu_backend: GPUBackend::CUDA,
fallback_to_cpu: true,
}
}
}
pub static AUTO_DEVICE_CONFIG: RwLock<AutoDeviceConfig> = RwLock::new(AutoDeviceConfig {
gpu_threshold: 1_000_000,
distributed_threshold: 100_000_000,
enable_mixed_precision: false,
prefer_memory_efficiency: false,
auto_transfer: true,
prefer_data_locality: true,
preferred_gpu_backend: GPUBackend::CUDA,
fallback_to_cpu: true,
});
#[allow(dead_code)]
pub fn set_auto_device_config(config: AutoDeviceConfig) {
if let Ok(mut global_config) = AUTO_DEVICE_CONFIG.write() {
*global_config = config;
}
}
#[allow(dead_code)]
pub fn get_auto_device_config() -> AutoDeviceConfig {
AUTO_DEVICE_CONFIG
.read()
.map(|c| c.clone())
.unwrap_or_default()
}
#[allow(dead_code)]
pub fn determine_best_device<T, D>(array: &Array<T, D>) -> DeviceType
where
T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
D: Dimension + crate::ndarray::RemoveAxis,
{
let config = get_auto_device_config();
let size = array.len();
if size >= config.distributed_threshold {
DeviceType::Distributed
} else if size >= config.gpu_threshold {
DeviceType::GPU
} else {
DeviceType::CPU
}
}
#[allow(dead_code)]
pub fn determine_best_device_for_operation<T, D>(
arrays: &[&Array<T, D>],
operation: &str,
) -> DeviceType
where
T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
D: Dimension + crate::ndarray::RemoveAxis,
{
let config = get_auto_device_config();
let is_complex_operation = matches!(operation, "matmul" | "svd" | "inverse" | "conv2d");
let total_size: usize = arrays.iter().map(|arr| arr.len()).sum();
let gpu_threshold = if is_complex_operation {
config.gpu_threshold / 10 } else {
config.gpu_threshold
};
let distributed_threshold = if is_complex_operation {
config.distributed_threshold / 2 } else {
config.distributed_threshold
};
if total_size >= distributed_threshold {
DeviceType::Distributed
} else if total_size >= gpu_threshold {
DeviceType::GPU
} else {
DeviceType::CPU
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceType {
CPU,
GPU,
Distributed,
}
#[allow(dead_code)]
pub fn convert_to_device<T, D>(array: Array<T, D>, device: DeviceType) -> Box<dyn ArrayProtocol>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
match device {
DeviceType::CPU => Box::new(NdarrayWrapper::new(array.clone())),
DeviceType::GPU => {
let config = get_auto_device_config();
let gpu_config = GPUConfig {
backend: config.preferred_gpu_backend,
device_id: 0,
async_ops: true,
mixed_precision: config.enable_mixed_precision,
memory_fraction: 0.9,
};
Box::new(GPUNdarray::new(array.clone(), gpu_config))
}
DeviceType::Distributed => {
let dist_config = DistributedConfig {
chunks: 2, balance: true,
strategy: DistributionStrategy::RowWise,
backend: DistributedBackend::Threaded,
};
Box::new(DistributedNdarray::from_array(&array, dist_config))
}
}
}
pub struct AutoDevice<T, D>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
array: Array<T, D>,
device: DeviceType,
device_array: Option<Box<dyn ArrayProtocol>>,
}
impl<T, D> std::fmt::Debug for AutoDevice<T, D>
where
T: Clone
+ Send
+ Sync
+ std::fmt::Debug
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + std::fmt::Debug + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutoDevice")
.field("array", &self.array)
.field("device", &self.device)
.field("device_array", &self.device_array.is_some())
.finish()
}
}
impl<T, D> AutoDevice<T, D>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
pub fn new(array: Array<T, D>) -> Self {
let device = determine_best_device(&array);
let device_array = None;
Self {
array,
device,
device_array,
}
}
pub fn on_device(&mut self, device: DeviceType) -> &dyn ArrayProtocol {
if self.device != device || self.device_array.is_none() {
self.device = device;
self.device_array = Some(convert_to_device(self.array.clone(), device));
}
self.device_array
.as_ref()
.expect("Operation failed")
.as_ref()
}
pub fn device(&self) -> DeviceType {
self.device
}
pub const fn array(&self) -> &Array<T, D> {
&self.array
}
}
impl<T, D> Clone for AutoDevice<T, D>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
fn clone(&self) -> Self {
Self {
array: self.array.clone(),
device: self.device,
device_array: self.device_array.clone(),
}
}
}
impl<T, D> ArrayProtocol for AutoDevice<T, D>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
fn array_function(
&self,
func: &ArrayFunction,
types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
if let Some(device_array) = &self.device_array {
device_array.array_function(func, types, args, kwargs)
} else {
let device = determine_best_device(&self.array);
let temp_array = convert_to_device(self.array.clone(), device);
temp_array.array_function(func, types, args, kwargs)
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.array.shape()
}
fn dtype(&self) -> TypeId {
TypeId::of::<T>()
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
#[allow(dead_code)]
pub fn auto_execute<T, D, F, R>(
arrays: &mut [&mut AutoDevice<T, D>],
operation: &str,
executor: F,
) -> CoreResult<R>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
R: 'static,
{
let best_device = determine_best_device_for_operation(
&arrays.iter().map(|a| &a.array).collect::<Vec<_>>(),
operation,
);
let device_arrays: Vec<&dyn ArrayProtocol> = arrays
.iter_mut()
.map(|a| a.on_device(best_device))
.collect();
executor(&device_arrays)
}
pub mod ops {
use super::*;
use crate::array_protocol::operations as ap_ops;
use crate::error::{CoreError, ErrorContext};
pub fn matmul<T, D>(
a: &mut AutoDevice<T, D>,
b: &mut AutoDevice<T, D>,
) -> CoreResult<Box<dyn ArrayProtocol>>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
auto_execute(&mut [a, b], "matmul", |arrays| {
match ap_ops::matmul(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
pub fn add<T, D>(
a: &mut AutoDevice<T, D>,
b: &mut AutoDevice<T, D>,
) -> CoreResult<Box<dyn ArrayProtocol>>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
auto_execute(&mut [a, b], "add", |arrays| {
match ap_ops::add(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
pub fn multiply<T, D>(
a: &mut AutoDevice<T, D>,
b: &mut AutoDevice<T, D>,
) -> CoreResult<Box<dyn ArrayProtocol>>
where
T: Clone
+ Send
+ Sync
+ 'static
+ num_traits::Zero
+ std::ops::Div<f64, Output = T>
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Add<Output = T>,
D: Dimension + crate::ndarray::RemoveAxis + 'static,
SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>,
SliceInfo<[SliceInfoElem; 2], Dim<[usize; 2]>, Dim<[usize; 2]>>: SliceArg<D>,
{
auto_execute(&mut [a, b], "multiply", |arrays| {
match ap_ops::multiply(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ::ndarray::{arr2, Array2};
#[test]
fn test_auto_device_selection() {
crate::array_protocol::init();
let small_array = Array2::<f64>::ones((10, 10));
let device = determine_best_device(&small_array);
assert_eq!(device, DeviceType::CPU);
let mut config = get_auto_device_config();
config.gpu_threshold = 50; set_auto_device_config(config);
let device = determine_best_device(&small_array);
assert_eq!(device, DeviceType::GPU);
set_auto_device_config(AutoDeviceConfig::default());
}
#[test]
fn test_auto_device_wrapper() {
crate::array_protocol::init();
let array_2d = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let array = array_2d.into_dyn();
let mut auto_array = AutoDevice::new(array.clone());
assert_eq!(auto_array.device(), DeviceType::CPU);
let gpu_array = auto_array.on_device(DeviceType::GPU);
assert!(gpu_array
.as_any()
.downcast_ref::<GPUNdarray<f64, crate::ndarray::IxDyn>>()
.is_some());
assert_eq!(auto_array.device(), DeviceType::GPU);
}
}