1use std::{collections::HashMap, fmt::Debug, sync::Arc};
2
3use crate::{
4 pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use hanzo_ml::{DType, Device, DeviceLocation, Result, Tensor};
7use hanzo_quant::log::once_log_info;
8use hanzo_quant::ShardedVarBuilder;
9use serde::{Deserialize, Serialize};
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Serialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14 pub ordinal: usize,
15 pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20 Map(DeviceMapMetadata),
22 Auto(AutoDeviceMapParams),
24 DummyNccl { nm_device: Device },
26 Nccl {
28 nm_device: Device,
29 comm: Arc<hanzo_quant::Comm>,
30 },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34pub struct DeviceMapMetadata {
36 device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37 host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41 pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42 Self {
43 device_layers: Some(device_layers),
44 host_layers: None,
45 }
46 }
47 pub fn dummy() -> Self {
49 Self {
50 device_layers: None,
51 host_layers: None,
52 }
53 }
54
55 pub fn device_layers(&self) -> Option<&[DeviceLayerMapMetadata]> {
56 self.device_layers.as_deref()
57 }
58
59 pub fn host_layers(&self) -> Option<usize> {
60 self.host_layers
61 }
62
63 pub fn to_cli_spec(&self) -> Option<String> {
64 let layers = self.device_layers.as_ref()?;
65 if layers.is_empty() {
66 return None;
67 }
68 Some(
69 layers
70 .iter()
71 .map(|l| format!("{}:{}", l.ordinal, l.layers))
72 .collect::<Vec<_>>()
73 .join(";"),
74 )
75 }
76}
77
78impl DeviceMapSetting {
79 pub fn dummy() -> Self {
81 Self::Map(DeviceMapMetadata::dummy())
82 }
83 pub fn into_mapper(
84 &self,
85 model_layers: usize,
86 device: &Device,
87 topology: Option<&Topology>,
88 all_devices: &[Device],
89 ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
90 match self {
91 Self::Nccl { nm_device, comm } => {
92 once_log_info("Loading model using a NCCL-parallelized pipeline.");
93 Ok(Box::new(NcclDeviceMapper {
94 nm_device: nm_device.clone(),
95 model_layers,
96 comm: Some(comm.clone()),
97 }))
98 }
99
100 Self::DummyNccl { nm_device } => {
101 once_log_info("Loading model using a NCCL-parallelized pipeline.");
102 Ok(Box::new(NcclDeviceMapper {
103 nm_device: nm_device.clone(),
104 model_layers,
105 comm: None,
106 }))
107 }
108
109 Self::Map(DeviceMapMetadata {
110 device_layers,
111 host_layers,
112 }) => {
113 if let Some(topology) = topology {
114 if topology.layers.iter().all(|x| x.is_none()) {
115 return Ok(Box::new(DummyDeviceMapper {
116 nm_device: device.clone(),
117 }));
118 } else {
119 let layers = topology
120 .layers
121 .iter()
122 .map(|layer| {
123 layer
124 .as_ref()
125 .map(|x| x.device.clone().unwrap_or(device.clone()))
126 .unwrap_or(device.clone())
127 })
128 .collect::<Vec<_>>();
129
130 info!("Loading model according to the following repeating layer mappings based on topology:");
131 for (i, dev) in layers.iter().enumerate() {
132 info!("Layer {i}: {}", dev.device_pretty_repr());
133 }
134
135 return Ok(Box::new(LayerDeviceMapper {
136 mappings: layers,
137 nm_device: device.clone(),
138 }));
139 }
140 }
141
142 let n_device_layers = if let Some(layers) = &device_layers {
145 layers
146 .iter()
147 .map(|metadata| metadata.layers)
148 .sum::<usize>()
149 .clamp(0, model_layers)
150 } else {
151 return Ok(Box::new(DummyDeviceMapper {
152 nm_device: device.clone(),
153 }));
154 };
155 let n_host_layers =
158 host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
159 if n_device_layers + n_host_layers != model_layers {
160 hanzo_ml::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
161 }
162 once_log_info(format!("Model has {model_layers} repeating layers."));
163
164 let mut combined = Vec::with_capacity(model_layers);
166 if device_layers
167 .as_ref()
168 .is_some_and(|layers| layers.len() == 1)
169 {
170 combined.extend(vec![device.clone(); n_device_layers]);
171 } else {
172 let original_seed = if !device.is_cpu() {
173 Some(device.get_current_seed()?)
174 } else {
175 None
176 };
177 for DeviceLayerMapMetadata { ordinal, layers } in
178 device_layers.as_ref().unwrap()
179 {
180 let dev = match device.location() {
181 DeviceLocation::Cpu => Device::Cpu,
182 #[cfg(feature = "rocm")]
183 DeviceLocation::Rocm { .. } => device.clone(),
184 #[cfg(feature = "vulkan")]
185 DeviceLocation::Vulkan { .. } => device.clone(),
186 DeviceLocation::Cuda { gpu_id: device_ord } => {
187 if device_ord == *ordinal {
188 device.clone()
189 } else {
190 let cuda_device = all_devices
191 .iter()
192 .filter(|d| d.is_cuda())
193 .map(|d| {
194 let ordinal = match d.location() {
196 DeviceLocation::Cpu => 0,
197 DeviceLocation::Cuda { gpu_id } => gpu_id,
198 DeviceLocation::Metal { gpu_id } => gpu_id,
199 #[cfg(feature = "rocm")]
200 DeviceLocation::Rocm { gpu_id } => gpu_id,
201 #[cfg(feature = "vulkan")]
202 DeviceLocation::Vulkan { gpu_id } => gpu_id,
203 };
204 (d.clone(), ordinal)
205 })
206 .find(|(_, other_device_ordinal)| {
207 other_device_ordinal == ordinal
208 });
209
210 if let Some((device, _)) = cuda_device {
211 device
212 } else {
213 hanzo_ml::bail!(
214 "Could not find cuda device with ordinal {}",
215 ordinal
216 )
217 }
218 }
219 }
220 DeviceLocation::Metal { gpu_id: device_ord } => {
221 if device_ord == *ordinal {
222 device.clone()
223 } else {
224 Device::new_metal(*ordinal)?
225 }
226 }
227 };
228 if !device.is_cpu() {
229 dev.set_seed(original_seed.unwrap())?;
230 }
231 combined.extend(vec![dev; *layers]);
232 }
233 }
234
235 combined.extend(vec![Device::Cpu; n_host_layers]);
237
238 assert_eq!(combined.len(), model_layers);
240
241 {
243 once_log_info(
244 "Loading model according to the following repeating layer mappings:",
245 );
246 let mut start_index = 0;
247 let mut current_dev = &combined[0];
248
249 for (i, variant) in combined.iter().enumerate().skip(1) {
251 if !variant.same_device(current_dev) {
253 once_log_info(format!(
254 "Layers {}-{}: {} ({} GB)",
255 start_index,
256 i - 1,
257 current_dev.device_pretty_repr(),
258 MemoryUsage
259 .query(current_dev)?
260 .total()
261 .div_ceil(1024 * 1024 * 1024),
262 ));
263 start_index = i; current_dev = variant;
265 }
266 }
267
268 once_log_info(format!(
269 "Layers {}-{}: {} ({} GB)",
270 start_index,
271 combined.len() - 1,
272 current_dev.device_pretty_repr(),
273 MemoryUsage
274 .query(current_dev)?
275 .total()
276 .div_ceil(1024 * 1024 * 1024),
277 ));
278 }
279
280 Ok(Box::new(LayerDeviceMapper {
281 mappings: combined,
282 nm_device: device.clone(),
283 }))
284 }
285 Self::Auto(_) => {
286 hanzo_ml::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
287 }
288 }
289 }
290}
291
292pub trait DeviceMapper: Debug {
293 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
296
297 fn set_device(
300 &self,
301 layer: usize,
302 varbuilder: ShardedVarBuilder,
303 loading_isq: bool,
304 ) -> ShardedVarBuilder;
305 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
307 fn get_unique_devices(&self) -> Vec<Device>;
308 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
310 fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
313 fn num_device_mapping_layers(&self) -> usize;
314 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>>;
315
316 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
318}
319
320#[derive(Debug)]
321pub struct LayerDeviceMapper {
323 mappings: Vec<Device>,
324 nm_device: Device,
325}
326
327impl DeviceMapper for LayerDeviceMapper {
328 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
329 input.to_device(&self.mappings[layer])
330 }
331 fn set_device<'a>(
332 &self,
333 layer: usize,
334 varbuilder: ShardedVarBuilder,
335 loading_isq: bool,
336 ) -> ShardedVarBuilder {
337 if loading_isq {
338 return varbuilder;
339 }
340 varbuilder.set_device(self.mappings[layer].clone())
341 }
342 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
343 if loading_isq {
344 return Some(&self.nm_device);
345 }
346 self.mappings.get(layer)
347 }
348 fn get_unique_devices(&self) -> Vec<Device> {
349 self.mappings.iter().fold(Vec::new(), |mut acc, device| {
350 if !acc.iter().any(|d| d.same_device(device)) {
351 acc.push(device.clone());
352 }
353 acc
354 })
355 }
356 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
357 if loading_isq {
358 x.to_device(&Device::Cpu)
359 } else {
360 x.to_device(&self.nm_device)
361 }
362 }
363 fn set_nm_device<'a>(
364 &self,
365 varbuilder: ShardedVarBuilder,
366 loading_isq: bool,
367 ) -> ShardedVarBuilder {
368 if loading_isq {
369 varbuilder
370 } else {
371 varbuilder.set_device(self.nm_device.clone())
372 }
373 }
374 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
375 dtype
376 .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
377 .map_err(hanzo_ml::Error::msg)
378 }
379 fn num_device_mapping_layers(&self) -> usize {
380 self.mappings.len()
381 }
382 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
383 let id = hanzo_quant::Id::new();
384 Ok(Arc::new(hanzo_quant::Comm::from_device(
385 id,
386 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
387 0,
388 1,
389 )?))
390 }
391}
392
393#[derive(Debug)]
394pub struct DummyDeviceMapper {
395 pub(crate) nm_device: Device,
396}
397
398impl DeviceMapper for DummyDeviceMapper {
399 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
400 Ok(input)
401 }
402 fn set_device<'a>(
403 &self,
404 _: usize,
405 varbuilder: ShardedVarBuilder,
406 loading_isq: bool,
407 ) -> ShardedVarBuilder {
408 if loading_isq {
409 varbuilder.set_device(Device::Cpu)
410 } else {
411 varbuilder.set_device(self.nm_device.clone())
412 }
413 }
414 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
415 Some(&self.nm_device)
416 }
417 fn get_unique_devices(&self) -> Vec<Device> {
418 vec![self.nm_device.clone()]
419 }
420 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
421 if loading_isq {
422 x.to_device(&Device::Cpu)
423 } else {
424 x.to_device(&self.nm_device)
425 }
426 }
427 fn set_nm_device<'a>(
428 &self,
429 varbuilder: ShardedVarBuilder,
430 loading_isq: bool,
431 ) -> ShardedVarBuilder {
432 if loading_isq {
433 varbuilder.set_device(Device::Cpu)
434 } else {
435 varbuilder.set_device(self.nm_device.clone())
436 }
437 }
438 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
439 dtype
440 .try_into_dtype(&[&self.nm_device])
441 .map_err(hanzo_ml::Error::msg)
442 }
443 fn num_device_mapping_layers(&self) -> usize {
444 1
446 }
447 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
448 let id = hanzo_quant::Id::new();
449 Ok(Arc::new(hanzo_quant::Comm::from_device(
450 id,
451 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
452 0,
453 1,
454 )?))
455 }
456}
457
458#[derive(Debug)]
459pub struct NcclDeviceMapper {
460 nm_device: Device,
461 model_layers: usize,
462 comm: Option<Arc<hanzo_quant::Comm>>,
463}
464
465impl DeviceMapper for NcclDeviceMapper {
466 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
467 Ok(input)
468 }
469 fn set_device<'a>(
470 &self,
471 _: usize,
472 varbuilder: ShardedVarBuilder,
473 loading_isq: bool,
474 ) -> ShardedVarBuilder {
475 if loading_isq {
476 varbuilder.set_device(Device::Cpu)
477 } else {
478 varbuilder.set_device(self.nm_device.clone())
479 }
480 }
481 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
482 Some(&self.nm_device)
483 }
484 fn get_unique_devices(&self) -> Vec<Device> {
485 vec![self.nm_device.clone()]
486 }
487 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
488 if loading_isq {
489 x.to_device(&Device::Cpu)
490 } else {
491 x.to_device(&self.nm_device)
492 }
493 }
494 fn set_nm_device<'a>(
495 &self,
496 varbuilder: ShardedVarBuilder,
497 loading_isq: bool,
498 ) -> ShardedVarBuilder {
499 if loading_isq {
500 varbuilder.set_device(Device::Cpu)
501 } else {
502 varbuilder.set_device(self.nm_device.clone())
503 }
504 }
505 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
506 dtype
507 .try_into_dtype(&[&self.nm_device])
508 .map_err(hanzo_ml::Error::msg)
509 }
510 fn num_device_mapping_layers(&self) -> usize {
511 self.model_layers
512 }
513 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
514 if let Some(comm) = &self.comm {
515 Ok(comm.clone())
516 } else {
517 let id = hanzo_quant::Id::new();
518 Ok(Arc::new(hanzo_quant::Comm::from_device(
519 id,
520 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
521 0,
522 1,
523 )?))
524 }
525 }
526}
527
528#[derive(Debug)]
529#[allow(dead_code)]
530pub struct NcclPipelineParallelMapper {
532 mappings: Vec<(Arc<hanzo_quant::Comm>, Device)>,
533 nm_device: Device,
534}
535
536impl DeviceMapper for NcclPipelineParallelMapper {
537 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
538 input.to_device(&self.mappings[layer].1)
539 }
540 fn set_device<'a>(
541 &self,
542 layer: usize,
543 varbuilder: ShardedVarBuilder,
544 loading_isq: bool,
545 ) -> ShardedVarBuilder {
546 if loading_isq {
547 return varbuilder;
548 }
549 varbuilder.set_device(self.mappings[layer].1.clone())
550 }
551 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
552 if loading_isq {
553 return Some(&self.nm_device);
554 }
555 self.mappings.get(layer).map(|(_, x)| x)
556 }
557 fn get_unique_devices(&self) -> Vec<Device> {
558 self.mappings
559 .iter()
560 .fold(Vec::new(), |mut acc, (_, device)| {
561 if !acc.iter().any(|d| d.same_device(device)) {
562 acc.push(device.clone());
563 }
564 acc
565 })
566 }
567 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
568 if loading_isq {
569 x.to_device(&Device::Cpu)
570 } else {
571 x.to_device(&self.nm_device)
572 }
573 }
574 fn set_nm_device<'a>(
575 &self,
576 varbuilder: ShardedVarBuilder,
577 loading_isq: bool,
578 ) -> ShardedVarBuilder {
579 if loading_isq {
580 varbuilder
581 } else {
582 varbuilder.set_device(self.nm_device.clone())
583 }
584 }
585 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
586 dtype
587 .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
588 .map_err(hanzo_ml::Error::msg)
589 }
590 fn num_device_mapping_layers(&self) -> usize {
591 self.mappings.len()
592 }
593 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
594 Ok(self.mappings[layer_idx].0.clone())
595 }
596}
597
598pub enum DeviceMappedMask {
605 None,
607 CausalFlash,
609 Custom(HashMap<DeviceLocation, Tensor>),
611}
612
613impl DeviceMappedMask {
614 pub fn new(mask: crate::attention::AttentionMask, mapper: &dyn DeviceMapper) -> Result<Self> {
616 match mask {
617 crate::attention::AttentionMask::None => Ok(Self::None),
618 crate::attention::AttentionMask::CausalFlash => Ok(Self::CausalFlash),
619 crate::attention::AttentionMask::Custom(tensor) => {
620 let mut masks = HashMap::new();
621 for device in mapper.get_unique_devices() {
622 let loc = device.location();
623 if let std::collections::hash_map::Entry::Vacant(e) = masks.entry(loc) {
624 e.insert(tensor.to_device(&device)?);
625 }
626 }
627 Ok(Self::Custom(masks))
628 }
629 }
630 }
631
632 pub fn from_single(mask: crate::attention::AttentionMask) -> Self {
634 match mask {
635 crate::attention::AttentionMask::None => Self::None,
636 crate::attention::AttentionMask::CausalFlash => Self::CausalFlash,
637 crate::attention::AttentionMask::Custom(tensor) => {
638 let mut masks = HashMap::new();
639 masks.insert(tensor.device().location(), tensor);
640 Self::Custom(masks)
641 }
642 }
643 }
644
645 pub fn get(&self, device: &Device) -> crate::attention::AttentionMask {
647 match self {
648 Self::None => crate::attention::AttentionMask::None,
649 Self::CausalFlash => crate::attention::AttentionMask::CausalFlash,
650 Self::Custom(masks) => {
651 let tensor = masks
652 .get(&device.location())
653 .expect("DeviceMappedMask: device not in mapper's unique devices");
654 crate::attention::AttentionMask::Custom(tensor.clone())
655 }
656 }
657 }
658}
659
660pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
662 let mut devices = Vec::new();
663 match base {
664 Device::Cpu => return Ok(vec![Device::Cpu]),
665 #[cfg(feature = "rocm")]
666 Device::Rocm(_) => return Ok(vec![base.clone()]),
667 #[cfg(feature = "vulkan")]
668 Device::Vulkan(_) => return Ok(vec![base.clone()]),
669 Device::Cuda(_) => {
670 let mut ord = 0;
671 let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
672 hanzo_ml::bail!("location and device do not match");
673 };
674 loop {
675 if base_ord == ord {
676 devices.push(base.clone());
677 ord += 1;
678 continue;
679 }
680 if let Ok(dev) = Device::new_cuda(ord) {
682 devices.push(dev);
683 ord += 1;
684 } else {
685 break;
686 }
687 }
688 }
689 #[cfg(not(feature = "metal"))]
690 Device::Metal(_) => {
691 hanzo_ml::bail!("Not compiled with metal features, but have a metal device.");
692 }
693 #[cfg(feature = "metal")]
694 Device::Metal(_) => {
695 #[cfg(feature = "metal")]
696 let total_ords = hanzo_metal_kernels::metal::Device::all().len();
697 #[cfg(not(feature = "metal"))]
698 let total_ords = 0;
699 let mut ord = 0;
700 let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
701 hanzo_ml::bail!("location and device do not match");
702 };
703 loop {
704 if base_ord == ord {
705 devices.push(base.clone());
706 ord += 1;
707 continue;
708 }
709 if total_ords == ord {
710 break;
711 }
712 if let Ok(dev) = Device::new_metal(ord) {
713 devices.push(dev);
714 ord += 1;
715 } else {
716 break;
717 }
718 }
719 }
720 }
721 Ok(devices)
722}