use super::{HardwareCapabilities, HardwareConfig, HardwareMetrics, HardwareResult, HardwareType};
use crate::tensor::Tensor;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait HardwareDevice: Send + Sync {
fn device_id(&self) -> &str;
fn hardware_type(&self) -> HardwareType;
fn capabilities(&self) -> &HardwareCapabilities;
async fn initialize(&mut self, config: &HardwareConfig) -> HardwareResult<()>;
async fn shutdown(&mut self) -> HardwareResult<()>;
fn is_available(&self) -> bool;
fn status(&self) -> DeviceStatus;
async fn metrics(&self) -> HardwareResult<HardwareMetrics>;
async fn reset(&mut self) -> HardwareResult<()>;
async fn allocate_memory(&mut self, size: usize) -> HardwareResult<DeviceMemory>;
async fn free_memory(&mut self, memory: DeviceMemory) -> HardwareResult<()>;
async fn synchronize(&self) -> HardwareResult<()>;
}
#[async_trait]
pub trait HardwareBackend: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
async fn discover_devices(&self) -> HardwareResult<Vec<Box<dyn HardwareDevice>>>;
async fn create_device(
&self,
config: &HardwareConfig,
) -> HardwareResult<Box<dyn HardwareDevice>>;
fn is_compatible(&self, hardware_type: HardwareType) -> bool;
fn supported_operations(&self) -> &[String];
fn validate_config(&self, config: &HardwareConfig) -> HardwareResult<()>;
}
#[async_trait]
pub trait HardwareOperation: Send + Sync {
fn name(&self) -> &str;
async fn execute(
&self,
device: &mut dyn HardwareDevice,
inputs: &[Tensor],
outputs: &mut [Tensor],
params: &HashMap<String, OperationParameter>,
) -> HardwareResult<()>;
fn validate_params(&self, params: &HashMap<String, OperationParameter>) -> HardwareResult<()>;
fn requirements(&self) -> OperationRequirements;
fn estimate_cost(&self, inputs: &[Tensor], params: &HashMap<String, OperationParameter>)
-> f64;
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DeviceStatus {
pub online: bool,
pub busy: bool,
pub error: Option<String>,
pub memory_usage: MemoryUsage,
pub temperature: Option<f64>,
pub power_consumption: Option<f64>,
pub utilization: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryUsage {
pub total: usize,
pub used: usize,
pub free: usize,
pub fragmentation: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DeviceMemory {
pub address: usize,
pub size: usize,
pub memory_type: MemoryType,
pub device_id: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MemoryType {
Local,
Host,
Shared,
Unified,
Persistent,
Cache,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum OperationParameter {
Integer(i64),
Float(f64),
String(String),
Boolean(bool),
Array(Vec<OperationParameter>),
Object(HashMap<String, OperationParameter>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OperationRequirements {
pub min_memory: usize,
pub compute_units: Option<u32>,
pub data_types: Vec<super::DataType>,
pub capabilities: Vec<String>,
pub performance: PerformanceRequirements,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct PerformanceRequirements {
pub max_latency: Option<f64>,
pub min_throughput: Option<f64>,
pub memory_bandwidth: Option<f64>,
pub power_limit: Option<f64>,
}
#[async_trait]
pub trait AsyncHardwareOperation: Send + Sync {
async fn start(
&self,
device: &mut dyn HardwareDevice,
inputs: &[Tensor],
params: &HashMap<String, OperationParameter>,
) -> HardwareResult<AsyncOperationHandle>;
async fn status(&self, handle: &AsyncOperationHandle) -> HardwareResult<AsyncOperationStatus>;
async fn results(&self, handle: &AsyncOperationHandle) -> HardwareResult<Vec<Tensor>>;
async fn cancel(&self, handle: &AsyncOperationHandle) -> HardwareResult<()>;
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AsyncOperationHandle {
pub id: String,
pub device_id: String,
pub operation_name: String,
pub start_time: std::time::SystemTime,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum AsyncOperationStatus {
Queued,
Running,
Completed,
Failed(String),
Cancelled,
}
pub trait HardwareScheduler: Send + Sync + std::fmt::Debug {
fn schedule_operation(
&self,
operation: &dyn HardwareOperation,
inputs: &[Tensor],
params: &HashMap<String, OperationParameter>,
) -> HardwareResult<String>;
fn statistics(&self) -> SchedulerStatistics;
fn update_priorities(&mut self, priorities: HashMap<String, f64>);
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SchedulerStatistics {
pub total_operations: u64,
pub operations_per_device: HashMap<String, u64>,
pub avg_scheduling_time: f64,
pub device_utilization: HashMap<String, f64>,
pub failed_operations: u64,
}
impl Default for DeviceStatus {
fn default() -> Self {
Self {
online: false,
busy: false,
error: None,
memory_usage: MemoryUsage::default(),
temperature: None,
power_consumption: None,
utilization: 0.0,
}
}
}
impl Default for MemoryUsage {
fn default() -> Self {
Self {
total: 0,
used: 0,
free: 0,
fragmentation: 0.0,
}
}
}
impl Default for OperationRequirements {
fn default() -> Self {
Self {
min_memory: 0,
compute_units: None,
data_types: vec![super::DataType::F32],
capabilities: vec![],
performance: PerformanceRequirements::default(),
}
}
}
impl Default for SchedulerStatistics {
fn default() -> Self {
Self {
total_operations: 0,
operations_per_device: HashMap::new(),
avg_scheduling_time: 0.0,
device_utilization: HashMap::new(),
failed_operations: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_status_default() {
let status = DeviceStatus::default();
assert!(!status.online);
assert!(!status.busy);
assert!(status.error.is_none());
assert_eq!(status.utilization, 0.0);
}
#[test]
fn test_memory_usage_calculation() {
let mut usage = MemoryUsage::default();
usage.total = 1000;
usage.used = 600;
usage.free = usage.total - usage.used;
assert_eq!(usage.free, 400);
}
#[test]
fn test_operation_parameter_types() {
let int_param = OperationParameter::Integer(42);
let float_param = OperationParameter::Float(std::f64::consts::PI);
let _string_param = OperationParameter::String("test".to_string());
let _bool_param = OperationParameter::Boolean(true);
match int_param {
OperationParameter::Integer(val) => assert_eq!(val, 42),
_ => panic!("Expected Integer parameter but got {:?}", int_param),
}
match float_param {
OperationParameter::Float(val) => assert_eq!(val, std::f64::consts::PI),
_ => panic!("Expected Float parameter but got {:?}", float_param),
}
}
#[test]
fn test_async_operation_status() {
let status = AsyncOperationStatus::Queued;
assert_eq!(status, AsyncOperationStatus::Queued);
let failed_status = AsyncOperationStatus::Failed("test error".to_string());
match failed_status {
AsyncOperationStatus::Failed(msg) => assert_eq!(msg, "test error"),
_ => panic!("Expected Failed status but got {:?}", failed_status),
}
}
#[test]
fn test_memory_type_equality() {
assert_eq!(MemoryType::Local, MemoryType::Local);
assert_ne!(MemoryType::Local, MemoryType::Host);
assert_eq!(MemoryType::Shared, MemoryType::Shared);
}
}