#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::CudaRustError;
use super::config::NutanixConfig;
use super::vgpu_scheduler::WorkloadRequest;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CloudProvider {
Aws,
Azure,
Gcp,
OnPrem,
}
impl std::fmt::Display for CloudProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CloudProvider::Aws => write!(f, "AWS"),
CloudProvider::Azure => write!(f, "Azure"),
CloudProvider::Gcp => write!(f, "GCP"),
CloudProvider::OnPrem => write!(f, "On-Prem"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ClusterStatus {
Running,
Provisioning,
Updating,
Error,
Stopped,
}
impl std::fmt::Display for ClusterStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusterStatus::Running => write!(f, "RUNNING"),
ClusterStatus::Provisioning => write!(f, "PROVISIONING"),
ClusterStatus::Updating => write!(f, "UPDATING"),
ClusterStatus::Error => write!(f, "ERROR"),
ClusterStatus::Stopped => write!(f, "STOPPED"),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Nc2Cluster {
pub cluster_id: String,
pub name: String,
pub provider: CloudProvider,
pub region: String,
pub gpu_types: Vec<String>,
pub status: ClusterStatus,
pub gpu_node_count: u32,
pub total_gpu_memory_bytes: u64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct WorkloadPlacement {
pub primary_cluster: String,
pub failover_cluster: Option<String>,
pub placement_reason: String,
pub estimated_cost: CostEstimate,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CostEstimate {
pub hourly_cost: f64,
pub monthly_cost: f64,
pub provider: CloudProvider,
pub instance_type: String,
pub gpu_type: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum MigrationStatus {
Initiated,
Transferring,
Restarting,
Completed,
Failed(String),
}
impl std::fmt::Display for MigrationStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MigrationStatus::Initiated => write!(f, "INITIATED"),
MigrationStatus::Transferring => write!(f, "TRANSFERRING"),
MigrationStatus::Restarting => write!(f, "RESTARTING"),
MigrationStatus::Completed => write!(f, "COMPLETED"),
MigrationStatus::Failed(reason) => write!(f, "FAILED: {}", reason),
}
}
}
pub struct Nc2Client {
#[allow(dead_code)]
config: NutanixConfig,
#[cfg(feature = "nutanix")]
#[allow(dead_code)]
client: reqwest::Client,
}
impl Nc2Client {
pub fn new(config: NutanixConfig) -> Result<Self, CudaRustError> {
#[cfg(feature = "nutanix")]
{
let builder = reqwest::Client::builder().timeout(config.timeout);
let client = builder.build().map_err(|e| {
CudaRustError::RuntimeError(format!("Failed to create HTTP client: {}", e))
})?;
Ok(Self { config, client })
}
#[cfg(not(feature = "nutanix"))]
{
Ok(Self { config })
}
}
pub async fn discover_nc2_clusters(&self) -> Result<Vec<Nc2Cluster>, CudaRustError> {
#[cfg(feature = "nutanix")]
{
Err(CudaRustError::RuntimeError(
"NC2 cluster discovery requires Prism Central connection".to_string(),
))
}
#[cfg(not(feature = "nutanix"))]
{
Ok(self.local_nc2_clusters())
}
}
pub async fn find_optimal_placement(
&self,
workload: &WorkloadRequest,
) -> Result<WorkloadPlacement, CudaRustError> {
let clusters = self.discover_nc2_clusters().await?;
let running_clusters: Vec<&Nc2Cluster> = clusters
.iter()
.filter(|c| c.status == ClusterStatus::Running)
.collect();
if running_clusters.is_empty() {
return Err(CudaRustError::RuntimeError(
"No running NC2 clusters available for placement".to_string(),
));
}
let suitable: Vec<&Nc2Cluster> = running_clusters
.iter()
.filter(|c| c.total_gpu_memory_bytes >= workload.min_gpu_memory)
.copied()
.collect();
if suitable.is_empty() {
return Err(CudaRustError::RuntimeError(format!(
"No NC2 cluster has sufficient GPU memory for workload '{}' ({} bytes required)",
workload.name, workload.min_gpu_memory
)));
}
let primary = suitable
.iter()
.min_by(|a, b| {
let cost_a = self.provider_cost_factor(&a.provider);
let cost_b = self.provider_cost_factor(&b.provider);
cost_a
.partial_cmp(&cost_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap();
let failover = suitable
.iter()
.find(|c| c.cluster_id != primary.cluster_id)
.map(|c| c.cluster_id.clone());
let gpu_type = primary
.gpu_types
.first()
.cloned()
.unwrap_or_else(|| "unknown".to_string());
let cost = self.estimate_cost(primary.provider.clone(), &gpu_type, 730);
let reason = format!(
"Selected {} ({}) for lowest cost ({:.2}/hr); {} GPU memory available",
primary.name,
primary.provider,
cost.hourly_cost,
format_bytes(primary.total_gpu_memory_bytes),
);
Ok(WorkloadPlacement {
primary_cluster: primary.cluster_id.clone(),
failover_cluster: failover,
placement_reason: reason,
estimated_cost: cost,
})
}
pub fn estimate_cost(
&self,
provider: CloudProvider,
gpu_type: &str,
hours: u32,
) -> CostEstimate {
let (hourly, instance_type) = match (&provider, gpu_type) {
(CloudProvider::Aws, t) if t.contains("A100") => (3.40, "p4d.24xlarge"),
(CloudProvider::Aws, t) if t.contains("H100") => (6.50, "p5.48xlarge"),
(CloudProvider::Aws, _) => (2.10, "g5.xlarge"),
(CloudProvider::Azure, t) if t.contains("A100") => (3.67, "Standard_NC96ads_A100_v4"),
(CloudProvider::Azure, t) if t.contains("H100") => (7.00, "Standard_ND96isr_H100_v5"),
(CloudProvider::Azure, _) => (2.30, "Standard_NC6s_v3"),
(CloudProvider::Gcp, t) if t.contains("A100") => (3.22, "a2-highgpu-1g"),
(CloudProvider::Gcp, t) if t.contains("H100") => (6.20, "a3-highgpu-1g"),
(CloudProvider::Gcp, _) => (1.90, "n1-standard-4-t4"),
(CloudProvider::OnPrem, _) => (0.50, "bare-metal"),
};
CostEstimate {
hourly_cost: hourly,
monthly_cost: hourly * hours as f64,
provider,
instance_type: instance_type.to_string(),
gpu_type: gpu_type.to_string(),
}
}
pub async fn migrate_workload(
&self,
from: &str,
to: &str,
workload_id: &str,
) -> Result<MigrationStatus, CudaRustError> {
#[cfg(feature = "nutanix")]
{
let _ = (from, to, workload_id);
Err(CudaRustError::RuntimeError(
"Workload migration requires Prism Central connection".to_string(),
))
}
#[cfg(not(feature = "nutanix"))]
{
self.local_migrate(from, to, workload_id)
}
}
fn provider_cost_factor(&self, provider: &CloudProvider) -> f64 {
match provider {
CloudProvider::OnPrem => 0.5,
CloudProvider::Gcp => 1.0,
CloudProvider::Aws => 1.1,
CloudProvider::Azure => 1.2,
}
}
#[cfg(not(feature = "nutanix"))]
fn local_nc2_clusters(&self) -> Vec<Nc2Cluster> {
let mut clusters = Vec::new();
if let Ok(output) = std::process::Command::new("curl")
.args([
"-s",
"--connect-timeout",
"1",
"http://169.254.169.254/latest/meta-data/instance-id",
])
.output()
{
if output.status.success() && !output.stdout.is_empty() {
let instance_id =
String::from_utf8_lossy(&output.stdout).to_string();
if !instance_id.is_empty()
&& !instance_id.contains("<!DOCTYPE")
{
let region = std::process::Command::new("curl")
.args([
"-s",
"--connect-timeout",
"1",
"http://169.254.169.254/latest/meta-data/placement/region",
])
.output()
.ok()
.and_then(|o| {
if o.status.success() {
String::from_utf8(o.stdout).ok()
} else {
None
}
})
.unwrap_or_else(|| "unknown".to_string());
clusters.push(Nc2Cluster {
cluster_id: format!("nc2-aws-{}", ®ion),
name: format!("NC2-AWS-{}", region),
provider: CloudProvider::Aws,
region,
gpu_types: vec!["Detected".to_string()],
status: ClusterStatus::Running,
gpu_node_count: 1,
total_gpu_memory_bytes: 0,
});
}
}
}
if clusters.is_empty() {
let has_nvidia = std::process::Command::new("nvidia-smi")
.arg("--list-gpus")
.output()
.map(|o| o.status.success())
.unwrap_or(false);
let has_rocm = std::path::Path::new("/opt/rocm").exists();
if has_nvidia || has_rocm {
clusters.push(Nc2Cluster {
cluster_id: "local-onprem".to_string(),
name: "Local GPU Cluster".to_string(),
provider: CloudProvider::OnPrem,
region: "local".to_string(),
gpu_types: vec!["Local".to_string()],
status: ClusterStatus::Running,
gpu_node_count: 1,
total_gpu_memory_bytes: 0,
});
}
}
clusters
}
#[cfg(not(feature = "nutanix"))]
fn local_migrate(
&self,
from: &str,
to: &str,
workload_id: &str,
) -> Result<MigrationStatus, CudaRustError> {
if from == to {
return Err(CudaRustError::RuntimeError(
"Source and destination clusters must be different".to_string(),
));
}
if workload_id.is_empty() {
return Err(CudaRustError::RuntimeError(
"Workload ID must not be empty".to_string(),
));
}
Ok(MigrationStatus::Initiated)
}
}
fn format_bytes(bytes: u64) -> String {
if bytes >= 1024 * 1024 * 1024 * 1024 {
format!("{:.1} TB", bytes as f64 / (1024.0 * 1024.0 * 1024.0 * 1024.0))
} else if bytes >= 1024 * 1024 * 1024 {
format!("{:.1} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
} else if bytes >= 1024 * 1024 {
format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
} else {
format!("{} B", bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nutanix::config::GpuVendor;
fn make_client() -> Nc2Client {
let config = NutanixConfig::new("https://prism.example.com:9440", "test-key");
Nc2Client::new(config).unwrap()
}
#[test]
fn test_cloud_provider_display() {
assert_eq!(CloudProvider::Aws.to_string(), "AWS");
assert_eq!(CloudProvider::Azure.to_string(), "Azure");
assert_eq!(CloudProvider::Gcp.to_string(), "GCP");
assert_eq!(CloudProvider::OnPrem.to_string(), "On-Prem");
}
#[test]
fn test_cluster_status_display() {
assert_eq!(ClusterStatus::Running.to_string(), "RUNNING");
assert_eq!(ClusterStatus::Stopped.to_string(), "STOPPED");
assert_eq!(ClusterStatus::Error.to_string(), "ERROR");
}
#[test]
fn test_migration_status_display() {
assert_eq!(MigrationStatus::Initiated.to_string(), "INITIATED");
assert_eq!(MigrationStatus::Completed.to_string(), "COMPLETED");
assert_eq!(
MigrationStatus::Failed("timeout".into()).to_string(),
"FAILED: timeout"
);
}
#[test]
fn test_estimate_cost_aws_a100() {
let client = make_client();
let cost = client.estimate_cost(CloudProvider::Aws, "A100", 100);
assert_eq!(cost.hourly_cost, 3.40);
assert!((cost.monthly_cost - 340.0).abs() < 0.01);
assert_eq!(cost.provider, CloudProvider::Aws);
assert_eq!(cost.instance_type, "p4d.24xlarge");
}
#[test]
fn test_estimate_cost_onprem() {
let client = make_client();
let cost = client.estimate_cost(CloudProvider::OnPrem, "A100", 730);
assert!(cost.hourly_cost < 1.0, "On-prem should be cheapest");
assert_eq!(cost.instance_type, "bare-metal");
}
#[tokio::test]
async fn test_local_discover_nc2_clusters() {
let client = make_client();
let clusters = client.discover_nc2_clusters().await.unwrap();
for cluster in &clusters {
assert!(!cluster.cluster_id.is_empty());
assert!(!cluster.name.is_empty());
assert_eq!(cluster.status, ClusterStatus::Running);
}
}
#[tokio::test]
async fn test_local_find_optimal_placement() {
let client = make_client();
let workload = WorkloadRequest::new("test-job", 0)
.with_vendor(GpuVendor::Nvidia);
let result = client.find_optimal_placement(&workload).await;
if let Ok(placement) = result {
assert!(!placement.primary_cluster.is_empty());
}
}
#[tokio::test]
async fn test_local_find_optimal_placement_insufficient_memory() {
let client = make_client();
let workload =
WorkloadRequest::new("huge-job", 2048 * 1024 * 1024 * 1024);
let result = client.find_optimal_placement(&workload).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_local_migrate_workload_success() {
let client = make_client();
let status = client
.migrate_workload("cluster-a", "cluster-b", "workload-123")
.await
.unwrap();
assert_eq!(status, MigrationStatus::Initiated);
}
#[tokio::test]
async fn test_local_migrate_workload_same_cluster_error() {
let client = make_client();
let result = client
.migrate_workload("cluster-a", "cluster-a", "workload-123")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_local_migrate_workload_empty_id_error() {
let client = make_client();
let result = client
.migrate_workload("cluster-a", "cluster-b", "")
.await;
assert!(result.is_err());
}
#[test]
fn test_format_bytes() {
assert_eq!(format_bytes(1024 * 1024 * 1024), "1.0 GB");
assert_eq!(format_bytes(80 * 1024 * 1024 * 1024), "80.0 GB");
assert_eq!(format_bytes(1024 * 1024), "1.0 MB");
assert_eq!(format_bytes(500), "500 B");
}
#[test]
fn test_nc2_client_creation() {
let config = NutanixConfig::new("https://prism.example.com:9440", "key");
let client = Nc2Client::new(config);
assert!(client.is_ok());
}
}