use std::sync::OnceLock;
use crate::quantization::QuantizationLevel;
use crate::registry::AcceleratorRegistry;
const PRICING_JSON: &str = include_str!("../data/cloud_pricing.json");
static PARSED_INSTANCES: OnceLock<Vec<CloudGpuInstance>> = OnceLock::new();
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CloudGpuInstance {
pub name: String,
pub provider: String,
pub gpu: String,
pub gpu_count: u32,
pub gpu_memory_gb: u32,
pub total_gpu_memory_gb: u32,
pub vcpus: u32,
pub ram_gb: u32,
pub interconnect: String,
pub price_per_hour: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CloudProvider {
Aws,
Gcp,
Azure,
}
impl CloudProvider {
#[inline]
fn as_str(self) -> &'static str {
match self {
Self::Aws => "aws",
Self::Gcp => "gcp",
Self::Azure => "azure",
}
}
}
#[derive(Debug, Clone)]
pub struct InstanceRecommendation {
pub instance: CloudGpuInstance,
pub memory_required_bytes: u64,
pub memory_headroom_pct: f64,
}
#[must_use]
pub fn all_instances() -> &'static [CloudGpuInstance] {
#[derive(serde::Deserialize)]
struct PricingData {
#[serde(default)]
instances: Vec<CloudGpuInstance>,
}
PARSED_INSTANCES.get_or_init(|| {
serde_json::from_str::<PricingData>(PRICING_JSON)
.map(|d| d.instances)
.unwrap_or_default()
})
}
#[must_use]
pub fn recommend_instance(
model_params: u64,
quant: &QuantizationLevel,
provider: Option<CloudProvider>,
) -> Vec<InstanceRecommendation> {
let needed = AcceleratorRegistry::estimate_memory(model_params, quant);
let needed_gb = (needed as f64) / crate::units::BYTES_PER_GIB;
let mut candidates: Vec<InstanceRecommendation> = all_instances()
.iter()
.filter(|inst| {
if let Some(p) = provider
&& inst.provider != p.as_str()
{
return false;
}
inst.total_gpu_memory_gb as f64 >= needed_gb
})
.map(|inst| {
let headroom = (inst.total_gpu_memory_gb as f64 - needed_gb)
/ inst.total_gpu_memory_gb as f64
* 100.0;
InstanceRecommendation {
instance: inst.clone(),
memory_required_bytes: needed,
memory_headroom_pct: headroom,
}
})
.collect();
candidates.sort_by(|a, b| {
a.instance
.price_per_hour
.partial_cmp(&b.instance.price_per_hour)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates
}
#[must_use]
pub fn cheapest_instance(
model_params: u64,
quant: &QuantizationLevel,
provider: Option<CloudProvider>,
) -> Option<InstanceRecommendation> {
recommend_instance(model_params, quant, provider)
.into_iter()
.next()
}