use std::fmt;
use serde::{Deserialize, Serialize};
use crate::hardware::AcceleratorType;
use crate::profile::AcceleratorProfile;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum AcceleratorRequirement {
#[default]
None,
Gpu,
Tpu { min_chips: u32 },
Gaudi,
AwsNeuron,
GpuOrTpu,
AnyAccelerator,
}
impl fmt::Display for AcceleratorRequirement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Gpu => write!(f, "gpu"),
Self::Tpu { min_chips } => write!(f, "tpu({}+ chips)", min_chips),
Self::Gaudi => write!(f, "gaudi"),
Self::AwsNeuron => write!(f, "aws-neuron"),
Self::GpuOrTpu => write!(f, "gpu-or-tpu"),
Self::AnyAccelerator => write!(f, "any-accelerator"),
}
}
}
impl AcceleratorRequirement {
#[must_use]
#[inline]
pub fn satisfied_by(&self, profile: &AcceleratorProfile) -> bool {
if !profile.available {
return false;
}
match self {
Self::None => true,
Self::Gpu => profile.accelerator.is_gpu(),
Self::Tpu { min_chips } => match &profile.accelerator {
AcceleratorType::Tpu { chip_count, .. } => *chip_count >= *min_chips,
_ => false,
},
Self::Gaudi => matches!(profile.accelerator, AcceleratorType::Gaudi { .. }),
Self::AwsNeuron => matches!(profile.accelerator, AcceleratorType::AwsNeuron { .. }),
Self::GpuOrTpu => profile.accelerator.is_gpu() || profile.accelerator.is_tpu(),
Self::AnyAccelerator => !matches!(profile.accelerator, AcceleratorType::Cpu),
}
}
}