use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CloudProvider {
AWS,
GCP,
Azure,
DigitalOcean,
Vultr,
Hetzner,
RunPod,
LambdaLabs,
PrimeIntellect,
VastAi,
Crusoe,
Generic,
}
impl CloudProvider {
#[must_use]
pub fn supports_tee(self) -> bool {
matches!(self, Self::AWS | Self::GCP | Self::Azure)
}
}
impl std::fmt::Display for CloudProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AWS => write!(f, "AWS"),
Self::GCP => write!(f, "Google Cloud"),
Self::Azure => write!(f, "Azure"),
Self::DigitalOcean => write!(f, "DigitalOcean"),
Self::Vultr => write!(f, "Vultr"),
Self::Hetzner => write!(f, "Hetzner"),
Self::RunPod => write!(f, "RunPod"),
Self::LambdaLabs => write!(f, "Lambda Labs"),
Self::PrimeIntellect => write!(f, "Prime Intellect"),
Self::VastAi => write!(f, "Vast.ai"),
Self::Crusoe => write!(f, "Crusoe"),
Self::Generic => write!(f, "Generic"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceSpec {
pub cpu: f32,
pub memory_gb: f32,
pub storage_gb: f32,
pub gpu_count: Option<u32>,
pub allow_spot: bool,
#[serde(default)]
pub tee_required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DeploymentTarget {
CloudInstance(CloudProvider),
Kubernetes { context: String, namespace: String },
Hybrid {
primary: CloudProvider,
fallback_k8s: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderPreferences {
pub gpu_providers: Vec<CloudProvider>,
pub cpu_intensive: Vec<CloudProvider>,
pub memory_intensive: Vec<CloudProvider>,
pub cost_optimized: Vec<CloudProvider>,
pub tee_capable: Vec<CloudProvider>,
}
impl Default for ProviderPreferences {
fn default() -> Self {
Self {
gpu_providers: vec![
CloudProvider::RunPod,
CloudProvider::LambdaLabs,
CloudProvider::VastAi,
CloudProvider::PrimeIntellect,
CloudProvider::Crusoe,
CloudProvider::GCP,
CloudProvider::AWS,
],
cpu_intensive: vec![
CloudProvider::Hetzner,
CloudProvider::Vultr,
CloudProvider::DigitalOcean,
CloudProvider::AWS,
],
memory_intensive: vec![
CloudProvider::AWS,
CloudProvider::GCP,
CloudProvider::Hetzner,
],
cost_optimized: vec![
CloudProvider::VastAi,
CloudProvider::Hetzner,
CloudProvider::Vultr,
CloudProvider::DigitalOcean,
],
tee_capable: vec![CloudProvider::AWS, CloudProvider::GCP, CloudProvider::Azure],
}
}
}
pub struct ProviderSelector {
preferences: ProviderPreferences,
}
impl ProviderSelector {
#[must_use]
pub fn new(preferences: ProviderPreferences) -> Self {
Self { preferences }
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(ProviderPreferences::default())
}
pub fn select_target(&self, requirements: &ResourceSpec) -> Result<DeploymentTarget> {
info!(
"Selecting deployment target for requirements: {:?}",
requirements
);
let provider = self.select_provider(requirements)?;
Ok(DeploymentTarget::CloudInstance(provider))
}
pub fn select_provider(&self, requirements: &ResourceSpec) -> Result<CloudProvider> {
let candidates = if requirements.tee_required {
info!("TEE required, selecting from TEE-capable providers");
&self.preferences.tee_capable
} else if requirements.gpu_count.is_some() {
info!("GPU required, selecting from GPU providers");
&self.preferences.gpu_providers
} else if requirements.cpu > 8.0 {
info!(
"High CPU requirement ({}), selecting from CPU-intensive providers",
requirements.cpu
);
&self.preferences.cpu_intensive
} else if requirements.memory_gb > 32.0 {
info!(
"High memory requirement ({}GB), selecting from memory-intensive providers",
requirements.memory_gb
);
&self.preferences.memory_intensive
} else {
info!("Standard workload, selecting from cost-optimized providers");
&self.preferences.cost_optimized
};
match candidates.first() {
Some(provider) => {
info!("Selected provider: {:?}", provider);
Ok(*provider)
}
None => {
warn!("No providers configured for workload requirements");
Err(Error::Other(
"No providers configured for the given resource requirements".into(),
))
}
}
}
pub fn get_fallback_providers(&self, requirements: &ResourceSpec) -> Vec<CloudProvider> {
let mut fallbacks = Vec::new();
if requirements.tee_required {
fallbacks.extend(&self.preferences.tee_capable);
let primary = self.select_provider(requirements).ok();
fallbacks.retain(|p| Some(*p) != primary);
fallbacks.dedup();
info!("TEE fallback providers: {:?}", fallbacks);
return fallbacks;
}
if requirements.gpu_count.is_some() {
fallbacks.extend(&self.preferences.cpu_intensive);
} else {
fallbacks.extend(&self.preferences.cost_optimized);
fallbacks.extend(&self.preferences.cpu_intensive);
fallbacks.extend(&self.preferences.memory_intensive);
}
let primary = self.select_provider(requirements).ok();
fallbacks.retain(|p| Some(*p) != primary);
fallbacks.dedup();
info!("Fallback providers: {:?}", fallbacks);
fallbacks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_provider_selection() {
let selector = ProviderSelector::with_defaults();
let requirements = ResourceSpec {
cpu: 4.0,
memory_gb: 16.0,
storage_gb: 100.0,
gpu_count: Some(1),
allow_spot: false,
tee_required: false,
};
let provider = selector.select_provider(&requirements).unwrap();
assert_eq!(provider, CloudProvider::RunPod);
}
#[test]
fn test_cpu_intensive_selection() {
let selector = ProviderSelector::with_defaults();
let requirements = ResourceSpec {
cpu: 16.0, memory_gb: 32.0,
storage_gb: 200.0,
gpu_count: None,
allow_spot: false,
tee_required: false,
};
let provider = selector.select_provider(&requirements).unwrap();
assert_eq!(provider, CloudProvider::Hetzner);
}
#[test]
fn test_cost_optimized_selection() {
let selector = ProviderSelector::with_defaults();
let requirements = ResourceSpec {
cpu: 2.0,
memory_gb: 4.0,
storage_gb: 20.0,
gpu_count: None,
allow_spot: true,
tee_required: false,
};
let provider = selector.select_provider(&requirements).unwrap();
assert_eq!(provider, CloudProvider::VastAi);
}
#[test]
fn test_fallback_providers() {
let selector = ProviderSelector::with_defaults();
let requirements = ResourceSpec {
cpu: 4.0,
memory_gb: 16.0,
storage_gb: 100.0,
gpu_count: Some(1),
allow_spot: false,
tee_required: false,
};
let fallbacks = selector.get_fallback_providers(&requirements);
assert!(fallbacks.contains(&CloudProvider::Vultr));
assert!(fallbacks.contains(&CloudProvider::DigitalOcean));
assert!(!fallbacks.contains(&CloudProvider::GCP));
}
#[test]
fn test_tee_provider_selection() {
let selector = ProviderSelector::with_defaults();
let requirements = ResourceSpec {
cpu: 2.0,
memory_gb: 8.0,
storage_gb: 40.0,
gpu_count: None,
allow_spot: false,
tee_required: true,
};
let provider = selector.select_provider(&requirements).unwrap();
assert_eq!(provider, CloudProvider::AWS);
assert!(provider.supports_tee());
}
}