Skip to main content

ai_hwaccel/hardware/
neuron.rs

1//! AWS Neuron chip types (Inferentia / Trainium).
2
3use std::fmt;
4
5use serde::{Deserialize, Serialize};
6
7/// AWS Neuron chip type.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9pub enum NeuronChipType {
10    /// Inference-optimised (inf1, inf2).
11    Inferentia,
12    /// Training-optimised (trn1).
13    Trainium,
14}
15
16impl NeuronChipType {
17    /// HBM per NeuronCore in bytes.
18    pub fn hbm_per_core_bytes(&self) -> u64 {
19        match self {
20            // inf2 NeuronCore-v2: 32 GB HBM per accelerator (2 cores share it)
21            Self::Inferentia => 16 * 1024 * 1024 * 1024,
22            // trn1 NeuronCore-v2: 32 GB HBM per accelerator
23            Self::Trainium => 32 * 1024 * 1024 * 1024,
24        }
25    }
26}
27
28impl fmt::Display for NeuronChipType {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::Inferentia => write!(f, "Inferentia"),
32            Self::Trainium => write!(f, "Trainium"),
33        }
34    }
35}