use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum NeuronChipType {
Inferentia,
Trainium,
}
impl NeuronChipType {
pub fn hbm_per_core_bytes(&self) -> u64 {
match self {
Self::Inferentia => 16 * 1024 * 1024 * 1024,
Self::Trainium => 32 * 1024 * 1024 * 1024,
}
}
}
impl fmt::Display for NeuronChipType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Inferentia => write!(f, "Inferentia"),
Self::Trainium => write!(f, "Trainium"),
}
}
}