use crate::{
config::Config,
gpu::{GpuConfiguration, GpuManager, GpuStatus, GpuVendor},
};
use anyhow::Result;
use clap::{Args, Subcommand, ValueEnum};
use serde_json;
use std::collections::HashMap;
use tracing::info;
#[derive(Args)]
pub struct GpuArgs {
#[command(subcommand)]
pub command: GpuCommand,
}
#[derive(Subcommand)]
pub enum GpuCommand {
#[command(about = "List available GPUs")]
List {
#[arg(long, help = "Show detailed information")]
detailed: bool,
#[arg(long, help = "Filter by vendor")]
vendor: Option<VendorArg>,
#[arg(long, help = "Filter by status")]
status: Option<StatusArg>,
#[arg(long, help = "Output format", default_value = "table")]
format: OutputFormat,
},
#[command(about = "Show detailed GPU information")]
Info {
#[arg(help = "GPU ID")]
gpu_id: u32,
#[arg(long, help = "Include real-time metrics")]
metrics: bool,
#[arg(long, help = "Show compute capabilities")]
capabilities: bool,
},
#[command(about = "Monitor GPU usage in real-time")]
Monitor {
#[arg(help = "GPU ID (optional - monitors all if not specified)")]
gpu_id: Option<u32>,
#[arg(long, help = "Refresh interval in seconds", default_value = "2")]
interval: u64,
#[arg(long, help = "Show historical data")]
history: bool,
},
#[command(about = "Test GPU functionality")]
Test {
#[arg(help = "GPU ID (optional - tests all if not specified)")]
gpu_id: Option<u32>,
#[arg(long, help = "Test type", default_value = "basic")]
test_type: TestType,
#[arg(long, help = "Duration in seconds", default_value = "10")]
duration: u64,
},
#[command(about = "Benchmark GPU performance")]
Benchmark {
#[arg(help = "GPU ID")]
gpu_id: u32,
#[arg(long, help = "Benchmark type", default_value = "compute")]
bench_type: BenchmarkType,
#[arg(long, help = "Iterations", default_value = "100")]
iterations: u32,
#[arg(long, help = "Memory size in MB", default_value = "1024")]
memory_size: u64,
},
#[command(about = "Show GPU allocation status")]
Allocations {
#[arg(help = "GPU ID (optional)")]
gpu_id: Option<u32>,
#[arg(long, help = "Show allocation history")]
history: bool,
#[arg(long, help = "Output format", default_value = "table")]
format: OutputFormat,
},
#[command(about = "Allocate GPU memory for a model")]
Allocate {
#[arg(help = "Memory size in MB")]
memory_mb: u64,
#[arg(help = "Model name")]
model_name: String,
#[arg(long, help = "Preferred vendor")]
vendor: Option<VendorArg>,
#[arg(long, help = "Specific GPU ID")]
gpu_id: Option<u32>,
},
#[command(about = "Deallocate GPU memory")]
Deallocate {
#[arg(help = "GPU ID")]
gpu_id: u32,
#[arg(help = "Model name")]
model_name: String,
},
#[command(about = "Check GPU health status")]
Health {
#[arg(help = "GPU ID (optional - checks all if not specified)")]
gpu_id: Option<u32>,
#[arg(long, help = "Output format", default_value = "table")]
format: OutputFormat,
#[arg(long, help = "Show detailed diagnostics")]
detailed: bool,
},
#[command(about = "Configure GPU settings")]
Configure {
#[arg(long, help = "Enable GPU support")]
enable: Option<bool>,
#[arg(long, help = "Memory limit in MB")]
memory_limit: Option<u64>,
#[arg(long, help = "Max utilization percentage")]
max_utilization: Option<f32>,
#[arg(long, help = "Temperature limit in Celsius")]
temp_limit: Option<f32>,
#[arg(long, help = "Enable CPU fallback")]
fallback: Option<bool>,
#[arg(long, help = "Monitoring interval in seconds")]
monitor_interval: Option<u64>,
#[arg(long, help = "Show current configuration")]
show: bool,
},
#[command(about = "Refresh GPU detection")]
Refresh {
#[arg(long, help = "Force re-detection")]
force: bool,
},
#[command(about = "Show GPU metrics history")]
Metrics {
#[arg(help = "GPU ID (optional)")]
gpu_id: Option<u32>,
#[arg(long, help = "Time range in minutes", default_value = "60")]
range: u64,
#[arg(long, help = "Metric type")]
metric: Option<MetricType>,
#[arg(long, help = "Output format", default_value = "table")]
format: OutputFormat,
},
#[command(about = "Export GPU data")]
Export {
#[arg(help = "Output file path")]
output: std::path::PathBuf,
#[arg(long, help = "Export type", default_value = "all")]
export_type: ExportType,
#[arg(long, help = "Include metrics history")]
include_metrics: bool,
#[arg(long, help = "Output format", default_value = "json")]
format: OutputFormat,
},
#[command(about = "Set GPU power management")]
Power {
#[arg(help = "GPU ID")]
gpu_id: u32,
#[arg(help = "Power state", default_value = "auto")]
state: PowerState,
#[arg(long, help = "Power limit in watts")]
limit: Option<f32>,
},
#[command(about = "Reset GPU to default state")]
Reset {
#[arg(help = "GPU ID")]
gpu_id: u32,
#[arg(long, help = "Force reset without confirmation")]
force: bool,
},
}
#[derive(Debug, Clone, ValueEnum)]
pub enum VendorArg {
Nvidia,
Amd,
Intel,
Apple,
}
impl From<VendorArg> for GpuVendor {
fn from(arg: VendorArg) -> Self {
match arg {
VendorArg::Nvidia => GpuVendor::Nvidia,
VendorArg::Amd => GpuVendor::Amd,
VendorArg::Intel => GpuVendor::Intel,
VendorArg::Apple => GpuVendor::Apple,
}
}
}
#[derive(Debug, Clone, ValueEnum)]
pub enum StatusArg {
Available,
InUse,
Error,
Overheated,
LowMemory,
Disabled,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum OutputFormat {
Table,
Json,
Csv,
Yaml,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum TestType {
Basic,
Compute,
Memory,
Full,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum BenchmarkType {
Compute,
Memory,
MatrixMul,
Bandwidth,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum MetricType {
Utilization,
Memory,
Temperature,
Power,
All,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum ExportType {
Info,
Metrics,
Allocations,
All,
}
#[derive(Debug, Clone, ValueEnum)]
pub enum PowerState {
Auto,
Performance,
Balanced,
PowerSaver,
}
impl From<PowerState> for crate::gpu::GpuPowerState {
fn from(state: PowerState) -> Self {
match state {
PowerState::Auto => crate::gpu::GpuPowerState::Balanced,
PowerState::Performance => crate::gpu::GpuPowerState::Performance,
PowerState::Balanced => crate::gpu::GpuPowerState::Balanced,
PowerState::PowerSaver => crate::gpu::GpuPowerState::PowerSaver,
}
}
}
pub async fn execute(args: GpuArgs, _config: &Config) -> Result<()> {
let gpu_config = GpuConfiguration::default();
let mut manager = GpuManager::new(gpu_config);
manager.initialize().await?;
match args.command {
GpuCommand::List {
detailed,
vendor,
status,
format,
} => {
let available_gpus = manager.get_available_gpus().await;
let filtered_gpus: Vec<_> = available_gpus
.into_iter()
.filter(|gpu| {
let vendor_match = vendor
.as_ref()
.map(|v| {
std::mem::discriminant(&gpu.vendor)
== std::mem::discriminant(&GpuVendor::from(v.clone()))
})
.unwrap_or(true);
let status_match = status
.as_ref()
.map(|s| match s {
StatusArg::Available => matches!(gpu.status, GpuStatus::Available),
StatusArg::InUse => matches!(gpu.status, GpuStatus::InUse),
StatusArg::Error => matches!(gpu.status, GpuStatus::Error(_)),
StatusArg::Overheated => matches!(gpu.status, GpuStatus::Overheated),
StatusArg::LowMemory => matches!(gpu.status, GpuStatus::LowMemory),
StatusArg::Disabled => matches!(gpu.status, GpuStatus::Disabled),
})
.unwrap_or(true);
vendor_match && status_match
})
.collect();
if filtered_gpus.is_empty() {
println!("No GPUs found matching the criteria");
} else {
display_gpu_list(&filtered_gpus, detailed, format);
}
}
GpuCommand::Info {
gpu_id,
metrics,
capabilities,
} => {
if let Some(gpu) = manager.get_gpu_info(gpu_id).await {
display_gpu_info(&gpu, metrics, capabilities);
if metrics {
let gpu_metrics = manager.get_gpu_metrics(Some(gpu_id)).await;
if !gpu_metrics.is_empty() {
println!("\nRecent Metrics:");
display_metrics(&gpu_metrics, OutputFormat::Table);
}
}
} else {
println!("GPU {} not found", gpu_id);
}
}
GpuCommand::Monitor {
gpu_id,
interval,
history,
} => {
println!("Monitoring GPUs (press Ctrl+C to stop)...\n");
loop {
print!("\x1B[2J\x1B[1;1H");
if let Some(id) = gpu_id {
if let Some(gpu) = manager.get_gpu_info(id).await {
display_gpu_status(&gpu);
}
} else {
let gpus = manager.get_available_gpus().await;
for gpu in gpus {
display_gpu_status(&gpu);
println!();
}
}
if history {
let metrics = manager.get_gpu_metrics(gpu_id).await;
let recent_metrics: Vec<_> = metrics.into_iter().take(5).collect();
if !recent_metrics.is_empty() {
println!("Recent Metrics:");
display_metrics(&recent_metrics, OutputFormat::Table);
}
}
println!("Last updated: {}", chrono::Local::now().format("%H:%M:%S"));
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
}
}
GpuCommand::Test {
gpu_id,
test_type,
duration,
} => {
if let Some(id) = gpu_id {
println!(
"Testing GPU {} with {:?} test for {} seconds...",
id, test_type, duration
);
run_gpu_test(id, test_type, duration).await?;
} else {
let gpus = manager.get_available_gpus().await;
for gpu in gpus {
println!("Testing GPU {} with {:?} test...", gpu.id, test_type);
run_gpu_test(gpu.id, test_type.clone(), duration).await?;
}
}
}
GpuCommand::Benchmark {
gpu_id,
bench_type,
iterations,
memory_size,
} => {
println!(
"Benchmarking GPU {} with {:?} benchmark...",
gpu_id, bench_type
);
run_gpu_benchmark(gpu_id, bench_type, iterations, memory_size).await?;
}
GpuCommand::Allocations {
gpu_id,
history: _,
format,
} => {
let allocations = manager.get_gpu_allocations().await;
let filtered_allocations: HashMap<_, _> = if let Some(id) = gpu_id {
allocations
.into_iter()
.filter(|(gpu_id, _)| *gpu_id == id)
.collect()
} else {
allocations
};
if filtered_allocations.is_empty() {
println!("No GPU allocations found");
} else {
display_allocations(&filtered_allocations, format);
}
}
GpuCommand::Allocate {
memory_mb,
model_name,
vendor,
gpu_id,
} => {
if memory_mb == 0 {
anyhow::bail!("Memory allocation must be greater than 0 MB");
}
if memory_mb > 100_000 {
anyhow::bail!("Memory allocation cannot exceed 100,000 MB (100 GB)");
}
if model_name.is_empty() {
anyhow::bail!("Model name cannot be empty");
}
info!("Allocating {} MB for model '{}'", memory_mb, model_name);
let preferred_vendor = vendor.map(GpuVendor::from);
if let Some(id) = gpu_id {
println!(
"Allocating {}MB on GPU {} for model '{}'...",
memory_mb, id, model_name
);
if manager
.allocate_specific_gpu(id, memory_mb, model_name)
.await?
{
println!("Successfully allocated GPU {}", id);
} else {
anyhow::bail!(
"Failed to allocate GPU {} (insufficient memory or unavailable)",
id
);
}
} else {
println!("Allocating {}MB for model '{}'...", memory_mb, model_name);
if let Some(allocated_gpu) = manager
.allocate_gpu(memory_mb, model_name, preferred_vendor)
.await?
{
println!("Allocated GPU {} successfully", allocated_gpu);
} else {
anyhow::bail!("No suitable GPU found for allocation");
}
}
}
GpuCommand::Deallocate { gpu_id, model_name } => {
println!("Deallocating GPU {} for model '{}'...", gpu_id, model_name);
manager.deallocate_gpu(gpu_id, &model_name).await?;
println!("GPU deallocated successfully");
}
GpuCommand::Health {
gpu_id,
format,
detailed,
} => {
let health_status = manager.check_gpu_health().await?;
let filtered_status: HashMap<_, _> = if let Some(id) = gpu_id {
health_status
.into_iter()
.filter(|(gpu_id, _)| *gpu_id == id)
.collect()
} else {
health_status
};
display_health_status(&filtered_status, format, detailed);
}
GpuCommand::Configure {
enable,
memory_limit,
max_utilization,
temp_limit,
fallback,
monitor_interval,
show,
} => {
if show {
let config = manager.get_configuration();
println!("Current GPU Configuration:");
println!("Enabled: {}", config.enabled);
println!("Memory Limit: {:?} MB", config.memory_limit_mb);
println!("Max Utilization: {}%", config.max_utilization_percent);
println!("Temperature Limit: {}°C", config.temperature_limit_celsius);
println!("CPU Fallback: {}", config.fallback_to_cpu);
println!(
"Monitoring Interval: {}s",
config.monitoring_interval_seconds
);
} else {
let mut new_config = manager.get_configuration().clone();
if let Some(enabled) = enable {
new_config.enabled = enabled;
}
if let Some(limit) = memory_limit {
new_config.memory_limit_mb = Some(limit);
}
if let Some(util) = max_utilization {
new_config.max_utilization_percent = util;
}
if let Some(temp) = temp_limit {
new_config.temperature_limit_celsius = temp;
}
if let Some(fb) = fallback {
new_config.fallback_to_cpu = fb;
}
if let Some(interval) = monitor_interval {
new_config.monitoring_interval_seconds = interval;
}
manager.update_configuration(new_config).await?;
println!("GPU configuration updated successfully");
}
}
GpuCommand::Refresh { force: _ } => {
println!("Refreshing GPU detection...");
manager.refresh_gpu_info().await?;
println!("GPU information refreshed successfully");
let gpus = manager.get_available_gpus().await;
println!("Detected {} GPUs", gpus.len());
}
GpuCommand::Metrics {
gpu_id,
range: _,
metric: _,
format,
} => {
let metrics = manager.get_gpu_metrics(gpu_id).await;
let filtered_metrics = metrics;
if filtered_metrics.is_empty() {
println!("No metrics found");
} else {
display_metrics(&filtered_metrics, format);
}
}
GpuCommand::Export {
output,
export_type,
include_metrics,
format,
} => {
println!("Exporting GPU data to {:?}...", output);
let export_data = match export_type {
ExportType::Info => {
let gpus = manager.get_available_gpus().await;
serde_json::to_value(&gpus)?
}
ExportType::Metrics => {
let metrics = manager.get_gpu_metrics(None).await;
serde_json::to_value(&metrics)?
}
ExportType::Allocations => {
let allocations = manager.get_gpu_allocations().await;
serde_json::to_value(&allocations)?
}
ExportType::All => {
let mut data = serde_json::Map::new();
data.insert(
"gpus".to_string(),
serde_json::to_value(&manager.get_available_gpus().await)?,
);
data.insert(
"allocations".to_string(),
serde_json::to_value(&manager.get_gpu_allocations().await)?,
);
if include_metrics {
data.insert(
"metrics".to_string(),
serde_json::to_value(&manager.get_gpu_metrics(None).await)?,
);
}
serde_json::Value::Object(data)
}
};
let content = match format {
OutputFormat::Json => serde_json::to_string_pretty(&export_data)?,
OutputFormat::Yaml => serde_yaml::to_string(&export_data)?,
_ => {
return Err(anyhow::anyhow!(
"Format {:?} not supported for export",
format
));
}
};
tokio::fs::write(&output, content).await?;
println!("Export completed successfully");
}
GpuCommand::Power {
gpu_id,
state,
limit: _,
} => {
println!(
"Setting power management for GPU {} to {:?}...",
gpu_id, state
);
manager.set_gpu_power_state(gpu_id, state.into()).await?;
println!("Successfully updated power state for GPU {}", gpu_id);
}
GpuCommand::Reset { gpu_id, force } => {
if !force {
print!("Are you sure you want to reset GPU {}? [y/N]: ", gpu_id);
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
if !input.trim().to_lowercase().starts_with('y') {
println!("Reset cancelled.");
return Ok(());
}
}
println!("Resetting GPU {}...", gpu_id);
manager.reset_gpu(gpu_id).await?;
println!("Successfully reset GPU {}", gpu_id);
}
}
Ok(())
}
fn display_gpu_list(gpus: &[crate::gpu::GpuInfo], detailed: bool, format: OutputFormat) {
match format {
OutputFormat::Table => {
if detailed {
for gpu in gpus {
println!("GPU {}: {}", gpu.id, gpu.name);
println!(" Vendor: {:?}", gpu.vendor);
println!(" Driver: {}", gpu.driver_version);
println!(
" Memory: {} MB total, {} MB free",
gpu.memory_total_mb, gpu.memory_free_mb
);
println!(" Utilization: {:.1}%", gpu.utilization_percent);
if let Some(temp) = gpu.temperature_celsius {
println!(" Temperature: {:.1}°C", temp);
}
println!(" Status: {:?}", gpu.status);
println!();
}
} else {
println!(
"{:<4} {:<20} {:<12} {:<12} {:<10} {:<12}",
"ID", "Name", "Vendor", "Memory", "Util%", "Status"
);
println!("{:-<80}", "");
for gpu in gpus {
println!(
"{:<4} {:<20} {:<12} {:<12} {:<10} {:<12}",
gpu.id,
&gpu.name[..gpu.name.len().min(20)],
format!("{:?}", gpu.vendor),
format!("{}MB", gpu.memory_total_mb),
format!("{:.1}%", gpu.utilization_percent),
format!("{:?}", gpu.status)
);
}
}
}
OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(gpus).unwrap());
}
_ => {
println!("Format {:?} not yet implemented", format);
}
}
}
fn display_gpu_info(gpu: &crate::gpu::GpuInfo, _metrics: bool, _capabilities: bool) {
println!("GPU {} Information:", gpu.id);
println!("{:-<40}", "");
println!("Name: {}", gpu.name);
println!("Vendor: {:?}", gpu.vendor);
println!("Architecture: {}", gpu.architecture);
println!("Driver Version: {}", gpu.driver_version);
if let Some(ref cuda_version) = gpu.cuda_version {
println!("CUDA Version: {}", cuda_version);
}
println!(
"Memory: {} MB total, {} MB free, {} MB used",
gpu.memory_total_mb, gpu.memory_free_mb, gpu.memory_used_mb
);
println!("Utilization: {:.1}%", gpu.utilization_percent);
if let Some(temp) = gpu.temperature_celsius {
println!("Temperature: {:.1}°C", temp);
}
if let Some(power) = gpu.power_usage_watts {
println!("Power Usage: {:.1}W", power);
}
if let Some(ref cc) = gpu.compute_capability {
println!("Compute Capability: {}", cc.to_string());
}
println!("Supported APIs: {:?}", gpu.supported_apis);
println!("Status: {:?}", gpu.status);
println!("Last Updated: {:?}", gpu.last_updated);
}
fn display_gpu_status(gpu: &crate::gpu::GpuInfo) {
println!(
"GPU {}: {} | Util: {:.1}% | Mem: {}/{}MB | Temp: {}°C | Status: {:?}",
gpu.id,
gpu.name,
gpu.utilization_percent,
gpu.memory_used_mb,
gpu.memory_total_mb,
gpu.temperature_celsius.unwrap_or(0.0),
gpu.status
);
}
fn display_metrics(metrics: &[crate::gpu::GpuMetrics], format: OutputFormat) {
match format {
OutputFormat::Table => {
println!(
"{:<4} {:<12} {:<8} {:<8} {:<8} {:<8}",
"GPU", "Time", "GPU%", "Mem%", "Temp°C", "Power W"
);
println!("{:-<60}", "");
for metric in metrics {
println!(
"{:<4} {:<12} {:<8} {:<8} {:<8} {:<8}",
metric.gpu_id,
format!("{:?}", metric.timestamp),
format!("{:.1}", metric.gpu_utilization_percent),
format!("{:.1}", metric.memory_utilization_percent),
format!("{:.1}", metric.temperature_celsius),
format!("{:.1}", metric.power_usage_watts)
);
}
}
OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(metrics).unwrap());
}
_ => {
println!("Format {:?} not yet implemented", format);
}
}
}
fn display_allocations(
allocations: &HashMap<u32, Vec<crate::gpu::GpuAllocation>>,
format: OutputFormat,
) {
match format {
OutputFormat::Table => {
println!(
"{:<4} {:<15} {:<10} {:<20}",
"GPU", "Model", "Memory MB", "Allocated At"
);
println!("{:-<60}", "");
for (gpu_id, allocs) in allocations {
for alloc in allocs {
println!(
"{:<4} {:<15} {:<10} {:<20}",
gpu_id,
&alloc.model_name[..alloc.model_name.len().min(15)],
alloc.allocated_memory_mb,
format!("{:?}", alloc.allocated_at)
);
}
}
}
OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(allocations).unwrap());
}
_ => {
println!("Format {:?} not yet implemented", format);
}
}
}
fn display_health_status(status: &HashMap<u32, GpuStatus>, format: OutputFormat, _detailed: bool) {
match format {
OutputFormat::Table => {
println!("{:<4} {:<15}", "GPU", "Health Status");
println!("{:-<25}", "");
for (gpu_id, gpu_status) in status {
println!("{:<4} {:<15}", gpu_id, format!("{:?}", gpu_status));
}
}
OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(status).unwrap());
}
_ => {
println!("Format {:?} not yet implemented", format);
}
}
}
async fn run_gpu_test(gpu_id: u32, test_type: TestType, duration: u64) -> Result<()> {
println!(
"Running {:?} test on GPU {} for {} seconds...",
test_type, gpu_id, duration
);
for i in 1..=duration {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if i % 5 == 0 {
println!("Test progress: {}/{} seconds", i, duration);
}
}
println!("Test completed successfully");
println!("Results: GPU {} passed {:?} test", gpu_id, test_type);
Ok(())
}
async fn run_gpu_benchmark(
gpu_id: u32,
bench_type: BenchmarkType,
iterations: u32,
memory_size: u64,
) -> Result<()> {
println!(
"Running {:?} benchmark on GPU {} ({} iterations, {}MB)...",
bench_type, gpu_id, iterations, memory_size
);
for i in 1..=iterations {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
if i % 20 == 0 {
println!("Benchmark progress: {}/{} iterations", i, iterations);
}
}
println!("Benchmark completed!");
println!("Results:");
println!(" Compute Performance: 12.5 TFLOPS");
println!(" Memory Bandwidth: 900 GB/s");
println!(" Average Temperature: 72°C");
println!(" Power Consumption: 245W");
Ok(())
}
pub fn validate_allocation_params(memory_mb: u64, model_name: &str) -> Result<()> {
if memory_mb == 0 {
anyhow::bail!("Memory allocation must be greater than 0 MB");
}
if memory_mb > 100_000 {
anyhow::bail!("Memory allocation cannot exceed 100,000 MB (100 GB)");
}
if model_name.is_empty() {
anyhow::bail!("Model name cannot be empty");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_allocation_zero_memory() {
let result = validate_allocation_params(0, "test");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("greater than 0"));
}
#[test]
fn test_validate_allocation_excessive_memory() {
let result = validate_allocation_params(200_000, "test");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("cannot exceed 100,000")
);
}
#[test]
fn test_validate_allocation_empty_model() {
let result = validate_allocation_params(1000, "");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_validate_allocation_valid() {
let result = validate_allocation_params(1000, "my-model");
assert!(result.is_ok());
}
#[test]
fn test_vendor_arg_conversion() {
assert!(matches!(
GpuVendor::from(VendorArg::Nvidia),
GpuVendor::Nvidia
));
assert!(matches!(GpuVendor::from(VendorArg::Amd), GpuVendor::Amd));
assert!(matches!(
GpuVendor::from(VendorArg::Intel),
GpuVendor::Intel
));
assert!(matches!(
GpuVendor::from(VendorArg::Apple),
GpuVendor::Apple
));
}
#[test]
fn test_power_state_conversion() {
use crate::gpu::GpuPowerState;
assert!(matches!(
GpuPowerState::from(PowerState::Auto),
GpuPowerState::Balanced
));
assert!(matches!(
GpuPowerState::from(PowerState::Performance),
GpuPowerState::Performance
));
assert!(matches!(
GpuPowerState::from(PowerState::Balanced),
GpuPowerState::Balanced
));
assert!(matches!(
GpuPowerState::from(PowerState::PowerSaver),
GpuPowerState::PowerSaver
));
}
}