#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DevicePlacement {
Gpu,
Cpu,
}
#[derive(Debug, Clone)]
pub struct LayerDeviceMap {
placements: Vec<DevicePlacement>,
pub embed_placement: DevicePlacement,
pub lm_head_placement: DevicePlacement,
}
impl LayerDeviceMap {
pub fn all_gpu(num_layers: usize) -> Self {
Self {
placements: vec![DevicePlacement::Gpu; num_layers],
embed_placement: DevicePlacement::Gpu,
lm_head_placement: DevicePlacement::Gpu,
}
}
pub fn with_gpu_layers(num_layers: usize, gpu_layers: usize) -> Self {
let mut placements = Vec::with_capacity(num_layers);
for i in 0..num_layers {
if i < gpu_layers {
placements.push(DevicePlacement::Gpu);
} else {
placements.push(DevicePlacement::Cpu);
}
}
Self {
placements,
embed_placement: DevicePlacement::Gpu,
lm_head_placement: DevicePlacement::Gpu,
}
}
pub fn placement(&self, layer_idx: usize) -> DevicePlacement {
self.placements
.get(layer_idx)
.copied()
.unwrap_or(DevicePlacement::Cpu)
}
pub fn gpu_layer_count(&self) -> usize {
self.placements
.iter()
.filter(|p| **p == DevicePlacement::Gpu)
.count()
}
pub fn cpu_layer_count(&self) -> usize {
self.placements
.iter()
.filter(|p| **p == DevicePlacement::Cpu)
.count()
}
}